This commit is contained in:
Colombo 2020-01-26 12:56:21 +04:00
commit c485e1718a
4 changed files with 10 additions and 22 deletions

View file

@ -182,23 +182,17 @@ class nn():
nn.conv2d_spatial_axes = [2,3]
@staticmethod
def get4Dshape ( w, h, c, data_format=None ):
def get4Dshape ( w, h, c ):
"""
returns 4D shape based on current data_format
"""
if data_format is None:
data_format = nn.data_format
if data_format == "NHWC":
if nn.data_format == "NHWC":
return (None,h,w,c)
else:
return (None,c,h,w)
@staticmethod
def to_data_format( x, to_data_format, from_data_format=None):
if from_data_format is None:
from_data_format = nn.data_format
def to_data_format( x, to_data_format, from_data_format):
if to_data_format == from_data_format:
return x

View file

@ -35,7 +35,7 @@ def initialize_tensor_ops(nn):
gv = [*zip(grads,vars)]
for g,v in gv:
if g is None:
raise Exception("No gradient for variable {v.name}")
raise Exception(f"No gradient for variable {v.name}")
return gv
nn.tf_gradients = tf_gradients