diff --git a/core/leras/ops/__init__.py b/core/leras/ops/__init__.py index 648407c..500a22a 100644 --- a/core/leras/ops/__init__.py +++ b/core/leras/ops/__init__.py @@ -140,7 +140,7 @@ nn.resize2d_bilinear = resize2d_bilinear def resize2d_nearest(x, size=2): if size in [-1,0,1]: return x - + if size > 0: raise Exception("") @@ -150,7 +150,7 @@ def resize2d_nearest(x, size=2): else: x = x[:,::-size,::-size,:] return x - + h = x.shape[nn.conv2d_spatial_axes[0]].value w = x.shape[nn.conv2d_spatial_axes[1]].value @@ -268,9 +268,9 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03 img_dtype = img1.dtype img1 = tf.cast(img1, tf.float32) img2 = tf.cast(img2, tf.float32) - + filter_size = max(1, filter_size) - + kernel = np.arange(0, filter_size, dtype=np.float32) kernel -= (filter_size - 1 ) / 2.0 kernel = kernel**2 @@ -342,10 +342,13 @@ def depth_to_space(x, size): x = tf.transpose(x, (0, 3, 4, 1, 5, 2)) x = tf.reshape(x, (-1, oc, oh, ow)) return x - - nn.depth_to_space = depth_to_space +def pixel_norm(x, power = 1.0): + return x * power * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=nn.conv2d_spatial_axes, keepdims=True) + 1e-06) +nn.pixel_norm = pixel_norm + + def rgb_to_lab(srgb): srgb_pixels = tf.reshape(srgb, [-1, 3]) linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32) @@ -376,6 +379,18 @@ def rgb_to_lab(srgb): return tf.reshape(lab_pixels, tf.shape(srgb)) nn.rgb_to_lab = rgb_to_lab +def total_variation_mse(images): + """ + Same as generic total_variation, but MSE diff instead of MAE + """ + pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :] + pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :] + + tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) + + tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) ) + return tot_var +nn.total_variation_mse = total_variation_mse + """ def tf_suppress_lower_mean(t, eps=0.00001): if t.shape.ndims != 1: