mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-16 10:03:41 -07:00
fix depth_to_space for tf2.4.0. Removing compute_output_shape in leras, because it uses CPU device, which does not support all ops.
This commit is contained in:
parent
bbf3a71a96
commit
b9c9e7cffd
5 changed files with 44 additions and 44 deletions
|
@ -116,41 +116,32 @@ class ModelBase(nn.Saveable):
|
|||
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def compute_output_shape(self, shapes):
|
||||
if not self.built:
|
||||
self.build()
|
||||
# def compute_output_shape(self, shapes):
|
||||
# if not self.built:
|
||||
# self.build()
|
||||
|
||||
not_list = False
|
||||
if not isinstance(shapes, list):
|
||||
not_list = True
|
||||
shapes = [shapes]
|
||||
# not_list = False
|
||||
# if not isinstance(shapes, list):
|
||||
# not_list = True
|
||||
# shapes = [shapes]
|
||||
|
||||
with tf.device('/CPU:0'):
|
||||
# CPU tensors will not impact any performance, only slightly RAM "leakage"
|
||||
phs = []
|
||||
for dtype,sh in shapes:
|
||||
phs += [ tf.placeholder(dtype, sh) ]
|
||||
# with tf.device('/CPU:0'):
|
||||
# # CPU tensors will not impact any performance, only slightly RAM "leakage"
|
||||
# phs = []
|
||||
# for dtype,sh in shapes:
|
||||
# phs += [ tf.placeholder(dtype, sh) ]
|
||||
|
||||
result = self.__call__(phs[0] if not_list else phs)
|
||||
# result = self.__call__(phs[0] if not_list else phs)
|
||||
|
||||
if not isinstance(result, list):
|
||||
result = [result]
|
||||
# if not isinstance(result, list):
|
||||
# result = [result]
|
||||
|
||||
result_shapes = []
|
||||
# result_shapes = []
|
||||
|
||||
for t in result:
|
||||
result_shapes += [ t.shape.as_list() ]
|
||||
# for t in result:
|
||||
# result_shapes += [ t.shape.as_list() ]
|
||||
|
||||
return result_shapes[0] if not_list else result_shapes
|
||||
|
||||
def compute_output_channels(self, shapes):
|
||||
shape = self.compute_output_shape(shapes)
|
||||
shape_len = len(shape)
|
||||
|
||||
if shape_len == 4:
|
||||
if nn.data_format == "NCHW":
|
||||
return shape[1]
|
||||
return shape[-1]
|
||||
# return result_shapes[0] if not_list else result_shapes
|
||||
|
||||
def build_for_run(self, shapes_list):
|
||||
if not isinstance(shapes_list, list):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue