mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 13:32:09 -07:00
_
This commit is contained in:
parent
24ba84d4a5
commit
5dc027a8b0
1 changed files with 23 additions and 21 deletions
|
@ -23,28 +23,13 @@ class Conv2D(nn.LayerBase):
|
||||||
if padding == "SAME":
|
if padding == "SAME":
|
||||||
padding = ( (kernel_size - 1) * dilations + 1 ) // 2
|
padding = ( (kernel_size - 1) * dilations + 1 ) // 2
|
||||||
elif padding == "VALID":
|
elif padding == "VALID":
|
||||||
padding = 0
|
padding = None
|
||||||
else:
|
else:
|
||||||
raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs")
|
raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs")
|
||||||
|
|
||||||
if isinstance(padding, int):
|
|
||||||
if padding != 0:
|
|
||||||
if nn.data_format == "NHWC":
|
|
||||||
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
|
||||||
else:
|
|
||||||
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
|
||||||
else:
|
|
||||||
padding = None
|
|
||||||
|
|
||||||
if nn.data_format == "NHWC":
|
|
||||||
strides = [1,strides,strides,1]
|
|
||||||
else:
|
else:
|
||||||
strides = [1,1,strides,strides]
|
padding = int(padding)
|
||||||
|
|
||||||
if nn.data_format == "NHWC":
|
|
||||||
dilations = [1,dilations,dilations,1]
|
|
||||||
else:
|
|
||||||
dilations = [1,1,dilations,dilations]
|
|
||||||
|
|
||||||
self.in_ch = in_ch
|
self.in_ch = in_ch
|
||||||
self.out_ch = out_ch
|
self.out_ch = out_ch
|
||||||
|
@ -93,9 +78,26 @@ class Conv2D(nn.LayerBase):
|
||||||
if self.use_wscale:
|
if self.use_wscale:
|
||||||
weight = weight * self.wscale
|
weight = weight * self.wscale
|
||||||
|
|
||||||
if self.padding is not None:
|
padding = self.padding
|
||||||
x = tf.pad (x, self.padding, mode='CONSTANT')
|
if padding is not None:
|
||||||
|
if nn.data_format == "NHWC":
|
||||||
|
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
||||||
|
else:
|
||||||
|
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
||||||
|
x = tf.pad (x, padding, mode='CONSTANT')
|
||||||
|
|
||||||
|
strides = self.strides
|
||||||
|
if nn.data_format == "NHWC":
|
||||||
|
strides = [1,strides,strides,1]
|
||||||
|
else:
|
||||||
|
strides = [1,1,strides,strides]
|
||||||
|
|
||||||
|
dilations = self.dilations
|
||||||
|
if nn.data_format == "NHWC":
|
||||||
|
dilations = [1,dilations,dilations,1]
|
||||||
|
else:
|
||||||
|
dilations = [1,1,dilations,dilations]
|
||||||
|
|
||||||
x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations, data_format=nn.data_format)
|
x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations, data_format=nn.data_format)
|
||||||
if self.use_bias:
|
if self.use_bias:
|
||||||
if nn.data_format == "NHWC":
|
if nn.data_format == "NHWC":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue