update leras

This commit is contained in:
Colombo 2020-02-17 18:26:19 +04:00
commit e0a55ff1c3
3 changed files with 39 additions and 4 deletions

View file

@ -129,6 +129,14 @@ def initialize_tensor_ops(nn):
nn.tf_flatten = tf_flatten
def tf_max_pool(x, kernel_size, strides):
if nn.data_format == "NHWC":
return tf.nn.max_pool(x, [1,kernel_size,kernel_size,1], [1,strides,strides,1], "VALID", data_format=nn.data_format)
else:
return tf.nn.max_pool(x, [1,1,kernel_size,kernel_size], [1,1,strides,strides], "VALID", data_format=nn.data_format)
nn.tf_max_pool = tf_max_pool
def tf_reshape_4D(x, w,h,c):
if nn.data_format == "NHWC":
# match NCHW version in order to switch data_format without problems