This commit is contained in:
Colombo 2020-01-26 12:56:21 +04:00
commit c485e1718a
4 changed files with 10 additions and 22 deletions

View file

@ -182,23 +182,17 @@ class nn():
nn.conv2d_spatial_axes = [2,3]
@staticmethod
def get4Dshape ( w, h, c, data_format=None ):
def get4Dshape ( w, h, c ):
"""
returns 4D shape based on current data_format
"""
if data_format is None:
data_format = nn.data_format
if data_format == "NHWC":
if nn.data_format == "NHWC":
return (None,h,w,c)
else:
return (None,c,h,w)
@staticmethod
def to_data_format( x, to_data_format, from_data_format=None):
if from_data_format is None:
from_data_format = nn.data_format
def to_data_format( x, to_data_format, from_data_format):
if to_data_format == from_data_format:
return x