This commit is contained in:
Jeremy Hummel 2019-10-19 23:00:19 -04:00
commit 372ff2c624

View file

@ -478,7 +478,7 @@ class SAEHDModel(ModelBase):
if self.options['ms_ssim_loss']:
# TODO - Done
src_loss = 10 * MsSSIM(max_value=1.0)(target_src_masked_opt, pred_src_src_masked_opt)
src_loss = K.mean(10 * MsSSIM(max_value=1.0)(target_src_masked_opt, pred_src_src_masked_opt))
else:
src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )
src_loss += K.mean ( 10*K.square( target_src_masked_opt - pred_src_src_masked_opt ) )
@ -495,14 +495,14 @@ class SAEHDModel(ModelBase):
if bg_style_power != 0:
if self.options['ms_ssim_loss']:
# TODO - Done
src_loss += MsSSIM(max_value=1.0)(psd_target_dst_anti_masked, target_dst_anti_masked)
src_loss += K.mean(10 * bg_style_power * MsSSIM(max_value=1.0)(psd_target_dst_anti_masked, target_dst_anti_masked))
else:
src_loss += K.mean( (10*bg_style_power)*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))
src_loss += K.mean( (10*bg_style_power)*K.square( psd_target_dst_anti_masked - target_dst_anti_masked ))
if self.options['ms_ssim_loss']:
# TODO - Done
dst_loss = 10 * MsSSIM(max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt)
dst_loss = K.mean(10 * MsSSIM(max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt))
else:
dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) )
dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) )
@ -533,8 +533,8 @@ class SAEHDModel(ModelBase):
if self.options['learn_mask']:
if self.options['ms_ssim_loss']:
# TODO - Done
src_mask_loss = MsSSIM(max_value=1.0)(self.model.target_srcm, self.model.pred_src_srcm)
dst_mask_loss = MsSSIM(max_value=1.0)(self.model.target_dstm, self.model.pred_dst_dstm)
src_mask_loss = K.mean(MsSSIM(max_value=1.0)(self.model.target_srcm, self.model.pred_src_srcm))
dst_mask_loss = K.mean(MsSSIM(max_value=1.0)(self.model.target_dstm, self.model.pred_dst_dstm))
else:
src_mask_loss = K.mean(K.square(self.model.target_srcm-self.model.pred_src_srcm))
dst_mask_loss = K.mean(K.square(self.model.target_dstm-self.model.pred_dst_dstm))