mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-31 04:00:11 -07:00
fix leras
This commit is contained in:
parent
5ce6f7ef45
commit
6aa585fa33
1 changed files with 8 additions and 3 deletions
|
@ -318,7 +318,12 @@ def initialize_layers(nn):
|
|||
|
||||
class BlurPool(LayerBase):
|
||||
def __init__(self, filt_size=3, stride=2, **kwargs ):
|
||||
self.strides = [1,stride,stride,1]
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
self.strides = [1,stride,stride,1]
|
||||
else:
|
||||
self.strides = [1,1,stride,stride]
|
||||
|
||||
self.filt_size = filt_size
|
||||
pad = [ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ]
|
||||
|
||||
|
@ -352,9 +357,9 @@ def initialize_layers(nn):
|
|||
self.k = tf.constant (self.a, dtype=nn.tf_floatx )
|
||||
|
||||
def __call__(self, x):
|
||||
k = tf.tile (self.k, (1,1,x.shape[-1],1) )
|
||||
k = tf.tile (self.k, (1,1,x.shape[nn.conv2d_ch_axis],1) )
|
||||
x = tf.pad(x, self.padding )
|
||||
x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID')
|
||||
x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID', data_format=nn.data_format)
|
||||
return x
|
||||
nn.BlurPool = BlurPool
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue