diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 7748ea3..5108323 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -478,7 +478,7 @@ class SAEHDModel(ModelBase): if self.options['ms_ssim_loss']: # TODO - Done - src_loss = 10 * MsSSIM(max_value=1.0)(target_src_masked_opt, pred_src_src_masked_opt) + src_loss = K.mean(10 * MsSSIM(max_value=1.0)(target_src_masked_opt, pred_src_src_masked_opt)) else: src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) ) src_loss += K.mean ( 10*K.square( target_src_masked_opt - pred_src_src_masked_opt ) ) @@ -495,14 +495,14 @@ class SAEHDModel(ModelBase): if bg_style_power != 0: if self.options['ms_ssim_loss']: # TODO - Done - src_loss += MsSSIM(max_value=1.0)(psd_target_dst_anti_masked, target_dst_anti_masked) + src_loss += K.mean(10 * bg_style_power * MsSSIM(max_value=1.0)(psd_target_dst_anti_masked, target_dst_anti_masked)) else: src_loss += K.mean( (10*bg_style_power)*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked, target_dst_anti_masked )) src_loss += K.mean( (10*bg_style_power)*K.square( psd_target_dst_anti_masked - target_dst_anti_masked )) if self.options['ms_ssim_loss']: # TODO - Done - dst_loss = 10 * MsSSIM(max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) + dst_loss = K.mean(10 * MsSSIM(max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt)) else: dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) ) dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) ) @@ -533,8 +533,8 @@ class SAEHDModel(ModelBase): if self.options['learn_mask']: if self.options['ms_ssim_loss']: # TODO - Done - src_mask_loss = MsSSIM(max_value=1.0)(self.model.target_srcm, self.model.pred_src_srcm) - dst_mask_loss = MsSSIM(max_value=1.0)(self.model.target_dstm, self.model.pred_dst_dstm) + src_mask_loss = K.mean(MsSSIM(max_value=1.0)(self.model.target_srcm, self.model.pred_src_srcm)) + dst_mask_loss = K.mean(MsSSIM(max_value=1.0)(self.model.target_dstm, self.model.pred_dst_dstm)) else: src_mask_loss = K.mean(K.square(self.model.target_srcm-self.model.pred_src_srcm)) dst_mask_loss = K.mean(K.square(self.model.target_dstm-self.model.pred_dst_dstm))