diff --git a/core/leras/archis/DeepFakeArchi.py b/core/leras/archis/DeepFakeArchi.py index ffd9201..e2ee408 100644 --- a/core/leras/archis/DeepFakeArchi.py +++ b/core/leras/archis/DeepFakeArchi.py @@ -72,11 +72,22 @@ class DeepFakeArchi(nn.ArchiBase): return x class Encoder(nn.ModelBase): - def on_build(self, in_ch, e_ch): - self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5) + def __init__(self, in_ch, e_ch, **kwargs ): + self.in_ch = in_ch + self.e_ch = e_ch + super().__init__(**kwargs) + + def on_build(self): + self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4, kernel_size=5) def forward(self, inp): return nn.flatten(self.down1(inp)) + + def get_out_res(self, res): + return res // (2**4) + + def get_out_ch(self): + return 512 lowest_dense_res = resolution // (32 if 'd' in opts else 16) @@ -104,9 +115,8 @@ class DeepFakeArchi(nn.ArchiBase): x = self.upscale1(x) return x - @staticmethod - def get_code_res(): - return lowest_dense_res + def get_out_res(self): + return lowest_dense_res * 2 def get_out_ch(self): return self.ae_out_ch diff --git a/core/leras/models/ModelBase.py b/core/leras/models/ModelBase.py index 77ac284..cc558a4 100644 --- a/core/leras/models/ModelBase.py +++ b/core/leras/models/ModelBase.py @@ -116,41 +116,32 @@ class ModelBase(nn.Saveable): return self.forward(*args, **kwargs) - def compute_output_shape(self, shapes): - if not self.built: - self.build() + # def compute_output_shape(self, shapes): + # if not self.built: + # self.build() - not_list = False - if not isinstance(shapes, list): - not_list = True - shapes = [shapes] + # not_list = False + # if not isinstance(shapes, list): + # not_list = True + # shapes = [shapes] - with tf.device('/CPU:0'): - # CPU tensors will not impact any performance, only slightly RAM "leakage" - phs = [] - for dtype,sh in shapes: - phs += [ tf.placeholder(dtype, sh) ] + # with tf.device('/CPU:0'): + # # CPU tensors will not impact any performance, only slightly RAM "leakage" + # phs = [] + # for dtype,sh in shapes: + # phs += [ tf.placeholder(dtype, sh) ] - result = self.__call__(phs[0] if not_list else phs) + # result = self.__call__(phs[0] if not_list else phs) - if not isinstance(result, list): - result = [result] + # if not isinstance(result, list): + # result = [result] - result_shapes = [] + # result_shapes = [] - for t in result: - result_shapes += [ t.shape.as_list() ] + # for t in result: + # result_shapes += [ t.shape.as_list() ] - return result_shapes[0] if not_list else result_shapes - - def compute_output_channels(self, shapes): - shape = self.compute_output_shape(shapes) - shape_len = len(shape) - - if shape_len == 4: - if nn.data_format == "NCHW": - return shape[1] - return shape[-1] + # return result_shapes[0] if not_list else result_shapes def build_for_run(self, shapes_list): if not isinstance(shapes_list, list): diff --git a/core/leras/ops/__init__.py b/core/leras/ops/__init__.py index 443549a..648407c 100644 --- a/core/leras/ops/__init__.py +++ b/core/leras/ops/__init__.py @@ -333,7 +333,7 @@ def depth_to_space(x, size): x = tf.reshape(x, (-1, oh, ow, oc, )) return x else: - + return tf.depth_to_space(x, size, data_format=nn.data_format) b,c,h,w = x.shape.as_list() oh, ow = h * size, w * size oc = c // (size * size) @@ -342,7 +342,7 @@ def depth_to_space(x, size): x = tf.transpose(x, (0, 3, 4, 1, 5, 2)) x = tf.reshape(x, (-1, oc, oh, ow)) return x - return tf.depth_to_space(x, size, data_format=nn.data_format) + nn.depth_to_space = depth_to_space diff --git a/models/Model_Quick96/Model.py b/models/Model_Quick96/Model.py index c85354f..3c39e46 100644 --- a/models/Model_Quick96/Model.py +++ b/models/Model_Quick96/Model.py @@ -56,10 +56,10 @@ class QModel(ModelBase): # Initializing model classes with tf.device (models_opt_device): self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') - encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) + encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') - inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) + inter_out_ch = self.inter.get_out_ch() self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 3803504..4a227ed 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -233,10 +233,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... with tf.device (models_opt_device): if 'df' in archi_type: self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') - encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) + encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') - inter_out_ch = self.inter.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) + inter_out_ch = self.inter.get_out_ch() self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') @@ -248,19 +248,18 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... if self.is_training: if self.options['true_face_power'] != 0: - self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_code_res()*2, name='dis' ) + self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=model_archi.Inter.get_out_res(), name='dis' ) self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ] elif 'liae' in archi_type: self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') - encoder_out_ch = self.encoder.compute_output_channels ( (nn.floatx, bgr_shape)) + encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB') self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B') - inter_AB_out_ch = self.inter_AB.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) - inter_B_out_ch = self.inter_B.compute_output_channels ( (nn.floatx, (None,encoder_out_ch))) - inters_out_ch = inter_AB_out_ch+inter_B_out_ch + inter_out_ch = self.inter_AB.get_out_ch() + inters_out_ch = inter_out_ch*2 self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder') self.model_filename_list += [ [self.encoder, 'encoder.npy'],