This commit is contained in:
iperov 2021-06-09 19:17:18 +04:00
parent 24ba84d4a5
commit 5dc027a8b0

View file

@ -23,28 +23,13 @@ class Conv2D(nn.LayerBase):
if padding == "SAME": if padding == "SAME":
padding = ( (kernel_size - 1) * dilations + 1 ) // 2 padding = ( (kernel_size - 1) * dilations + 1 ) // 2
elif padding == "VALID": elif padding == "VALID":
padding = 0 padding = None
else: else:
raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs") raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs")
if isinstance(padding, int):
if padding != 0:
if nn.data_format == "NHWC":
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
else:
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
else:
padding = None
if nn.data_format == "NHWC":
strides = [1,strides,strides,1]
else: else:
strides = [1,1,strides,strides] padding = int(padding)
if nn.data_format == "NHWC":
dilations = [1,dilations,dilations,1]
else:
dilations = [1,1,dilations,dilations]
self.in_ch = in_ch self.in_ch = in_ch
self.out_ch = out_ch self.out_ch = out_ch
@ -93,9 +78,26 @@ class Conv2D(nn.LayerBase):
if self.use_wscale: if self.use_wscale:
weight = weight * self.wscale weight = weight * self.wscale
if self.padding is not None: padding = self.padding
x = tf.pad (x, self.padding, mode='CONSTANT') if padding is not None:
if nn.data_format == "NHWC":
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
else:
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
x = tf.pad (x, padding, mode='CONSTANT')
strides = self.strides
if nn.data_format == "NHWC":
strides = [1,strides,strides,1]
else:
strides = [1,1,strides,strides]
dilations = self.dilations
if nn.data_format == "NHWC":
dilations = [1,dilations,dilations,1]
else:
dilations = [1,1,dilations,dilations]
x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations, data_format=nn.data_format) x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations, data_format=nn.data_format)
if self.use_bias: if self.use_bias:
if nn.data_format == "NHWC": if nn.data_format == "NHWC":