feat: ms_ssim implementation

This commit is contained in:
jh 2021-03-12 07:57:09 -08:00
commit f06ff1eae8

View file

@ -307,6 +307,35 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03
nn.dssim = dssim
def ms_ssim(img1, img2, resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0,
power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)):
img_dtype = img1.dtype
if img_dtype != img2.dtype:
raise ValueError("img1.dtype != img2.dtype")
if img_dtype != tf.float32:
img1 = tf.cast(img1, tf.float32)
img2 = tf.cast(img2, tf.float32)
# restrict mssim factors to those greater/equal to kernel size
power_factors = [power_factors[i] for i in range(len(power_factors)) if resolution//(2**i) >= kernel_size]
# normalize power factors if reduced because of size
if sum(power_factors) < 1.0:
power_factors = [x/sum(power_factors) for x in power_factors]
ms_ssim_val = tf.image.ssim_multiscale(img1, img2, max_value=max_value, power_factors=power_factors,
filter_size=kernel_size, k1=k1, k2=k2)
loss = (1.0 - ms_ssim_val) / 2.0
if img_dtype != tf.float32:
loss = tf.cast(loss, img_dtype)
return loss
nn.dssim = ms_ssim
def space_to_depth(x, size):
if nn.data_format == "NHWC":
# match NCHW version in order to switch data_format without problems
@ -385,7 +414,7 @@ def total_variation_mse(images):
"""
pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :]
pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :]
tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) +
tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) )
return tot_var
@ -400,4 +429,4 @@ def tf_suppress_lower_mean(t, eps=0.00001):
q = tf.clip_by_value(q-t_mean_eps, 0, eps)
q = q * (t/eps)
return q
"""
"""