upd leras ops

This commit is contained in:
Colombo 2020-07-03 19:32:14 +04:00
parent 11dc7d41a0
commit 9a540e644c

View file

@ -137,7 +137,39 @@ def resize2d_bilinear(x, size=2):
return x
nn.resize2d_bilinear = resize2d_bilinear
def resize2d_nearest(x, size=2):
if size in [-1,0,1]:
return x
"""
if size > 0:
raise Exception("")
else:
if nn.data_format == "NCHW":
x = x[:,:,::-size,::-size]
else:
x = x[:,::-size,::-size,:]
return x
"""
h = x.shape[nn.conv2d_spatial_axes[0]].value
w = x.shape[nn.conv2d_spatial_axes[1]].value
if nn.data_format == "NCHW":
x = tf.transpose(x, (0,2,3,1))
if size > 0:
new_size = (h*size,w*size)
else:
new_size = (h//-size,w//-size)
x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
if nn.data_format == "NCHW":
x = tf.transpose(x, (0,3,1,2))
return x
nn.resize2d_nearest = resize2d_nearest
def flatten(x):
if nn.data_format == "NHWC":
@ -237,7 +269,9 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03
img_dtype = img1.dtype
img1 = tf.cast(img1, tf.float32)
img2 = tf.cast(img2, tf.float32)
filter_size = max(1, filter_size)
kernel = np.arange(0, filter_size, dtype=np.float32)
kernel -= (filter_size - 1 ) / 2.0
kernel = kernel**2