mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 22:34:25 -07:00
Use mean
This commit is contained in:
parent
6ac21f16c5
commit
372ff2c624
1 changed files with 5 additions and 5 deletions
|
@ -478,7 +478,7 @@ class SAEHDModel(ModelBase):
|
||||||
|
|
||||||
if self.options['ms_ssim_loss']:
|
if self.options['ms_ssim_loss']:
|
||||||
# TODO - Done
|
# TODO - Done
|
||||||
src_loss = 10 * MsSSIM(max_value=1.0)(target_src_masked_opt, pred_src_src_masked_opt)
|
src_loss = K.mean(10 * MsSSIM(max_value=1.0)(target_src_masked_opt, pred_src_src_masked_opt))
|
||||||
else:
|
else:
|
||||||
src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )
|
src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )
|
||||||
src_loss += K.mean ( 10*K.square( target_src_masked_opt - pred_src_src_masked_opt ) )
|
src_loss += K.mean ( 10*K.square( target_src_masked_opt - pred_src_src_masked_opt ) )
|
||||||
|
@ -495,14 +495,14 @@ class SAEHDModel(ModelBase):
|
||||||
if bg_style_power != 0:
|
if bg_style_power != 0:
|
||||||
if self.options['ms_ssim_loss']:
|
if self.options['ms_ssim_loss']:
|
||||||
# TODO - Done
|
# TODO - Done
|
||||||
src_loss += MsSSIM(max_value=1.0)(psd_target_dst_anti_masked, target_dst_anti_masked)
|
src_loss += K.mean(10 * bg_style_power * MsSSIM(max_value=1.0)(psd_target_dst_anti_masked, target_dst_anti_masked))
|
||||||
else:
|
else:
|
||||||
src_loss += K.mean( (10*bg_style_power)*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))
|
src_loss += K.mean( (10*bg_style_power)*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))
|
||||||
src_loss += K.mean( (10*bg_style_power)*K.square( psd_target_dst_anti_masked - target_dst_anti_masked ))
|
src_loss += K.mean( (10*bg_style_power)*K.square( psd_target_dst_anti_masked - target_dst_anti_masked ))
|
||||||
|
|
||||||
if self.options['ms_ssim_loss']:
|
if self.options['ms_ssim_loss']:
|
||||||
# TODO - Done
|
# TODO - Done
|
||||||
dst_loss = 10 * MsSSIM(max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt)
|
dst_loss = K.mean(10 * MsSSIM(max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt))
|
||||||
else:
|
else:
|
||||||
dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) )
|
dst_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_opt, pred_dst_dst_masked_opt) )
|
||||||
dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) )
|
dst_loss += K.mean( 10*K.square( target_dst_masked_opt - pred_dst_dst_masked_opt ) )
|
||||||
|
@ -533,8 +533,8 @@ class SAEHDModel(ModelBase):
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
if self.options['ms_ssim_loss']:
|
if self.options['ms_ssim_loss']:
|
||||||
# TODO - Done
|
# TODO - Done
|
||||||
src_mask_loss = MsSSIM(max_value=1.0)(self.model.target_srcm, self.model.pred_src_srcm)
|
src_mask_loss = K.mean(MsSSIM(max_value=1.0)(self.model.target_srcm, self.model.pred_src_srcm))
|
||||||
dst_mask_loss = MsSSIM(max_value=1.0)(self.model.target_dstm, self.model.pred_dst_dstm)
|
dst_mask_loss = K.mean(MsSSIM(max_value=1.0)(self.model.target_dstm, self.model.pred_dst_dstm))
|
||||||
else:
|
else:
|
||||||
src_mask_loss = K.mean(K.square(self.model.target_srcm-self.model.pred_src_srcm))
|
src_mask_loss = K.mean(K.square(self.model.target_srcm-self.model.pred_src_srcm))
|
||||||
dst_mask_loss = K.mean(K.square(self.model.target_dstm-self.model.pred_dst_dstm))
|
dst_mask_loss = K.mean(K.square(self.model.target_dstm-self.model.pred_dst_dstm))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue