From 8314e1b01a0df115572a990c58ae8ea7a038de49 Mon Sep 17 00:00:00 2001 From: jh Date: Sat, 24 Apr 2021 00:40:27 -0700 Subject: [PATCH] Adds functions to generate gaussian weights, and mix MS-SSIM with L1 loss --- core/leras/layers/MsSsim.py | 30 +++++++++++++++++++++++------- core/leras/ops/__init__.py | 17 +++++++++++++++-- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/core/leras/layers/MsSsim.py b/core/leras/layers/MsSsim.py index 03b9a14..1653243 100644 --- a/core/leras/layers/MsSsim.py +++ b/core/leras/layers/MsSsim.py @@ -4,15 +4,20 @@ tf = nn.tf class MsSsim(nn.LayerBase): default_power_factors = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) + default_l1_alpha = 0.84 - def __init__(self, resolution, kernel_size=11, **kwargs): + def __init__(self, batch_size, in_ch, resolution, kernel_size=11, use_l1=False, **kwargs): # restrict mssim factors to those greater/equal to kernel size power_factors = [p for i, p in enumerate(self.default_power_factors) if resolution//(2**i) >= kernel_size] # normalize power factors if reduced because of size if sum(power_factors) < 1.0: power_factors = [x/sum(power_factors) for x in power_factors] self.power_factors = power_factors + self.num_scales = len(power_factors) self.kernel_size = kernel_size + self.use_l1 = use_l1 + if use_l1: + self.gaussian_weights = nn.get_gaussian_weights(batch_size, in_ch, resolution, num_scales=self.num_scales) super().__init__(**kwargs) @@ -21,14 +26,25 @@ class MsSsim(nn.LayerBase): y_true_t = tf.transpose(tf.cast(y_true, tf.float32), [0, 2, 3, 1]) y_pred_t = tf.transpose(tf.cast(y_pred, tf.float32), [0, 2, 3, 1]) - if tf.__version__ >= "1.14": - ms_ssim_val = tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors, filter_size=self.kernel_size) - else: - ms_ssim_val = tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors) - # ssim_multiscale returns values in range [0, 1] (where 1 is completely identical) # subtract from 1 to get loss - return 1.0 - ms_ssim_val + if tf.__version__ >= "1.14": + ms_ssim_loss = 1.0 - tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors, filter_size=self.kernel_size) + else: + ms_ssim_loss = 1.0 - tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors) + + # If use L1 is enabled, use mix of ms-ssim and L1 (weighted by gaussian filters) + # H. Zhao, O. Gallo, I. Frosio and J. Kautz, "Loss Functions for Image Restoration With Neural Networks," + # in IEEE Transactions on Computational Imaging, vol. 3, no. 1, pp. 47-57, March 2017, + # doi: 10.1109/TCI.2016.2644865. + # https://research.nvidia.com/publication/loss-functions-image-restoration-neural-networks + + if self.use_l1: + diff = tf.tile(tf.expand_dims(tf.abs(y_true - y_pred), axis=0), multiples=[self.num_scales, 1, 1, 1, 1]) + l1_loss = tf.reduce_mean(tf.reduce_sum(self.gaussian_weights[-1, :, :, :, :] * diff, axis=[0, 3, 4]), axis=[1]) + return self.default_l1_alpha * ms_ssim_loss + (1 - self.default_l1_alpha) * l1_loss + + return ms_ssim_loss nn.MsSsim = MsSsim diff --git a/core/leras/ops/__init__.py b/core/leras/ops/__init__.py index 500a22a..6e6310b 100644 --- a/core/leras/ops/__init__.py +++ b/core/leras/ops/__init__.py @@ -237,6 +237,19 @@ def gaussian_blur(input, radius=2.0): return x nn.gaussian_blur = gaussian_blur +def get_gaussian_weights(batch_size, in_ch, resolution, num_scale=5, sigma=(0.5, 1., 2., 4., 8.)): + w = np.empty((num_scale, batch_size, in_ch, resolution, resolution)) + for i in range(num_scale): + gaussian = np.exp(-1.*np.arange(-(resolution/2-0.5), resolution/2+0.5)**2/(2*sigma[i]**2)) + gaussian = np.outer(gaussian, gaussian.reshape((resolution, 1))) # extend to 2D + gaussian = gaussian/np.sum(gaussian) # normalization + gaussian = np.reshape(gaussian, (1, 1, resolution, resolution)) # reshape to 3D + gaussian = np.tile(gaussian, (batch_size, in_ch, 1, 1)) + w[i, :, :, :, :] = gaussian + return w + +nn.get_gaussian_weights = get_gaussian_weights + def style_loss(target, style, gaussian_blur_radius=0.0, loss_weight=1.0, step_size=1): def sd(content, style, loss_weight): content_nc = content.shape[ nn.conv2d_ch_axis ] @@ -385,7 +398,7 @@ def total_variation_mse(images): """ pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :] pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :] - + tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) + tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) ) return tot_var @@ -400,4 +413,4 @@ def tf_suppress_lower_mean(t, eps=0.00001): q = tf.clip_by_value(q-t_mean_eps, 0, eps) q = q * (t/eps) return q -""" \ No newline at end of file +"""