fix depth_to_space for tf2.4.0. Removing compute_output_shape in leras, because it uses CPU device, which does not support all ops.

This commit is contained in:
iperov 2020-12-11 11:28:33 +04:00
commit b9c9e7cffd
5 changed files with 44 additions and 44 deletions

View file

@ -333,7 +333,7 @@ def depth_to_space(x, size):
x = tf.reshape(x, (-1, oh, ow, oc, ))
return x
else:
return tf.depth_to_space(x, size, data_format=nn.data_format)
b,c,h,w = x.shape.as_list()
oh, ow = h * size, w * size
oc = c // (size * size)
@ -342,7 +342,7 @@ def depth_to_space(x, size):
x = tf.transpose(x, (0, 3, 4, 1, 5, 2))
x = tf.reshape(x, (-1, oc, oh, ow))
return x
return tf.depth_to_space(x, size, data_format=nn.data_format)
nn.depth_to_space = depth_to_space