mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-14 02:37:00 -07:00
update leras
This commit is contained in:
parent
e1635ff760
commit
e0a55ff1c3
3 changed files with 39 additions and 4 deletions
|
@ -129,6 +129,14 @@ def initialize_tensor_ops(nn):
|
|||
|
||||
nn.tf_flatten = tf_flatten
|
||||
|
||||
def tf_max_pool(x, kernel_size, strides):
|
||||
if nn.data_format == "NHWC":
|
||||
return tf.nn.max_pool(x, [1,kernel_size,kernel_size,1], [1,strides,strides,1], "VALID", data_format=nn.data_format)
|
||||
else:
|
||||
return tf.nn.max_pool(x, [1,1,kernel_size,kernel_size], [1,1,strides,strides], "VALID", data_format=nn.data_format)
|
||||
|
||||
nn.tf_max_pool = tf_max_pool
|
||||
|
||||
def tf_reshape_4D(x, w,h,c):
|
||||
if nn.data_format == "NHWC":
|
||||
# match NCHW version in order to switch data_format without problems
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue