mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 04:59:27 -07:00
fix leras
This commit is contained in:
parent
5ce6f7ef45
commit
6aa585fa33
1 changed files with 8 additions and 3 deletions
|
@ -318,7 +318,12 @@ def initialize_layers(nn):
|
||||||
|
|
||||||
class BlurPool(LayerBase):
|
class BlurPool(LayerBase):
|
||||||
def __init__(self, filt_size=3, stride=2, **kwargs ):
|
def __init__(self, filt_size=3, stride=2, **kwargs ):
|
||||||
|
|
||||||
|
if nn.data_format == "NHWC":
|
||||||
self.strides = [1,stride,stride,1]
|
self.strides = [1,stride,stride,1]
|
||||||
|
else:
|
||||||
|
self.strides = [1,1,stride,stride]
|
||||||
|
|
||||||
self.filt_size = filt_size
|
self.filt_size = filt_size
|
||||||
pad = [ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ]
|
pad = [ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ]
|
||||||
|
|
||||||
|
@ -352,9 +357,9 @@ def initialize_layers(nn):
|
||||||
self.k = tf.constant (self.a, dtype=nn.tf_floatx )
|
self.k = tf.constant (self.a, dtype=nn.tf_floatx )
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
k = tf.tile (self.k, (1,1,x.shape[-1],1) )
|
k = tf.tile (self.k, (1,1,x.shape[nn.conv2d_ch_axis],1) )
|
||||||
x = tf.pad(x, self.padding )
|
x = tf.pad(x, self.padding )
|
||||||
x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID')
|
x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID', data_format=nn.data_format)
|
||||||
return x
|
return x
|
||||||
nn.BlurPool = BlurPool
|
nn.BlurPool = BlurPool
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue