diff --git a/core/leras/ops/__init__.py b/core/leras/ops/__init__.py index 7cc30ac..ab42fa1 100644 --- a/core/leras/ops/__init__.py +++ b/core/leras/ops/__init__.py @@ -137,7 +137,39 @@ def resize2d_bilinear(x, size=2): return x nn.resize2d_bilinear = resize2d_bilinear +def resize2d_nearest(x, size=2): + if size in [-1,0,1]: + return x + + """ + if size > 0: + raise Exception("") + else: + if nn.data_format == "NCHW": + x = x[:,:,::-size,::-size] + else: + x = x[:,::-size,::-size,:] + return x + """ + + h = x.shape[nn.conv2d_spatial_axes[0]].value + w = x.shape[nn.conv2d_spatial_axes[1]].value + if nn.data_format == "NCHW": + x = tf.transpose(x, (0,2,3,1)) + + if size > 0: + new_size = (h*size,w*size) + else: + new_size = (h//-size,w//-size) + + x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) + + if nn.data_format == "NCHW": + x = tf.transpose(x, (0,3,1,2)) + + return x +nn.resize2d_nearest = resize2d_nearest def flatten(x): if nn.data_format == "NHWC": @@ -237,7 +269,9 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03 img_dtype = img1.dtype img1 = tf.cast(img1, tf.float32) img2 = tf.cast(img2, tf.float32) - + + filter_size = max(1, filter_size) + kernel = np.arange(0, filter_size, dtype=np.float32) kernel -= (filter_size - 1 ) / 2.0 kernel = kernel**2