From 3c6bbe22b90c6a0f7bdd0b14ce38a3da5ddda350 Mon Sep 17 00:00:00 2001 From: iperov Date: Sat, 26 Jun 2021 10:55:33 +0400 Subject: [PATCH] AMP: added Inter dimensions param. The model is not changed. Should be equal or more than AutoEncoder dimensions. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU. --- models/Model_AMP/Model.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index 5b75036..1588f56 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -36,6 +36,12 @@ class AMPModel(ModelBase): default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) + + inter_dims = self.load_or_def_option('inter_dims', None) + if inter_dims is None: + inter_dims = self.options['ae_dims'] + default_inter_dims = self.options['inter_dims'] = inter_dims + default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) @@ -77,7 +83,9 @@ class AMPModel(ModelBase): default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims) if self.is_first_run(): - self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) + self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) + self.options['inter_dims'] = np.clip ( io.input_int("Inter dimensions", default_inter_dims, add_info="32-2048", help_message="Should be equal or more than AutoEncoder dimensions. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 2048 ) + e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) self.options['e_dims'] = e_dims + e_dims % 2 @@ -139,6 +147,7 @@ class AMPModel(ModelBase): input_ch=3 ae_dims = self.ae_dims = self.options['ae_dims'] + inter_dims = self.inter_dims = self.options['inter_dims'] e_dims = self.options['e_dims'] d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] @@ -203,13 +212,13 @@ class AMPModel(ModelBase): def forward(self, inp): x = inp x = self.dense2(x) - x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, ae_dims) + x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, inter_dims) return x class Decoder(nn.ModelBase): def on_build(self ): - self.upscale0 = Upscale(ae_dims, d_dims*8, kernel_size=3) + self.upscale0 = Upscale(inter_dims, d_dims*8, kernel_size=3) self.upscale1 = Upscale(d_dims*8, d_dims*8, kernel_size=3) self.upscale2 = Upscale(d_dims*8, d_dims*4, kernel_size=3) self.upscale3 = Upscale(d_dims*4, d_dims*2, kernel_size=3) @@ -219,7 +228,7 @@ class AMPModel(ModelBase): self.res2 = ResidualBlock(d_dims*4, kernel_size=3) self.res3 = ResidualBlock(d_dims*2, kernel_size=3) - self.upscalem0 = Upscale(ae_dims, d_mask_dims*8, kernel_size=3) + self.upscalem0 = Upscale(inter_dims, d_mask_dims*8, kernel_size=3) self.upscalem1 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3) self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3) self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3) @@ -402,9 +411,9 @@ class AMPModel(ModelBase): gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial) gpu_dst_code = gpu_dst_inter_dst_code - ae_dims_slice = tf.cast(ae_dims*self.morph_value_t[0], tf.int32) - gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, ae_dims_slice , lowest_dense_res, lowest_dense_res]), - tf.slice(gpu_dst_inter_dst_code, [0,ae_dims_slice,0,0], [-1,ae_dims-ae_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 ) + inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) + gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , lowest_dense_res, lowest_dense_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,ae_dims-inter_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 ) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) @@ -558,9 +567,9 @@ class AMPModel(ModelBase): gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code) gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code) - ae_dims_slice = tf.cast(ae_dims*self.morph_value_t[0], tf.int32) - gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, ae_dims_slice , lowest_dense_res, lowest_dense_res]), - tf.slice(gpu_dst_inter_dst_code, [0,ae_dims_slice,0,0], [-1,ae_dims-ae_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 ) + inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) + gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , lowest_dense_res, lowest_dense_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,ae_dims-inter_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 ) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code)