From 28279971a17a5300ced4e129b8bf953063007b44 Mon Sep 17 00:00:00 2001 From: jh Date: Fri, 12 Mar 2021 11:28:01 -0800 Subject: [PATCH] fix: more testing --- core/leras/ops/__init__.py | 37 ++++++++++++++++--------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/core/leras/ops/__init__.py b/core/leras/ops/__init__.py index c136009..583a67d 100644 --- a/core/leras/ops/__init__.py +++ b/core/leras/ops/__init__.py @@ -308,7 +308,7 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03 nn.dssim = dssim -def ms_ssim(resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0, +def ms_ssim(img1, img2, resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)): # restrict mssim factors to those greater/equal to kernel size @@ -318,29 +318,24 @@ def ms_ssim(resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0, if sum(power_factors) < 1.0: power_factors = [x/sum(power_factors) for x in power_factors] - def loss(img1, img2): - img_dtype = img1.dtype - if img_dtype != img2.dtype: - raise ValueError("img1.dtype != img2.dtype") + img_dtype = img1.dtype + if img_dtype != img2.dtype: + raise ValueError("img1.dtype != img2.dtype") - if img_dtype != tf.float32: - img1 = tf.cast(img1, tf.float32) - img2 = tf.cast(img2, tf.float32) + if img_dtype != tf.float32: + img1 = tf.cast(img1, tf.float32) + img2 = tf.cast(img2, tf.float32) + # Transpose images from NCHW to NHWC + img1_t = tf.transpose(img1, [0, 2, 3, 1]) + img2_t = tf.transpose(img2, [0, 2, 3, 1]) + ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors, + filter_size=kernel_size, k1=k1, k2=k2) + ms_ssim_loss = (1.0 - ms_ssim_val) / 2.0 - - # Transpose images from NCHW to NHWC - img1_t = tf.transpose(img1, [0, 2, 3, 1]) - img2_t = tf.transpose(img2, [0, 2, 3, 1]) - ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors, - filter_size=kernel_size, k1=k1, k2=k2) - ms_ssim_loss = (1.0 - ms_ssim_val) / 2.0 - - if img_dtype != tf.float32: - ms_ssim_loss = tf.cast(ms_ssim_loss, img_dtype) - return ms_ssim_loss - - return loss + if img_dtype != tf.float32: + ms_ssim_loss = tf.cast(ms_ssim_loss, img_dtype) + return ms_ssim_loss nn.ms_ssim = ms_ssim