pixel_norm op

This commit is contained in:
iperov 2021-05-25 14:26:48 +04:00
parent 757283d10e
commit e6e2ee7466

View file

@ -204,7 +204,7 @@ def random_binomial(shape, p=0.0, dtype=None, seed=None):
seed = np.random.randint(10e6) seed = np.random.randint(10e6)
return array_ops.where( return array_ops.where(
random_ops.random_uniform(shape, dtype=tf.float16, seed=seed) < p, random_ops.random_uniform(shape, dtype=tf.float16, seed=seed) < p,
array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype)) array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
nn.random_binomial = random_binomial nn.random_binomial = random_binomial
def gaussian_blur(input, radius=2.0): def gaussian_blur(input, radius=2.0):
@ -391,6 +391,11 @@ def total_variation_mse(images):
return tot_var return tot_var
nn.total_variation_mse = total_variation_mse nn.total_variation_mse = total_variation_mse
def pixel_norm(x, axes):
return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=axes, keepdims=True) + 1e-06)
nn.pixel_norm = pixel_norm
""" """
def tf_suppress_lower_mean(t, eps=0.00001): def tf_suppress_lower_mean(t, eps=0.00001):
if t.shape.ndims != 1: if t.shape.ndims != 1: