fix num_scale arg

This commit is contained in:
jh 2021-04-24 01:06:50 -07:00
commit 55f5fbd199

View file

@ -13,11 +13,11 @@ class MsSsim(nn.LayerBase):
if sum(power_factors) < 1.0: if sum(power_factors) < 1.0:
power_factors = [x/sum(power_factors) for x in power_factors] power_factors = [x/sum(power_factors) for x in power_factors]
self.power_factors = power_factors self.power_factors = power_factors
self.num_scales = len(power_factors) self.num_scale = len(power_factors)
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.use_l1 = use_l1 self.use_l1 = use_l1
if use_l1: if use_l1:
self.gaussian_weights = nn.get_gaussian_weights(batch_size, in_ch, resolution, num_scales=self.num_scales) self.gaussian_weights = nn.get_gaussian_weights(batch_size, in_ch, resolution, num_scale=self.num_scale)
super().__init__(**kwargs) super().__init__(**kwargs)
@ -40,7 +40,7 @@ class MsSsim(nn.LayerBase):
# https://research.nvidia.com/publication/loss-functions-image-restoration-neural-networks # https://research.nvidia.com/publication/loss-functions-image-restoration-neural-networks
if self.use_l1: if self.use_l1:
diff = tf.tile(tf.expand_dims(tf.abs(y_true - y_pred), axis=0), multiples=[self.num_scales, 1, 1, 1, 1]) diff = tf.tile(tf.expand_dims(tf.abs(y_true - y_pred), axis=0), multiples=[self.num_scale, 1, 1, 1, 1])
l1_loss = tf.reduce_mean(tf.reduce_sum(self.gaussian_weights[-1, :, :, :, :] * diff, axis=[0, 3, 4]), axis=[1]) l1_loss = tf.reduce_mean(tf.reduce_sum(self.gaussian_weights[-1, :, :, :, :] * diff, axis=[0, 3, 4]), axis=[1])
return self.default_l1_alpha * ms_ssim_loss + (1 - self.default_l1_alpha) * l1_loss return self.default_l1_alpha * ms_ssim_loss + (1 - self.default_l1_alpha) * l1_loss