From 9b98607b334136a42ecd7a7173647103dd806331 Mon Sep 17 00:00:00 2001 From: jh Date: Fri, 12 Mar 2021 17:52:57 -0800 Subject: [PATCH] test: try as layer --- core/leras/layers/MsSsim.py | 25 ++++++++++++ core/leras/ops/__init__.py | 76 ++++++++++++++++++------------------- models/Model_SAEHD/Model.py | 2 +- 3 files changed, 64 insertions(+), 39 deletions(-) create mode 100644 core/leras/layers/MsSsim.py diff --git a/core/leras/layers/MsSsim.py b/core/leras/layers/MsSsim.py new file mode 100644 index 0000000..d1ad64c --- /dev/null +++ b/core/leras/layers/MsSsim.py @@ -0,0 +1,25 @@ +from core.leras import nn +tf = nn.tf + +class MsSsim(nn.LayerBase): + default_power_factors = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) + + def __init__(self, resolution, kernel_size=11, **kwargs): + # restrict mssim factors to those greater/equal to kernel size + power_factors = [p for i, p in enumerate(self.default_power_factors) if resolution//(2**i) >= kernel_size] + # normalize power factors if reduced because of size + if sum(power_factors) < 1.0: + power_factors = [x/sum(power_factors) for x in power_factors] + self.power_factors = power_factors + self.kernel_size = kernel_size + + super().__init__(**kwargs) + + def __call__(self, y_true, y_pred, max_val): + # Transpose images from NCHW to NHWC + y_true_t = tf.transpose(y_true, [0, 2, 3, 1]) + y_pred_t = tf.transpose(y_pred, [0, 2, 3, 1]) + loss = tf.image.ssim_multiscale(y_true, y_pred, max_val, power_factors=self.power_factors) + return (1.0 - loss) / 2.0 + +nn.MsSsim = MsSsim diff --git a/core/leras/ops/__init__.py b/core/leras/ops/__init__.py index b460f17..f34461f 100644 --- a/core/leras/ops/__init__.py +++ b/core/leras/ops/__init__.py @@ -308,44 +308,44 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03 nn.dssim = dssim -def ms_ssim(img1, img2, resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0, - power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)): - - # restrict mssim factors to those greater/equal to kernel size - power_factors = [power_factors[i] for i in range(len(power_factors)) if resolution//(2**i) >= kernel_size] - - # normalize power factors if reduced because of size - if sum(power_factors) < 1.0: - power_factors = [x/sum(power_factors) for x in power_factors] - - img_dtype = img1.dtype - if img_dtype != img2.dtype: - raise ValueError("img1.dtype != img2.dtype") - - if img_dtype != tf.float32: - img1 = tf.cast(img1, tf.float32) - img2 = tf.cast(img2, tf.float32) - - # Transpose images from NCHW to NHWC - img1_t = tf.transpose(img1, [0, 2, 3, 1]) - img2_t = tf.transpose(img2, [0, 2, 3, 1]) - - def assign_device(op): - if op.type != 'ListDiff': - return '/gpu:0' - else: - return '/cpu:0' - - with tf.device(assign_device): - ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors, - filter_size=kernel_size, k1=k1, k2=k2) - ms_ssim_loss = (1.0 - ms_ssim_val) / 2.0 - - if img_dtype != tf.float32: - ms_ssim_loss = tf.cast(ms_ssim_loss, img_dtype) - return ms_ssim_loss - -nn.ms_ssim = ms_ssim +# def ms_ssim(img1, img2, resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0, +# power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)): +# +# # restrict mssim factors to those greater/equal to kernel size +# power_factors = [power_factors[i] for i in range(len(power_factors)) if resolution//(2**i) >= kernel_size] +# +# # normalize power factors if reduced because of size +# if sum(power_factors) < 1.0: +# power_factors = [x/sum(power_factors) for x in power_factors] +# +# img_dtype = img1.dtype +# if img_dtype != img2.dtype: +# raise ValueError("img1.dtype != img2.dtype") +# +# if img_dtype != tf.float32: +# img1 = tf.cast(img1, tf.float32) +# img2 = tf.cast(img2, tf.float32) +# +# # Transpose images from NCHW to NHWC +# img1_t = tf.transpose(img1, [0, 2, 3, 1]) +# img2_t = tf.transpose(img2, [0, 2, 3, 1]) +# +# def assign_device(op): +# if op.type != 'ListDiff': +# return '/gpu:0' +# else: +# return '/cpu:0' +# +# with tf.device(assign_device): +# ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors, +# filter_size=kernel_size, k1=k1, k2=k2) +# ms_ssim_loss = (1.0 - ms_ssim_val) / 2.0 +# +# if img_dtype != tf.float32: +# ms_ssim_loss = tf.cast(ms_ssim_loss, img_dtype) +# return ms_ssim_loss +# +# nn.ms_ssim = ms_ssim def space_to_depth(x, size): diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 23065ec..c57ec6e 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -426,7 +426,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur) if self.options['ms_ssim_loss']: - gpu_src_loss = tf.reduce_mean ( 10*nn.ms_ssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, resolution)) + gpu_src_loss = tf.reduce_mean ( 10*nn.MsSsim(resolution)(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt)) else: if resolution < 256: gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])