diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index a81d535..531960c 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -621,8 +621,9 @@ class SAEHDModel(ModelBase): pred_src_srcm = nn.tf_concat(gpu_pred_src_srcm_list, 0) pred_dst_dstm = nn.tf_concat(gpu_pred_dst_dstm_list, 0) pred_src_dstm = nn.tf_concat(gpu_pred_src_dstm_list, 0) - src_loss = nn.tf_average_tensor_list(gpu_src_losses) - dst_loss = nn.tf_average_tensor_list(gpu_dst_losses) + + src_loss = tf.concat(gpu_src_losses, 0) + dst_loss = tf.concat(gpu_dst_losses, 0) src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.tf_average_gv_list (gpu_G_loss_gvs)) if self.options['true_face_power'] != 0: