fix: more testing

This commit is contained in:
jh 2021-03-12 11:28:01 -08:00
commit 28279971a1

View file

@ -308,7 +308,7 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03
nn.dssim = dssim
def ms_ssim(resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0,
def ms_ssim(img1, img2, resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0,
power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)):
# restrict mssim factors to those greater/equal to kernel size
@ -318,7 +318,6 @@ def ms_ssim(resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0,
if sum(power_factors) < 1.0:
power_factors = [x/sum(power_factors) for x in power_factors]
def loss(img1, img2):
img_dtype = img1.dtype
if img_dtype != img2.dtype:
raise ValueError("img1.dtype != img2.dtype")
@ -327,8 +326,6 @@ def ms_ssim(resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0,
img1 = tf.cast(img1, tf.float32)
img2 = tf.cast(img2, tf.float32)
# Transpose images from NCHW to NHWC
img1_t = tf.transpose(img1, [0, 2, 3, 1])
img2_t = tf.transpose(img2, [0, 2, 3, 1])
@ -340,8 +337,6 @@ def ms_ssim(resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0,
ms_ssim_loss = tf.cast(ms_ssim_loss, img_dtype)
return ms_ssim_loss
return loss
nn.ms_ssim = ms_ssim