update leras

This commit is contained in:
Colombo 2020-02-26 13:32:32 +04:00
parent a1cc297cef
commit f07bbd5fb0

View file

@ -117,9 +117,27 @@ def initialize_tensor_ops(nn):
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
nn.tf_upsample2d = tf_upsample2d
def tf_upsample2d_bilinear(x, size=2):
return tf.image.resize_images(x, (x.shape[1]*size, x.shape[2]*size) )
nn.tf_upsample2d_bilinear = tf_upsample2d_bilinear
def tf_resize2d_bilinear(x, size=2):
h = x.shape[nn.conv2d_spatial_axes[0]].value
w = x.shape[nn.conv2d_spatial_axes[1]].value
if nn.data_format == "NCHW":
x = tf.transpose(x, (0,2,3,1))
if size > 0:
new_size = (h*size,w*size)
else:
new_size = (h//-size,w//-size)
x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.BILINEAR)
if nn.data_format == "NCHW":
x = tf.transpose(x, (0,3,1,2))
return x
nn.tf_resize2d_bilinear = tf_resize2d_bilinear
def tf_flatten(x):
if nn.data_format == "NHWC":
@ -181,16 +199,11 @@ def initialize_tensor_ops(nn):
padding = None
gauss_kernel = gauss_kernel[:,:,None,None]
outputs = []
for i in range(input.shape[nn.conv2d_ch_axis]):
x = input[:,:,:,i:i+1] if nn.data_format == "NHWC" \
else input[:,i:i+1,:,:]
if padding is not None:
x = input
k = tf.tile (gauss_kernel, (1,1,x.shape[nn.conv2d_ch_axis],1) )
x = tf.pad(x, padding )
outputs += [ tf.nn.conv2d(x, tf.constant(gauss_kernel, dtype=input.dtype ), strides=[1,1,1,1], padding="VALID", data_format=nn.data_format) ]
return tf.concat (outputs, axis=nn.conv2d_ch_axis)
x = tf.nn.depthwise_conv2d(x, k, strides=[1,1,1,1], padding='VALID', data_format=nn.data_format)
return x
nn.tf_gaussian_blur = tf_gaussian_blur
def tf_style_loss(target, style, gaussian_blur_radius=0.0, loss_weight=1.0, step_size=1):