fix depth_to_space for tf2.4.0. Removing compute_output_shape in leras, because it uses CPU device, which does not support all ops.

This commit is contained in:
iperov 2020-12-11 11:28:33 +04:00
commit b9c9e7cffd
5 changed files with 44 additions and 44 deletions

View file

@ -72,11 +72,22 @@ class DeepFakeArchi(nn.ArchiBase):
return x
class Encoder(nn.ModelBase):
def on_build(self, in_ch, e_ch):
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5)
def __init__(self, in_ch, e_ch, **kwargs ):
self.in_ch = in_ch
self.e_ch = e_ch
super().__init__(**kwargs)
def on_build(self):
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4, kernel_size=5)
def forward(self, inp):
return nn.flatten(self.down1(inp))
def get_out_res(self, res):
return res // (2**4)
def get_out_ch(self):
return 512
lowest_dense_res = resolution // (32 if 'd' in opts else 16)
@ -104,9 +115,8 @@ class DeepFakeArchi(nn.ArchiBase):
x = self.upscale1(x)
return x
@staticmethod
def get_code_res():
return lowest_dense_res
def get_out_res(self):
return lowest_dense_res * 2
def get_out_ch(self):
return self.ae_out_ch

View file

@ -116,41 +116,32 @@ class ModelBase(nn.Saveable):
return self.forward(*args, **kwargs)
def compute_output_shape(self, shapes):
if not self.built:
self.build()
# def compute_output_shape(self, shapes):
# if not self.built:
# self.build()
not_list = False
if not isinstance(shapes, list):
not_list = True
shapes = [shapes]
# not_list = False
# if not isinstance(shapes, list):
# not_list = True
# shapes = [shapes]
with tf.device('/CPU:0'):
# CPU tensors will not impact any performance, only slightly RAM "leakage"
phs = []
for dtype,sh in shapes:
phs += [ tf.placeholder(dtype, sh) ]
# with tf.device('/CPU:0'):
# # CPU tensors will not impact any performance, only slightly RAM "leakage"
# phs = []
# for dtype,sh in shapes:
# phs += [ tf.placeholder(dtype, sh) ]
result = self.__call__(phs[0] if not_list else phs)
# result = self.__call__(phs[0] if not_list else phs)
if not isinstance(result, list):
result = [result]
# if not isinstance(result, list):
# result = [result]
result_shapes = []
# result_shapes = []
for t in result:
result_shapes += [ t.shape.as_list() ]
# for t in result:
# result_shapes += [ t.shape.as_list() ]
return result_shapes[0] if not_list else result_shapes
def compute_output_channels(self, shapes):
shape = self.compute_output_shape(shapes)
shape_len = len(shape)
if shape_len == 4:
if nn.data_format == "NCHW":
return shape[1]
return shape[-1]
# return result_shapes[0] if not_list else result_shapes
def build_for_run(self, shapes_list):
if not isinstance(shapes_list, list):

View file

@ -333,7 +333,7 @@ def depth_to_space(x, size):
x = tf.reshape(x, (-1, oh, ow, oc, ))
return x
else:
return tf.depth_to_space(x, size, data_format=nn.data_format)
b,c,h,w = x.shape.as_list()
oh, ow = h * size, w * size
oc = c // (size * size)
@ -342,7 +342,7 @@ def depth_to_space(x, size):
x = tf.transpose(x, (0, 3, 4, 1, 5, 2))
x = tf.reshape(x, (-1, oc, oh, ow))
return x
return tf.depth_to_space(x, size, data_format=nn.data_format)
nn.depth_to_space = depth_to_space