fix depth_to_space for tf2.4.0. Removing compute_output_shape in leras, because it uses CPU device, which does not support all ops.

This commit is contained in:
iperov 2020-12-11 11:28:33 +04:00
parent bbf3a71a96
commit b9c9e7cffd
5 changed files with 44 additions and 44 deletions

View file

@ -116,41 +116,32 @@ class ModelBase(nn.Saveable):
return self.forward(*args, **kwargs)
def compute_output_shape(self, shapes):
if not self.built:
self.build()
# def compute_output_shape(self, shapes):
# if not self.built:
# self.build()
not_list = False
if not isinstance(shapes, list):
not_list = True
shapes = [shapes]
# not_list = False
# if not isinstance(shapes, list):
# not_list = True
# shapes = [shapes]
with tf.device('/CPU:0'):
# CPU tensors will not impact any performance, only slightly RAM "leakage"
phs = []
for dtype,sh in shapes:
phs += [ tf.placeholder(dtype, sh) ]
# with tf.device('/CPU:0'):
# # CPU tensors will not impact any performance, only slightly RAM "leakage"
# phs = []
# for dtype,sh in shapes:
# phs += [ tf.placeholder(dtype, sh) ]
result = self.__call__(phs[0] if not_list else phs)
# result = self.__call__(phs[0] if not_list else phs)
if not isinstance(result, list):
result = [result]
# if not isinstance(result, list):
# result = [result]
result_shapes = []
# result_shapes = []
for t in result:
result_shapes += [ t.shape.as_list() ]
# for t in result:
# result_shapes += [ t.shape.as_list() ]
return result_shapes[0] if not_list else result_shapes
def compute_output_channels(self, shapes):
shape = self.compute_output_shape(shapes)
shape_len = len(shape)
if shape_len == 4:
if nn.data_format == "NCHW":
return shape[1]
return shape[-1]
# return result_shapes[0] if not_list else result_shapes
def build_for_run(self, shapes_list):
if not isinstance(shapes_list, list):