fix leras

This commit is contained in:
Colombo 2020-02-19 07:00:29 +04:00
parent 5ce6f7ef45
commit 6aa585fa33

View file

@ -318,7 +318,12 @@ def initialize_layers(nn):
class BlurPool(LayerBase):
def __init__(self, filt_size=3, stride=2, **kwargs ):
self.strides = [1,stride,stride,1]
if nn.data_format == "NHWC":
self.strides = [1,stride,stride,1]
else:
self.strides = [1,1,stride,stride]
self.filt_size = filt_size
pad = [ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ]
@ -352,9 +357,9 @@ def initialize_layers(nn):
self.k = tf.constant (self.a, dtype=nn.tf_floatx )
def __call__(self, x):
k = tf.tile (self.k, (1,1,x.shape[-1],1) )
k = tf.tile (self.k, (1,1,x.shape[nn.conv2d_ch_axis],1) )
x = tf.pad(x, self.padding )
x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID')
x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID', data_format=nn.data_format)
return x
nn.BlurPool = BlurPool