From 0748b8d04351c65d4c6b49e5a9bd24258f476ce3 Mon Sep 17 00:00:00 2001 From: iperov Date: Sat, 17 Jul 2021 21:58:57 +0400 Subject: [PATCH] AMP: removed eyes mouth prio option(default enabled), removed masked training option(default enabled). --- models/Model_AMP/Model.py | 372 ++++++++++++++------------------------ 1 file changed, 137 insertions(+), 235 deletions(-) diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index 37789a3..0a5bc18 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -18,44 +18,26 @@ class AMPModel(ModelBase): def on_initialize_options(self): device_config = nn.getCurrentDeviceConfig() - lowest_vram = 2 - if len(device_config.devices) != 0: - lowest_vram = device_config.devices.get_worst_device().total_mem_gb - - if lowest_vram >= 4: - suggest_batch_size = 8 - else: - suggest_batch_size = 4 - - yn_str = {True:'y',False:'n'} - min_res = 64 - max_res = 640 - - #default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False) default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224) default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) - + inter_dims = self.load_or_def_option('inter_dims', None) if inter_dims is None: inter_dims = self.options['ae_dims'] - default_inter_dims = self.options['inter_dims'] = inter_dims - + default_inter_dims = self.options['inter_dims'] = inter_dims + default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) default_morph_factor = self.options['morph_factor'] = self.options.get('morph_factor', 0.5) - default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True) - default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', True) default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False) default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none') default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False) - #default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False) - ask_override = self.ask_override() if self.is_first_run() or ask_override: @@ -64,12 +46,11 @@ class AMPModel(ModelBase): self.ask_target_iter() self.ask_random_src_flip() self.ask_random_dst_flip() - self.ask_batch_size(suggest_batch_size) - #self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters. Lr-dropout should be enabled.') - + self.ask_batch_size(8) + if self.is_first_run(): resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .") - resolution = np.clip ( (resolution // 32) * 32, min_res, max_res) + resolution = np.clip ( (resolution // 32) * 32, 64, 640) self.options['resolution'] = resolution self.options['face_type'] = io.input_str ("Face type", default_face_type, ['f','wf','head'], help_message="whole face / head").lower() @@ -93,15 +74,10 @@ class AMPModel(ModelBase): d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 ) self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2 - morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="The smaller the value, the more src-like facial expressions will appear. The larger the value, the less space there is to train a large dst faceset in the neural network. Typical fine value is 0.33"), 0.1, 0.5 ) + morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="Typical fine value is 0.5"), 0.1, 0.5 ) self.options['morph_factor'] = morph_factor - if self.is_first_run() or ask_override: - if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head': - self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' or 'head' type. Masked training clips training area to full_face mask or XSeg mask, thus network will train the faces properly.") - - self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.') self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) @@ -125,10 +101,7 @@ class AMPModel(ModelBase): self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.") self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") - #self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, uniform_yaw=Y") - self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims']) - #self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) #override def on_initialize(self): @@ -138,30 +111,38 @@ class AMPModel(ModelBase): nn.initialize(data_format=self.model_data_format) tf = nn.tf - self.resolution = resolution = self.options['resolution'] - input_ch=3 - ae_dims = self.ae_dims = self.options['ae_dims'] - inter_dims = self.inter_dims = self.options['inter_dims'] - e_dims = self.options['e_dims'] - d_dims = self.options['d_dims'] + resolution = self.resolution = self.options['resolution'] + e_dims = self.options['e_dims'] + ae_dims = self.options['ae_dims'] + inter_dims = self.inter_dims = self.options['inter_dims'] + inter_res = self.inter_res = resolution // 32 + d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] - inter_res = self.inter_res = resolution // 32 - use_fp16 = False# self.options['use_fp16'] - - ae_use_fp16 = use_fp16 - ae_conv_dtype = tf.float16 if use_fp16 else tf.float32 - + face_type = self.face_type = {'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[ self.options['face_type'] ] + morph_factor = self.options['morph_factor'] + gan_power = self.gan_power = self.options['gan_power'] + random_warp = self.options['random_warp'] + + ct_mode = self.options['ct_mode'] + if ct_mode == 'none': + ct_mode = None + + use_fp16 = self.is_exporting + conv_dtype = tf.float16 if use_fp16 else tf.float32 + class Downscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=5 ): - self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=ae_conv_dtype) + self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype) def forward(self, x): return tf.nn.leaky_relu(self.conv1(x), 0.1) class Upscale(nn.ModelBase): def on_build(self, in_ch, out_ch, kernel_size=3 ): - self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=ae_conv_dtype) + self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) def forward(self, x): x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2) @@ -169,8 +150,8 @@ class AMPModel(ModelBase): class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): - self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=ae_conv_dtype) - self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=ae_conv_dtype) + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) def forward(self, inp): x = self.conv1(inp) @@ -191,7 +172,7 @@ class AMPModel(ModelBase): self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims ) def forward(self, x): - if ae_use_fp16: + if use_fp16: x = tf.cast(x, tf.float16) x = self.down1(x) x = self.res1(x) @@ -200,7 +181,7 @@ class AMPModel(ModelBase): x = self.down4(x) x = self.down5(x) x = self.res5(x) - if ae_use_fp16: + if use_fp16: x = tf.cast(x, tf.float32) x = nn.pixel_norm(nn.flatten(x), axes=-1) x = self.dense1(x) @@ -235,15 +216,15 @@ class AMPModel(ModelBase): self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3) self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3) self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3) - self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=ae_conv_dtype) + self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) - self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=ae_conv_dtype) - self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=ae_conv_dtype) - self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=ae_conv_dtype) - self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=ae_conv_dtype) + self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) + self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) def forward(self, z): - if ae_use_fp16: + if use_fp16: z = tf.cast(z, tf.float16) x = self.upscale0(z) @@ -259,60 +240,22 @@ class AMPModel(ModelBase): self.out_conv1(x), self.out_conv2(x), self.out_conv3(x)), nn.conv2d_ch_axis), 2) ) - m = self.upscalem0(z) m = self.upscalem1(m) m = self.upscalem2(m) m = self.upscalem3(m) m = self.upscalem4(m) m = tf.nn.sigmoid(self.out_convm(m)) - - if ae_use_fp16: + + if use_fp16: x = tf.cast(x, tf.float32) m = tf.cast(m, tf.float32) return x, m - self.face_type = {'f' : FaceType.FULL, - 'wf' : FaceType.WHOLE_FACE, - 'head' : FaceType.HEAD}[ self.options['face_type'] ] - - if 'eyes_prio' in self.options: - self.options.pop('eyes_prio') - - eyes_mouth_prio = self.options['eyes_mouth_prio'] - - - morph_factor = self.options['morph_factor'] - - gan_power = self.gan_power = self.options['gan_power'] - random_warp = self.options['random_warp'] - random_src_flip = self.random_src_flip - random_dst_flip = self.random_dst_flip - - #pretrain = self.pretrain = self.options['pretrain'] - #if self.pretrain_just_disabled: - # self.set_iter(0) - # self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power'] - # random_warp = False if self.pretrain else self.options['random_warp'] - # random_src_flip = self.random_src_flip if not self.pretrain else True - # random_dst_flip = self.random_dst_flip if not self.pretrain else True - - # if self.pretrain: - # self.options_show_override['gan_power'] = 0.0 - # self.options_show_override['random_warp'] = False - # self.options_show_override['lr_dropout'] = 'n' - # self.options_show_override['uniform_yaw'] = True - - masked_training = self.options['masked_training'] - ct_mode = self.options['ct_mode'] - if ct_mode == 'none': - ct_mode = None - models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' optimizer_vars_on_cpu = models_opt_device=='/CPU:0' - bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) mask_shape = nn.get4Dshape(resolution,resolution,1) self.model_filename_list = [] @@ -333,7 +276,6 @@ class AMPModel(ModelBase): self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t') # Initializing model classes - with tf.device (models_opt_device): self.encoder = Encoder(name='encoder') self.inter_src = Inter(name='inter_src') @@ -346,30 +288,21 @@ class AMPModel(ModelBase): [self.decoder , 'decoder.npy'] ] if self.is_training: - if gan_power != 0: - self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN") - self.model_filename_list += [ [self.GAN, 'GAN.npy'] ] - # Initialize optimizers - lr=5e-5 - lr_dropout = 0.3 - clipnorm = 1.0 if self.options['clipgrad'] else 0.0 self.all_weights = self.encoder.get_weights() + self.decoder.get_weights() - #if pretrain: - # self.trainable_weights = self.encoder.get_weights() + self.decoder.get_weights() - #else: - self.trainable_weights = self.encoder.get_weights() + self.decoder.get_weights() - self.src_dst_opt = nn.AdaBelief(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') + self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=0.3, clipnorm=clipnorm, name='src_dst_opt') self.src_dst_opt.initialize_variables (self.all_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if gan_power != 0: - self.GAN_opt = nn.AdaBelief(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt') - self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu)#+self.D_src_x2.get_weights() - self.model_filename_list += [ (self.GAN_opt, 'GAN_opt.npy') ] + self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN") + self.GAN_opt = nn.AdaBelief(lr=5e-5, lr_dropout=0.3, clipnorm=clipnorm, name='GAN_opt') + self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) + self.model_filename_list += [ [self.GAN, 'GAN.npy'], + [self.GAN_opt, 'GAN_opt.npy'] ] if self.is_training: # Adjust batch size for multiple GPU @@ -387,8 +320,8 @@ class AMPModel(ModelBase): gpu_src_losses = [] gpu_dst_losses = [] - gpu_G_loss_gvs = [] - gpu_D_src_dst_loss_gvs = [] + gpu_G_loss_gradients = [] + gpu_GAN_loss_grads = [] for gpu_id in range(gpu_count): with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): @@ -408,16 +341,8 @@ class AMPModel(ModelBase): gpu_src_code = self.encoder (gpu_warped_src) gpu_dst_code = self.encoder (gpu_warped_dst) - # if pretrain: - # gpu_src_inter_src_code = self.inter_src (gpu_src_code) - # gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) - # gpu_src_code = gpu_src_inter_src_code * nn.random_binomial( [bs_per_gpu, gpu_src_inter_src_code.shape.as_list()[1], 1,1] , p=morph_factor) - # gpu_dst_code = gpu_src_dst_code = gpu_dst_inter_dst_code * nn.random_binomial( [bs_per_gpu, gpu_dst_inter_dst_code.shape.as_list()[1], 1,1] , p=morph_factor) - # else: - gpu_src_inter_src_code = self.inter_src (gpu_src_code) - gpu_src_inter_dst_code = self.inter_dst (gpu_src_code) - gpu_dst_inter_src_code = self.inter_src (gpu_dst_code) - gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) + gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code) + gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code) inter_rnd_binomial = nn.random_binomial( [bs_per_gpu, gpu_src_inter_src_code.shape.as_list()[1], 1,1] , p=morph_factor) gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial) @@ -431,62 +356,50 @@ class AMPModel(ModelBase): gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) - gpu_pred_src_src_list.append(gpu_pred_src_src) - gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) - gpu_pred_src_dst_list.append(gpu_pred_src_dst) + gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) + gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) + gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) - gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) - gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) - gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) + gpu_target_srcm_blur = tf.clip_by_value( nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ), 0, 0.5) * 2 + gpu_target_dstm_blur = tf.clip_by_value(nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ), 0, 0.5) * 2 - gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) - gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2 + gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur + gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur - gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) - gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2 + gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur + gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur + gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur + gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur - gpu_target_dst_anti_masked = gpu_target_dst*(1.0-gpu_target_dstm_blur) - gpu_target_src_anti_masked = gpu_target_src*(1.0-gpu_target_srcm_blur) - gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src - gpu_target_dst_masked_opt = gpu_target_dst*gpu_target_dstm_blur if masked_training else gpu_target_dst + gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur + gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur + gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur + gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur - gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src - gpu_pred_src_src_anti_masked = gpu_pred_src_src*(1.0-gpu_target_srcm_blur) - gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst - gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*(1.0-gpu_target_dstm_blur) + # Structural loss + gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + gpu_dst_loss = tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) - if resolution < 256: - gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) - else: - gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) - gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) - gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) - if eyes_mouth_prio: - gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_em - gpu_pred_dst_dst*gpu_target_dstm_em ), axis=[1,2,3]) - gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) - gpu_dst_loss += 0.1*tf.reduce_mean(tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] ) - gpu_dst_losses += [gpu_dst_loss] + # Pixel loss + gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src_masked-gpu_pred_src_src_masked), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst_masked-gpu_pred_dst_dst_masked), axis=[1,2,3]) - #if not pretrain: - if resolution < 256: - gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) - else: - gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) - gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) - gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) - - if eyes_mouth_prio: - gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_em - gpu_pred_src_src*gpu_target_srcm_em ), axis=[1,2,3]) + # Eyes+mouth prio loss + gpu_src_loss += tf.reduce_mean (300*tf.abs (gpu_target_src*gpu_target_srcm_em-gpu_pred_src_src*gpu_target_srcm_em), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean (300*tf.abs (gpu_target_dst*gpu_target_dstm_em-gpu_pred_dst_dst*gpu_target_dstm_em), axis=[1,2,3]) + # Mask loss gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) - #else: - # gpu_src_loss = gpu_dst_loss + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) + + # dst-dst background weak loss + gpu_dst_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] ) + gpu_dst_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked) gpu_src_losses += [gpu_src_loss] - - #if pretrain: - # gpu_G_loss = gpu_dst_loss - #else: + gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss def DLossOnes(logits): @@ -496,30 +409,28 @@ class AMPModel(ModelBase): return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3]) if gan_power != 0: - gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked_opt) - gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked_opt) - gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked_opt) - gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked_opt) + gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked) + gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked) + gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked) + gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked) - gpu_D_src_dst_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \ - DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \ - DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \ - DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2) - ) * ( 1.0 / 8) + gpu_GAN_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \ + DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \ + DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \ + DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2) + ) * (1.0 / 8) - gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.GAN.get_weights() ) ] + gpu_GAN_loss_grads += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ] gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \ DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2) ) * gan_power - if masked_training: - # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan - gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) - gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) - - gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.trainable_weights ) ] + # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) + gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) + gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.encoder.get_weights() + self.decoder.get_weights() ) ] # Average losses and gradients, and create optimizer update ops with tf.device(f'/CPU:0'): @@ -533,15 +444,15 @@ class AMPModel(ModelBase): with tf.device (models_opt_device): src_loss = tf.concat(gpu_src_losses, 0) dst_loss = tf.concat(gpu_dst_losses, 0) - src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) + train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients)) if gan_power != 0: - src_D_src_dst_loss_gv_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) ) + GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_grads) ) # Initializing training and view functions - def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ + def train(warped_src, target_src, target_srcm, target_srcm_em, \ warped_dst, target_dst, target_dstm, target_dstm_em, ): - s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], + s, d, _ = nn.tf_sess.run ([src_loss, dst_loss, train_op], feed_dict={self.warped_src :warped_src, self.target_src :target_src, self.target_srcm:target_srcm, @@ -552,20 +463,20 @@ class AMPModel(ModelBase): self.target_dstm_em:target_dstm_em, }) return s, d - self.src_dst_train = src_dst_train + self.train = train if gan_power != 0: - def D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ - warped_dst, target_dst, target_dstm, target_dstm_em, ): - nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, - self.target_src :target_src, - self.target_srcm:target_srcm, - self.target_srcm_em:target_srcm_em, - self.warped_dst :warped_dst, - self.target_dst :target_dst, - self.target_dstm:target_dstm, - self.target_dstm_em:target_dstm_em}) - self.D_src_dst_train = D_src_dst_train + def GAN_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + nn.tf_sess.run ([GAN_train_op], feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em}) + self.GAN_train = GAN_train def AE_view(warped_src, warped_dst, morph_value): return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], @@ -576,8 +487,8 @@ class AMPModel(ModelBase): #Initializing merge function with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): gpu_dst_code = self.encoder (self.warped_dst) - gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code) - gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code) + gpu_dst_inter_src_code = self.inter_src (gpu_dst_code) + gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), @@ -593,16 +504,10 @@ class AMPModel(ModelBase): # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): - # if self.pretrain_just_disabled: - # do_init = False - # if model == self.inter_src or model == self.inter_dst: - # do_init = True - # else: do_init = self.is_first_run() if self.is_training and gan_power != 0 and model == self.GAN: if self.gan_model_changed: do_init = True - if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) if do_init: @@ -614,7 +519,7 @@ class AMPModel(ModelBase): training_data_src_path = self.training_data_src_path #if not self.pretrain else self.get_pretraining_data_path() training_data_dst_path = self.training_data_dst_path #if not self.pretrain else self.get_pretraining_data_path() - random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain + random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain cpu_count = min(multiprocessing.cpu_count(), 8) src_generators_count = cpu_count // 2 @@ -624,21 +529,21 @@ class AMPModel(ModelBase): self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=random_src_flip), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + sample_process_options=SampleProcessor.Options(random_flip=self.random_src_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, generators_count=src_generators_count ), SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), - sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip), - output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, - {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + sample_process_options=SampleProcessor.Options(random_flip=self.random_dst_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, generators_count=dst_generators_count ) @@ -646,15 +551,12 @@ class AMPModel(ModelBase): self.last_src_samples_loss = [] self.last_dst_samples_loss = [] - #if self.pretrain_just_disabled: - # self.update_sample_for_preview(force_new=True) - def export_dfm (self): output_path=self.get_strpath_storage_for_file('model.dfm') - + io.log_info(f'Dumping .dfm to {output_path}') - + tf = nn.tf with tf.device (nn.tf_default_device_name): warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') @@ -679,13 +581,13 @@ class AMPModel(ModelBase): tf.identity(gpu_pred_dst_dstm, name='out_face_mask') tf.identity(gpu_pred_src_dst, name='out_celeb_face') tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') - + output_graph_def = tf.graph_util.convert_variables_to_constants( - nn.tf_sess, - tf.get_default_graph().as_graph_def(), + nn.tf_sess, + tf.get_default_graph().as_graph_def(), ['out_face_mask','out_celeb_face','out_celeb_face_mask'] - ) - + ) + import tf2onnx with tf.device("/CPU:0"): model_proto, _ = tf2onnx.convert._convert_common( @@ -717,7 +619,7 @@ class AMPModel(ModelBase): ( (warped_src, target_src, target_srcm, target_srcm_em), \ (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() - src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) for i in range(bs): self.last_src_samples_loss.append ( (src_loss[i], warped_src[i], target_src[i], target_srcm[i], target_srcm_em[i]) ) @@ -737,12 +639,12 @@ class AMPModel(ModelBase): target_dstm = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) target_dstm_em = np.stack( [ x[4] for x in dst_samples_loss[:bs] ] ) - src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) self.last_src_samples_loss = [] self.last_dst_samples_loss = [] if self.gan_power != 0: - self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )