From f07bbd5fb0b48174cc36ba19652f0f08e3f4b22c Mon Sep 17 00:00:00 2001 From: Colombo Date: Wed, 26 Feb 2020 13:32:32 +0400 Subject: [PATCH] update leras --- core/leras/tensor_ops.py | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/core/leras/tensor_ops.py b/core/leras/tensor_ops.py index 9b0f177..604ebdf 100644 --- a/core/leras/tensor_ops.py +++ b/core/leras/tensor_ops.py @@ -117,9 +117,27 @@ def initialize_tensor_ops(nn): return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) ) nn.tf_upsample2d = tf_upsample2d - def tf_upsample2d_bilinear(x, size=2): - return tf.image.resize_images(x, (x.shape[1]*size, x.shape[2]*size) ) - nn.tf_upsample2d_bilinear = tf_upsample2d_bilinear + def tf_resize2d_bilinear(x, size=2): + h = x.shape[nn.conv2d_spatial_axes[0]].value + w = x.shape[nn.conv2d_spatial_axes[1]].value + + if nn.data_format == "NCHW": + x = tf.transpose(x, (0,2,3,1)) + + if size > 0: + new_size = (h*size,w*size) + else: + new_size = (h//-size,w//-size) + + x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.BILINEAR) + + if nn.data_format == "NCHW": + x = tf.transpose(x, (0,3,1,2)) + + return x + nn.tf_resize2d_bilinear = tf_resize2d_bilinear + + def tf_flatten(x): if nn.data_format == "NHWC": @@ -181,16 +199,11 @@ def initialize_tensor_ops(nn): padding = None gauss_kernel = gauss_kernel[:,:,None,None] - outputs = [] - for i in range(input.shape[nn.conv2d_ch_axis]): - x = input[:,:,:,i:i+1] if nn.data_format == "NHWC" \ - else input[:,i:i+1,:,:] - - if padding is not None: - x = tf.pad (x, padding) - outputs += [ tf.nn.conv2d(x, tf.constant(gauss_kernel, dtype=input.dtype ), strides=[1,1,1,1], padding="VALID", data_format=nn.data_format) ] - - return tf.concat (outputs, axis=nn.conv2d_ch_axis) + x = input + k = tf.tile (gauss_kernel, (1,1,x.shape[nn.conv2d_ch_axis],1) ) + x = tf.pad(x, padding ) + x = tf.nn.depthwise_conv2d(x, k, strides=[1,1,1,1], padding='VALID', data_format=nn.data_format) + return x nn.tf_gaussian_blur = tf_gaussian_blur def tf_style_loss(target, style, gaussian_blur_radius=0.0, loss_weight=1.0, step_size=1):