From d5fd69906d2eac8ba1f5359dfd14a1adf7b8a5b3 Mon Sep 17 00:00:00 2001 From: jh Date: Fri, 12 Mar 2021 08:17:04 -0800 Subject: [PATCH] feat: first test of ms_ssim --- models/Model_SAEHD/Model.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 65bca2d..2989d6e 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -425,12 +425,15 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur) - if resolution < 256: - gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + if self.options['ms_ssim_loss']: + gpu_src_loss = tf.reduce_mean ( 10*nn.ms_ssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, resolution), axis=[1]) else: - gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) - gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) - gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) + if resolution < 256: + gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + else: + gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) if eyes_prio or mouth_prio: if eyes_prio and mouth_prio: