From f06ff1eae8ded0969803cc2346e00ba9adb4d8bc Mon Sep 17 00:00:00 2001 From: jh Date: Fri, 12 Mar 2021 07:57:09 -0800 Subject: [PATCH] feat: ms_ssim implementation --- core/leras/ops/__init__.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/core/leras/ops/__init__.py b/core/leras/ops/__init__.py index 500a22a..af59c6c 100644 --- a/core/leras/ops/__init__.py +++ b/core/leras/ops/__init__.py @@ -307,6 +307,35 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03 nn.dssim = dssim + +def ms_ssim(img1, img2, resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0, + power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)): + img_dtype = img1.dtype + if img_dtype != img2.dtype: + raise ValueError("img1.dtype != img2.dtype") + + if img_dtype != tf.float32: + img1 = tf.cast(img1, tf.float32) + img2 = tf.cast(img2, tf.float32) + + # restrict mssim factors to those greater/equal to kernel size + power_factors = [power_factors[i] for i in range(len(power_factors)) if resolution//(2**i) >= kernel_size] + + # normalize power factors if reduced because of size + if sum(power_factors) < 1.0: + power_factors = [x/sum(power_factors) for x in power_factors] + + ms_ssim_val = tf.image.ssim_multiscale(img1, img2, max_value=max_value, power_factors=power_factors, + filter_size=kernel_size, k1=k1, k2=k2) + loss = (1.0 - ms_ssim_val) / 2.0 + + if img_dtype != tf.float32: + loss = tf.cast(loss, img_dtype) + return loss + +nn.dssim = ms_ssim + + def space_to_depth(x, size): if nn.data_format == "NHWC": # match NCHW version in order to switch data_format without problems @@ -385,7 +414,7 @@ def total_variation_mse(images): """ pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :] pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :] - + tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) + tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) ) return tot_var @@ -400,4 +429,4 @@ def tf_suppress_lower_mean(t, eps=0.00001): q = tf.clip_by_value(q-t_mean_eps, 0, eps) q = q * (t/eps) return q -""" \ No newline at end of file +"""