leras ops: pixelnorm, total_variation_mse

This commit is contained in:
iperov 2020-12-30 14:32:07 +04:00
commit 241d1a9c35

View file

@ -342,10 +342,13 @@ def depth_to_space(x, size):
x = tf.transpose(x, (0, 3, 4, 1, 5, 2))
x = tf.reshape(x, (-1, oc, oh, ow))
return x
nn.depth_to_space = depth_to_space
def pixel_norm(x, power = 1.0):
return x * power * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=nn.conv2d_spatial_axes, keepdims=True) + 1e-06)
nn.pixel_norm = pixel_norm
def rgb_to_lab(srgb):
srgb_pixels = tf.reshape(srgb, [-1, 3])
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
@ -376,6 +379,18 @@ def rgb_to_lab(srgb):
return tf.reshape(lab_pixels, tf.shape(srgb))
nn.rgb_to_lab = rgb_to_lab
def total_variation_mse(images):
"""
Same as generic total_variation, but MSE diff instead of MAE
"""
pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :]
pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :]
tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) +
tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) )
return tot_var
nn.total_variation_mse = total_variation_mse
"""
def tf_suppress_lower_mean(t, eps=0.00001):
if t.shape.ndims != 1: