diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index 46d065a..0fbcddb 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -28,7 +28,9 @@ class AMPModel(ModelBase): default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) default_morph_factor = self.options['morph_factor'] = self.options.get('morph_factor', 0.5) default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False) - default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', 'n') + default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False) + default_dst_denoise = self.options['rtm_dst_denoise'] = self.load_or_def_option('rtm_dst_denoise', False) + default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', 'n') default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none') default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False) @@ -73,6 +75,9 @@ class AMPModel(ModelBase): if self.is_first_run() or ask_override: self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') + self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.') + self.options['rtm_dst_denoise'] = io.input_bool ("Denoise RTM DST faceset.", default_dst_denoise, help_message='Used in RTM(ReadyToMerge) training with RTM DST faceset. Removes high frequency noise keeping edges. Result is better face syncronization with any face. Can be enabled at any time.') + self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.") default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) @@ -120,7 +125,10 @@ class AMPModel(ModelBase): morph_factor = self.options['morph_factor'] gan_power = self.gan_power = self.options['gan_power'] random_warp = self.options['random_warp'] - + + blur_out_mask = self.options['blur_out_mask'] + rtm_dst_denoise = self.options['rtm_dst_denoise'] + ct_mode = self.options['ct_mode'] if ct_mode == 'none': ct_mode = None @@ -128,7 +136,7 @@ class AMPModel(ModelBase): use_fp16 = False if self.is_exporting: use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') - + conv_dtype = tf.float16 if use_fp16 else tf.float32 class Downscale(nn.ModelBase): @@ -289,11 +297,11 @@ class AMPModel(ModelBase): # Initialize optimizers clipnorm = 1.0 if self.options['clipgrad'] else 0.0 lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] else 1.0 - - self.all_weights = self.encoder.get_weights() + self.decoder.get_weights() + + self.G_weights = self.encoder.get_weights() + self.decoder.get_weights() self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') - self.src_dst_opt.initialize_variables (self.all_weights, vars_on_cpu=optimizer_vars_on_cpu) + self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if gan_power != 0: @@ -320,7 +328,13 @@ class AMPModel(ModelBase): gpu_src_losses = [] gpu_dst_losses = [] gpu_G_loss_gradients = [] - gpu_GAN_loss_grads = [] + gpu_GAN_loss_gradients = [] + + def DLossOnes(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3]) + + def DLossZeros(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3]) for gpu_id in range(gpu_count): with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): @@ -342,39 +356,63 @@ class AMPModel(ModelBase): gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code) gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code) - - inter_rnd_binomial = nn.random_binomial( [bs_per_gpu, gpu_src_inter_src_code.shape.as_list()[1], 1,1] , p=morph_factor) + + inter_dims_bin = int(inter_dims*morph_factor) + inter_rnd_binomial = tf.stack([tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )), + tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 ) for _ in range(bs_per_gpu)], 0) + + inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None]) + gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial) gpu_dst_code = gpu_dst_inter_dst_code inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) - gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), - tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) - + gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) - + gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) - gpu_target_srcm_blur = tf.clip_by_value( nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ), 0, 0.5) * 2 - gpu_target_dstm_blur = tf.clip_by_value(nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ), 0, 0.5) * 2 - + gpu_target_srcm_anti = 1-gpu_target_srcm + gpu_target_dstm_anti = 1-gpu_target_dstm + + gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32) + gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32) + + gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2 + gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2 gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur + + if blur_out_mask: + #gpu_target_src = gpu_target_src*gpu_target_srcm_blur + nn.gaussian_blur(gpu_target_src, resolution // 32)*gpu_target_srcm_anti_blur + #gpu_target_dst = gpu_target_dst*gpu_target_dstm_blur + nn.gaussian_blur(gpu_target_dst, resolution // 32)*gpu_target_dstm_anti_blur + bg_blur_div = 128 + + gpu_target_src = gpu_target_src*gpu_target_srcm + \ + tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, resolution / bg_blur_div), + (1-nn.gaussian_blur(gpu_target_srcm, resolution / bg_blur_div) ) ) * gpu_target_srcm_anti + gpu_target_dst = gpu_target_dst*gpu_target_dstm + \ + tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, resolution / bg_blur_div), + (1-nn.gaussian_blur(gpu_target_dstm, resolution / bg_blur_div)) ) * gpu_target_dstm_anti + + gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur - + gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur - + # Structural loss gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) @@ -393,20 +431,14 @@ class AMPModel(ModelBase): gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) - # dst-dst background weak loss - gpu_dst_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] ) - gpu_dst_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked) - gpu_src_losses += [gpu_src_loss] gpu_dst_losses += [gpu_dst_loss] gpu_G_loss = gpu_src_loss + gpu_dst_loss - - def DLossOnes(logits): - return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3]) - - def DLossZeros(logits): - return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3]) - + # dst-dst background weak loss + gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] ) + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked) + + if gan_power != 0: gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked) gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked) @@ -419,17 +451,17 @@ class AMPModel(ModelBase): DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2) ) * (1.0 / 8) - gpu_GAN_loss_grads += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ] + gpu_GAN_loss_gradients += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ] gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \ DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2) ) * gan_power - + # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) - gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.encoder.get_weights() + self.decoder.get_weights() ) ] + gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.G_weights ) ] # Average losses and gradients, and create optimizer update ops with tf.device(f'/CPU:0'): @@ -444,9 +476,9 @@ class AMPModel(ModelBase): src_loss = tf.concat(gpu_src_losses, 0) dst_loss = tf.concat(gpu_dst_losses, 0) train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients)) - + if gan_power != 0: - GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_grads) ) + GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) ) # Initializing training and view functions def train(warped_src, target_src, target_srcm, target_srcm_em, \ @@ -520,11 +552,13 @@ class AMPModel(ModelBase): random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain - cpu_count = min(multiprocessing.cpu_count(), 8) + cpu_count = multiprocessing.cpu_count() src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if ct_mode is not None: src_generators_count = int(src_generators_count * 1.5) + + self.set_training_data_generators ([ SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), @@ -541,6 +575,7 @@ class AMPModel(ModelBase): sample_process_options=SampleProcessor.Options(random_flip=self.random_dst_flip), output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'denoise_filter' : rtm_dst_denoise, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, ], @@ -616,13 +651,13 @@ class AMPModel(ModelBase): bs = self.get_batch_size() ( (warped_src, target_src, target_srcm, target_srcm_em), \ - (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() + (warped_dst, target_dst, target_dst_train, target_dstm, target_dstm_em) ) = self.generate_next_samples() - src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst_train, target_dstm, target_dstm_em) for i in range(bs): self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) ) - self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) ) + self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dst_train[i], target_dstm[i], target_dstm_em[i]) ) if len(self.last_src_samples_loss) >= bs*16: src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True) @@ -633,22 +668,23 @@ class AMPModel(ModelBase): target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] ) target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) - target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) - target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) + target_dst_train = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) + target_dstm = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) + target_dstm_em = np.stack( [ x[4] for x in dst_samples_loss[:bs] ] ) - src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em) + src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst_train, target_dstm, target_dstm_em) self.last_src_samples_loss = [] self.last_dst_samples_loss = [] if self.gan_power != 0: - self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst_train, target_dstm, target_dstm_em) return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) #override def onGetPreview(self, samples, for_history=False): ( (warped_src, target_src, target_srcm, target_srcm_em), - (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples + (warped_dst, target_dst, target_dst_train, target_dstm, target_dstm_em) ) = samples S, D, SS, DD, DDM_000, _, _ = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst, 0.0) ) ]