From 6aa585fa3300a0fe853c62ed31abc913261909e7 Mon Sep 17 00:00:00 2001 From: Colombo Date: Wed, 19 Feb 2020 07:00:29 +0400 Subject: [PATCH] fix leras --- core/leras/layers.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/core/leras/layers.py b/core/leras/layers.py index 276e003..0a8191f 100644 --- a/core/leras/layers.py +++ b/core/leras/layers.py @@ -318,7 +318,12 @@ def initialize_layers(nn): class BlurPool(LayerBase): def __init__(self, filt_size=3, stride=2, **kwargs ): - self.strides = [1,stride,stride,1] + + if nn.data_format == "NHWC": + self.strides = [1,stride,stride,1] + else: + self.strides = [1,1,stride,stride] + self.filt_size = filt_size pad = [ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ] @@ -352,9 +357,9 @@ def initialize_layers(nn): self.k = tf.constant (self.a, dtype=nn.tf_floatx ) def __call__(self, x): - k = tf.tile (self.k, (1,1,x.shape[-1],1) ) + k = tf.tile (self.k, (1,1,x.shape[nn.conv2d_ch_axis],1) ) x = tf.pad(x, self.padding ) - x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID') + x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID', data_format=nn.data_format) return x nn.BlurPool = BlurPool