mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 21:42:08 -07:00
fix depth_to_space for tf2.4.0. Removing compute_output_shape in leras, because it uses CPU device, which does not support all ops.
This commit is contained in:
parent
bbf3a71a96
commit
b9c9e7cffd
5 changed files with 44 additions and 44 deletions
|
@ -72,11 +72,22 @@ class DeepFakeArchi(nn.ArchiBase):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class Encoder(nn.ModelBase):
|
class Encoder(nn.ModelBase):
|
||||||
def on_build(self, in_ch, e_ch):
|
def __init__(self, in_ch, e_ch, **kwargs ):
|
||||||
self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5)
|
self.in_ch = in_ch
|
||||||
|
self.e_ch = e_ch
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def on_build(self):
|
||||||
|
self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4, kernel_size=5)
|
||||||
|
|
||||||
def forward(self, inp):
|
def forward(self, inp):
|
||||||
return nn.flatten(self.down1(inp))
|
return nn.flatten(self.down1(inp))
|
||||||
|
|
||||||
|
def get_out_res(self, res):
|
||||||
|
return res // (2**4)
|
||||||
|
|
||||||
|
def get_out_ch(self):
|
||||||
|
return 512
|
||||||
|
|
||||||
lowest_dense_res = resolution // (32 if 'd' in opts else 16)
|
lowest_dense_res = resolution // (32 if 'd' in opts else 16)
|
||||||
|
|
||||||
|
@ -104,9 +115,8 @@ class DeepFakeArchi(nn.ArchiBase):
|
||||||
x = self.upscale1(x)
|
x = self.upscale1(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
def get_out_res(self):
|
||||||
def get_code_res():
|
return lowest_dense_res * 2
|
||||||
return lowest_dense_res
|
|
||||||
|
|
||||||
def get_out_ch(self):
|
def get_out_ch(self):
|
||||||
return self.ae_out_ch
|
return self.ae_out_ch
|
||||||
|
|
|
@ -116,41 +116,32 @@ class ModelBase(nn.Saveable):
|
||||||
|
|
||||||
return self.forward(*args, **kwargs)
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
def compute_output_shape(self, shapes):
|
# def compute_output_shape(self, shapes):
|
||||||
if not self.built:
|
# if not self.built:
|
||||||
self.build()
|
# self.build()
|
||||||
|
|
||||||
not_list = False
|
# not_list = False
|
||||||
if not isinstance(shapes, list):
|
# if not isinstance(shapes, list):
|
||||||
not_list = True
|
# not_list = True
|
||||||
shapes = [shapes]
|
# shapes = [shapes]
|
||||||
|
|
||||||
with tf.device('/CPU:0'):
|
# with tf.device('/CPU:0'):
|
||||||
# CPU tensors will not impact any performance, only slightly RAM "leakage"
|
# # CPU tensors will not impact any performance, only slightly RAM "leakage"
|
||||||
phs = []
|
# phs = []
|
||||||
for dtype,sh in shapes:
|
# for dtype,sh in shapes:
|
||||||
phs += [ tf.placeholder(dtype, sh) ]
|
# phs += [ tf.placeholder(dtype, sh) ]
|
||||||
|
|
||||||
result = self.__call__(phs[0] if not_list else phs)
|
# result = self.__call__(phs[0] if not_list else phs)
|
||||||
|
|
||||||
if not isinstance(result, list):
|
# if not isinstance(result, list):
|
||||||
result = [result]
|
# result = [result]
|
||||||
|
|
||||||
result_shapes = []
|
# result_shapes = []
|
||||||
|
|
||||||
for t in result:
|
# for t in result:
|
||||||
result_shapes += [ t.shape.as_list() ]
|
# result_shapes += [ t.shape.as_list() ]
|
||||||
|
|
||||||
return result_shapes[0] if not_list else result_shapes
|
# return result_shapes[0] if not_list else result_shapes
|
||||||
|
|
||||||
def compute_output_channels(self, shapes):
|
|
||||||
shape = self.compute_output_shape(shapes)
|
|
||||||
shape_len = len(shape)
|
|
||||||
|
|
||||||
if shape_len == 4:
|
|
||||||
if nn.data_format == "NCHW":
|
|
||||||
return shape[1]
|
|
||||||
return shape[-1]
|
|
||||||
|
|
||||||
def build_for_run(self, shapes_list):
|
def build_for_run(self, shapes_list):
|
||||||
if not isinstance(shapes_list, list):
|
if not isinstance(shapes_list, list):
|
||||||
|
|
|
@ -333,7 +333,7 @@ def depth_to_space(x, size):
|
||||||
x = tf.reshape(x, (-1, oh, ow, oc, ))
|
x = tf.reshape(x, (-1, oh, ow, oc, ))
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
|
return tf.depth_to_space(x, size, data_format=nn.data_format)
|
||||||
b,c,h,w = x.shape.as_list()
|
b,c,h,w = x.shape.as_list()
|
||||||
oh, ow = h * size, w * size
|
oh, ow = h * size, w * size
|
||||||
oc = c // (size * size)
|
oc = c // (size * size)
|
||||||
|
@ -342,7 +342,7 @@ def depth_to_space(x, size):
|
||||||
x = tf.transpose(x, (0, 3, 4, 1, 5, 2))
|
x = tf.transpose(x, (0, 3, 4, 1, 5, 2))
|
||||||
x = tf.reshape(x, (-1, oc, oh, ow))
|
x = tf.reshape(x, (-1, oc, oh, ow))
|
||||||
return x
|
return x
|
||||||
return tf.depth_to_space(x, size, data_format=nn.data_format)
|
|
||||||
|
|
||||||
nn.depth_to_space = depth_to_space
|
nn.depth_to_space = depth_to_space
|
||||||
|
|
||||||
|
|
|
@ -56,10 +56,10 @@ class QModel(ModelBase):
|
||||||
# Initializing model classes
|
# Initializing model classes
|
||||||
with tf.device (models_opt_device):
|
with tf.device (models_opt_device):
|
||||||
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
|
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
|
||||||
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
|
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2
|
||||||
|
|
||||||
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
|
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
|
||||||
inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
|
inter_out_ch = self.inter.get_out_ch()
|
||||||
|
|
||||||
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src')
|
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src')
|
||||||
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst')
|
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst')
|
||||||
|
|
|
@ -233,10 +233,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
with tf.device (models_opt_device):
|
with tf.device (models_opt_device):
|
||||||
if 'df' in archi_type:
|
if 'df' in archi_type:
|
||||||
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
|
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
|
||||||
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
|
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2
|
||||||
|
|
||||||
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
|
self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter')
|
||||||
inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
|
inter_out_ch = self.inter.get_out_ch()
|
||||||
|
|
||||||
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src')
|
self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src')
|
||||||
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst')
|
self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst')
|
||||||
|
@ -248,19 +248,18 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
|
|
||||||
if self.is_training:
|
if self.is_training:
|
||||||
if self.options['true_face_power'] != 0:
|
if self.options['true_face_power'] != 0:
|
||||||
self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' )
|
self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_out_res(), name='dis' )
|
||||||
self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ]
|
self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ]
|
||||||
|
|
||||||
elif 'liae' in archi_type:
|
elif 'liae' in archi_type:
|
||||||
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
|
self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder')
|
||||||
encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape))
|
encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2
|
||||||
|
|
||||||
self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB')
|
self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB')
|
||||||
self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B')
|
self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B')
|
||||||
|
|
||||||
inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
|
inter_out_ch = self.inter_AB.get_out_ch()
|
||||||
inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch)))
|
inters_out_ch = inter_out_ch*2
|
||||||
inters_out_ch = inter_AB_out_ch+inter_B_out_ch
|
|
||||||
self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder')
|
self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder')
|
||||||
|
|
||||||
self.model_filename_list += [ [self.encoder, 'encoder.npy'],
|
self.model_filename_list += [ [self.encoder, 'encoder.npy'],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue