diff --git a/core/leras/layers/Conv2D.py b/core/leras/layers/Conv2D.py index ae37c50..fd121f0 100644 --- a/core/leras/layers/Conv2D.py +++ b/core/leras/layers/Conv2D.py @@ -23,28 +23,13 @@ class Conv2D(nn.LayerBase): if padding == "SAME": padding = ( (kernel_size - 1) * dilations + 1 ) // 2 elif padding == "VALID": - padding = 0 + padding = None else: raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs") - - if isinstance(padding, int): - if padding != 0: - if nn.data_format == "NHWC": - padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ] - else: - padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ] - else: - padding = None - - if nn.data_format == "NHWC": - strides = [1,strides,strides,1] else: - strides = [1,1,strides,strides] - - if nn.data_format == "NHWC": - dilations = [1,dilations,dilations,1] - else: - dilations = [1,1,dilations,dilations] + padding = int(padding) + + self.in_ch = in_ch self.out_ch = out_ch @@ -93,9 +78,26 @@ class Conv2D(nn.LayerBase): if self.use_wscale: weight = weight * self.wscale - if self.padding is not None: - x = tf.pad (x, self.padding, mode='CONSTANT') + padding = self.padding + if padding is not None: + if nn.data_format == "NHWC": + padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ] + else: + padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ] + x = tf.pad (x, padding, mode='CONSTANT') + + strides = self.strides + if nn.data_format == "NHWC": + strides = [1,strides,strides,1] + else: + strides = [1,1,strides,strides] + dilations = self.dilations + if nn.data_format == "NHWC": + dilations = [1,dilations,dilations,1] + else: + dilations = [1,1,dilations,dilations] + x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations, data_format=nn.data_format) if self.use_bias: if nn.data_format == "NHWC":