Adds functions to generate gaussian weights, and mix MS-SSIM with L1 loss

This commit is contained in:
jh 2021-04-24 00:40:27 -07:00
commit 8314e1b01a
2 changed files with 38 additions and 9 deletions

View file

@ -4,15 +4,20 @@ tf = nn.tf
class MsSsim(nn.LayerBase):
default_power_factors = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
default_l1_alpha = 0.84
def __init__(self, resolution, kernel_size=11, **kwargs):
def __init__(self, batch_size, in_ch, resolution, kernel_size=11, use_l1=False, **kwargs):
# restrict mssim factors to those greater/equal to kernel size
power_factors = [p for i, p in enumerate(self.default_power_factors) if resolution//(2**i) >= kernel_size]
# normalize power factors if reduced because of size
if sum(power_factors) < 1.0:
power_factors = [x/sum(power_factors) for x in power_factors]
self.power_factors = power_factors
self.num_scales = len(power_factors)
self.kernel_size = kernel_size
self.use_l1 = use_l1
if use_l1:
self.gaussian_weights = nn.get_gaussian_weights(batch_size, in_ch, resolution, num_scales=self.num_scales)
super().__init__(**kwargs)
@ -21,14 +26,25 @@ class MsSsim(nn.LayerBase):
y_true_t = tf.transpose(tf.cast(y_true, tf.float32), [0, 2, 3, 1])
y_pred_t = tf.transpose(tf.cast(y_pred, tf.float32), [0, 2, 3, 1])
if tf.__version__ >= "1.14":
ms_ssim_val = tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors, filter_size=self.kernel_size)
else:
ms_ssim_val = tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors)
# ssim_multiscale returns values in range [0, 1] (where 1 is completely identical)
# subtract from 1 to get loss
return 1.0 - ms_ssim_val
if tf.__version__ >= "1.14":
ms_ssim_loss = 1.0 - tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors, filter_size=self.kernel_size)
else:
ms_ssim_loss = 1.0 - tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors)
# If use L1 is enabled, use mix of ms-ssim and L1 (weighted by gaussian filters)
# H. Zhao, O. Gallo, I. Frosio and J. Kautz, "Loss Functions for Image Restoration With Neural Networks,"
# in IEEE Transactions on Computational Imaging, vol. 3, no. 1, pp. 47-57, March 2017,
# doi: 10.1109/TCI.2016.2644865.
# https://research.nvidia.com/publication/loss-functions-image-restoration-neural-networks
if self.use_l1:
diff = tf.tile(tf.expand_dims(tf.abs(y_true - y_pred), axis=0), multiples=[self.num_scales, 1, 1, 1, 1])
l1_loss = tf.reduce_mean(tf.reduce_sum(self.gaussian_weights[-1, :, :, :, :] * diff, axis=[0, 3, 4]), axis=[1])
return self.default_l1_alpha * ms_ssim_loss + (1 - self.default_l1_alpha) * l1_loss
return ms_ssim_loss
nn.MsSsim = MsSsim

View file

@ -237,6 +237,19 @@ def gaussian_blur(input, radius=2.0):
return x
nn.gaussian_blur = gaussian_blur
def get_gaussian_weights(batch_size, in_ch, resolution, num_scale=5, sigma=(0.5, 1., 2., 4., 8.)):
w = np.empty((num_scale, batch_size, in_ch, resolution, resolution))
for i in range(num_scale):
gaussian = np.exp(-1.*np.arange(-(resolution/2-0.5), resolution/2+0.5)**2/(2*sigma[i]**2))
gaussian = np.outer(gaussian, gaussian.reshape((resolution, 1))) # extend to 2D
gaussian = gaussian/np.sum(gaussian) # normalization
gaussian = np.reshape(gaussian, (1, 1, resolution, resolution)) # reshape to 3D
gaussian = np.tile(gaussian, (batch_size, in_ch, 1, 1))
w[i, :, :, :, :] = gaussian
return w
nn.get_gaussian_weights = get_gaussian_weights
def style_loss(target, style, gaussian_blur_radius=0.0, loss_weight=1.0, step_size=1):
def sd(content, style, loss_weight):
content_nc = content.shape[ nn.conv2d_ch_axis ]