bug fixing loss function

This commit is contained in:
Jan 2021-11-25 08:28:20 +01:00
commit 39e380bca8

View file

@ -524,7 +524,7 @@ class AMPModel(ModelBase):
if self.options['loss_function'] == 'MS-SSIM':
gpu_dst_loss = 10 * nn.MsSsim(bs_per_gpu, input_ch, resolution)(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0)
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_dst_masked - gpu_pred_dst_dst_masked ), axis=[1,2,3])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_dst_masked - gpu_pred_dst_dst_masked ), axis=[1,2,3])
elif self.options['loss_function'] == 'MS-SSIM+L1':
gpu_dst_loss = 10 * nn.MsSsim(bs_per_gpu, input_ch, resolution, use_l1=True)(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0)