diff --git a/core/leras/ops/__init__.py b/core/leras/ops/__init__.py index 6824488..c136009 100644 --- a/core/leras/ops/__init__.py +++ b/core/leras/ops/__init__.py @@ -308,15 +308,8 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03 nn.dssim = dssim -def ms_ssim(img1, img2, resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0, +def ms_ssim(resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)): - img_dtype = img1.dtype - if img_dtype != img2.dtype: - raise ValueError("img1.dtype != img2.dtype") - - if img_dtype != tf.float32: - img1 = tf.cast(img1, tf.float32) - img2 = tf.cast(img2, tf.float32) # restrict mssim factors to those greater/equal to kernel size power_factors = [power_factors[i] for i in range(len(power_factors)) if resolution//(2**i) >= kernel_size] @@ -325,15 +318,28 @@ def ms_ssim(img1, img2, resolution, kernel_size=11, k1=0.01, k2=0.03, max_value= if sum(power_factors) < 1.0: power_factors = [x/sum(power_factors) for x in power_factors] - # Transpose images from NCHW to NHWC - img1_t = tf.transpose(img1, [0, 2, 3, 1]) - img2_t = tf.transpose(img2, [0, 2, 3, 1]) - ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors, - filter_size=kernel_size, k1=k1, k2=k2) - loss = (1.0 - ms_ssim_val) / 2.0 + def loss(img1, img2): + img_dtype = img1.dtype + if img_dtype != img2.dtype: + raise ValueError("img1.dtype != img2.dtype") + + if img_dtype != tf.float32: + img1 = tf.cast(img1, tf.float32) + img2 = tf.cast(img2, tf.float32) + + + + # Transpose images from NCHW to NHWC + img1_t = tf.transpose(img1, [0, 2, 3, 1]) + img2_t = tf.transpose(img2, [0, 2, 3, 1]) + ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors, + filter_size=kernel_size, k1=k1, k2=k2) + ms_ssim_loss = (1.0 - ms_ssim_val) / 2.0 + + if img_dtype != tf.float32: + ms_ssim_loss = tf.cast(ms_ssim_loss, img_dtype) + return ms_ssim_loss - if img_dtype != tf.float32: - loss = tf.cast(loss, img_dtype) return loss nn.ms_ssim = ms_ssim diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 2db427d..16d63da 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -426,7 +426,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur) if self.options['ms_ssim_loss']: - gpu_src_loss = tf.reduce_mean ( 10*nn.ms_ssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, resolution)) + gpu_src_loss = tf.reduce_mean ( 10*nn.ms_ssim(resolution)(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt)) else: if resolution < 256: gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])