diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index e2a9d01..5108323 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -478,7 +478,7 @@ class SAEHDModel(ModelBase): if self.options['ms_ssim_loss']: # TODO - Done - src_loss = K.mean(10 * MsSSIM(max_value=1.0, power_factors=(1.0,))(target_src_masked_opt, pred_src_src_masked_opt)) + src_loss = K.mean(10 * MsSSIM(max_value=1.0)(target_src_masked_opt, pred_src_src_masked_opt)) else: src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) ) src_loss += K.mean ( 10*K.square( target_src_masked_opt - pred_src_src_masked_opt ) ) @@ -495,14 +495,14 @@ class SAEHDModel(ModelBase): if bg_style_power != 0: if self.options['ms_ssim_loss']: # TODO - Done - src_loss += K.mean(10 * bg_style_power * MsSSIM(max_value=1.0, power_factors=(1.0,))(psd_target_dst_anti_masked, target_dst_anti_masked)) + src_loss += K.mean(10 * bg_style_power * MsSSIM(max_value=1.0)(psd_target_dst_anti_masked, target_dst_anti_masked)) else: src_loss += K.mean( (10*bg_style_power)*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked, target_dst_anti_masked )) src_loss += K.mean( (10*bg_style_power)*K.square( psd_target_dst_anti_masked - target_dst_anti_masked )) if self.options['ms_ssim_loss']: # TODO - Done - dst_loss = K.mean(10 * MsSSIM(max_value=1.0, power_factors=(1.0,))(target_dst_masked_opt, pred_dst_dst_masked_opt)) + dst_loss = K.mean(10 * MsSSIM(max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt)) else: dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) ) dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) ) @@ -533,8 +533,8 @@ class SAEHDModel(ModelBase): if self.options['learn_mask']: if self.options['ms_ssim_loss']: # TODO - Done - src_mask_loss = K.mean(MsSSIM(max_value=1.0, power_factors=(1.0,))(self.model.target_srcm, self.model.pred_src_srcm)) - dst_mask_loss = K.mean(MsSSIM(max_value=1.0, power_factors=(1.0,))(self.model.target_dstm, self.model.pred_dst_dstm)) + src_mask_loss = K.mean(MsSSIM(max_value=1.0)(self.model.target_srcm, self.model.pred_src_srcm)) + dst_mask_loss = K.mean(MsSSIM(max_value=1.0)(self.model.target_dstm, self.model.pred_dst_dstm)) else: src_mask_loss = K.mean(K.square(self.model.target_srcm-self.model.pred_src_srcm)) dst_mask_loss = K.mean(K.square(self.model.target_dstm-self.model.pred_dst_dstm)) diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index b86de76..5c16375 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -365,14 +365,15 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator k1=self.k1, k2=self.k2) loss = (1.0 - mssim_val) / 2.0 return loss - loss = 0.0 - # im_size = K.shape(y_pred)[-2] - for i, weight in enumerate(self.power_factors): - size = 2**i - dssim = self.dssim(K.pool2d(y_true, (size, size), strides=(size, size), pool_mode='avg'), - K.pool2d(y_pred, (size, size), strides=(size, size), pool_mode='avg')) - loss += dssim**weight - return loss/len(self.power_factors) + else: + loss = 0.0 + # im_size = K.shape(y_pred)[-2] + for i, weight in enumerate(self.power_factors): + size = 2**i + dssim = self.dssim(K.pool2d(y_true, (size, size), strides=(size, size), pool_mode='avg'), + K.pool2d(y_pred, (size, size), strides=(size, size), pool_mode='avg')) + loss += dssim**weight + return loss/len(self.power_factors) nnlib.MsSSIM = MsSSIM