diff --git a/doc/manual_en_google_translated.docx b/doc/manual_en_google_translated.docx index 430b640..6fb1d5e 100644 Binary files a/doc/manual_en_google_translated.docx and b/doc/manual_en_google_translated.docx differ diff --git a/doc/manual_en_google_translated.pdf b/doc/manual_en_google_translated.pdf index e5350d0..09ff703 100644 Binary files a/doc/manual_en_google_translated.pdf and b/doc/manual_en_google_translated.pdf differ diff --git a/doc/manual_ru.pdf b/doc/manual_ru.pdf index 0ad79fd..1b35473 100644 Binary files a/doc/manual_ru.pdf and b/doc/manual_ru.pdf differ diff --git a/doc/manual_ru_source.docx b/doc/manual_ru_source.docx index 64a9561..4f96ba6 100644 Binary files a/doc/manual_ru_source.docx and b/doc/manual_ru_source.docx differ diff --git a/main.py b/main.py index 0f1a68a..c480a9e 100644 --- a/main.py +++ b/main.py @@ -120,6 +120,7 @@ if __name__ == "__main__": os_utils.set_process_lowest_prio() args = {'training_data_src_dir' : arguments.training_data_src_dir, 'training_data_dst_dir' : arguments.training_data_dst_dir, + 'pretraining_data_dir' : arguments.pretraining_data_dir, 'model_path' : arguments.model_dir, 'model_name' : arguments.model_name, 'no_preview' : arguments.no_preview, @@ -133,8 +134,9 @@ if __name__ == "__main__": Trainer.main(args, device_args) p = subparsers.add_parser( "train", help="Trainer") - p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir", help="Dir of src-set.") - p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir", help="Dir of dst-set.") + p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir", help="Dir of extracted SRC faceset.") + p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir", help="Dir of extracted DST faceset.") + p.add_argument('--pretraining-data-dir', action=fixPathAction, dest="pretraining_data_dir", help="Optional dir of extracted faceset that will be used in pretraining mode.") p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.") p.add_argument('--model', required=True, dest="model_name", choices=Path_utils.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Type of model") p.add_argument('--no-preview', action="store_true", dest="no_preview", default=False, help="Disable preview window.") diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index f70b641..bde87e3 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -19,6 +19,7 @@ def trainerThread (s2c, c2s, args, device_args): training_data_src_path = Path( args.get('training_data_src_dir', '') ) training_data_dst_path = Path( args.get('training_data_dst_dir', '') ) + pretraining_data_path = Path( args.get('pretraining_data_dir', '') ) model_path = Path( args.get('model_path', '') ) model_name = args.get('model_name', '') save_interval_min = 15 @@ -40,6 +41,7 @@ def trainerThread (s2c, c2s, args, device_args): model_path, training_data_src_path=training_data_src_path, training_data_dst_path=training_data_dst_path, + pretraining_data_path=pretraining_data_path, debug=debug, device_args=device_args) diff --git a/models/ModelBase.py b/models/ModelBase.py index 453a462..e1f25f5 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -20,9 +20,13 @@ You can implement your own model. Check examples. class ModelBase(object): - def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, device_args = None, - ask_write_preview_history=True, ask_target_iter=True, ask_batch_size=True, ask_sort_by_yaw=True, - ask_random_flip=True, ask_src_scale_mod=True): + def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None, + ask_write_preview_history=True, + ask_target_iter=True, + ask_batch_size=True, + ask_sort_by_yaw=True, + ask_random_flip=True, + ask_src_scale_mod=True): device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1) device_args['cpu_only'] = device_args.get('cpu_only',False) @@ -46,7 +50,8 @@ class ModelBase(object): self.training_data_src_path = training_data_src_path self.training_data_dst_path = training_data_dst_path - + self.pretraining_data_path = pretraining_data_path + self.src_images_paths = None self.dst_images_paths = None self.src_yaw_images_paths = None @@ -85,7 +90,7 @@ class ModelBase(object): else: self.options['write_preview_history'] = self.options.get('write_preview_history', False) - if self.iter == 0 and self.options['write_preview_history'] and io.is_support_windows(): + if (self.iter == 0 or ask_override) and self.options['write_preview_history'] and io.is_support_windows(): choose_preview_history = io.input_bool("Choose image for the preview history? (y/n skip:%s) : " % (yn_str[False]) , False) else: choose_preview_history = False diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 09ac424..a077877 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -91,7 +91,12 @@ class SAEModel(ModelBase): self.options['pixel_loss'] = self.options.get('pixel_loss', False) self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power) self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power) - + + if is_first_run: + self.options['pretrain'] = io.input_bool ("Pretrain the model? (y/n, ?:help skip:n) : ", False, help_message="Pretrain the model with large amount of various faces. This technique may help to train the fake with overly different face shapes and light conditions of src/dst data. Face will be look more like a morphed. To reduce the morph effect, some model files will be initialized but not be updated after pretrain: LIAE: inter_AB.h5 DF: both decoders.h5. The longer you pretrain the model the more morphed face will look. After that, save and run the training again.") + else: + self.options['pretrain'] = False + #override def onInitialize(self): exec(nnlib.import_all(), locals(), globals()) @@ -102,6 +107,10 @@ class SAEModel(ModelBase): ae_dims = self.options['ae_dims'] e_ch_dims = self.options['e_ch_dims'] d_ch_dims = self.options['d_ch_dims'] + self.pretrain = self.options['pretrain'] = self.options.get('pretrain', False) + if not self.pretrain: + self.options.pop('pretrain') + d_residual_blocks = True bgr_shape = (resolution, resolution, 3) mask_shape = (resolution, resolution, 1) @@ -123,16 +132,9 @@ class SAEModel(ModelBase): target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)] target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)] - padding = 'zero' - norm = '' - - if '-s' in self.options['archi']: - norm = 'bn' - - common_flow_kwargs = { 'padding': padding, - 'norm': norm, - 'act':'' } - + common_flow_kwargs = { 'padding': 'zero', + 'norm': 'norm', + 'act':'' } models_list = [] weights_to_load = [] if 'liae' in self.options['archi']: @@ -216,14 +218,17 @@ class SAEModel(ModelBase): pred_dst_dstm = self.decoder_dstm(warped_dst_code) pred_src_dstm = self.decoder_srcm(warped_dst_code) - if self.is_first_run() and self.options.get('ca_weights',False): - conv_weights_list = [] - for model in models_list: - for layer in model.layers: - if type(layer) == keras.layers.Conv2D: - conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights - CAInitializerMP ( conv_weights_list ) - + if self.is_first_run(): + if self.options.get('ca_weights',False): + conv_weights_list = [] + for model in models_list: + for layer in model.layers: + if type(layer) == keras.layers.Conv2D: + conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights + CAInitializerMP ( conv_weights_list ) + else: + self.load_weights_safe(weights_to_load) + pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ] if self.options['learn_mask']: @@ -259,7 +264,7 @@ class SAEModel(ModelBase): psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] - + if self.is_training_mode: self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1) @@ -323,9 +328,8 @@ class SAEModel(ModelBase): else: self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] ) - self.load_weights_safe(weights_to_load)#, [ [self.src_dst_opt, 'src_dst_opt'], [self.src_dst_mask_opt, 'src_dst_mask_opt']]) + else: - self.load_weights_safe(weights_to_load) if self.options['learn_mask']: self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_dst_dstm[-1], pred_src_dstm[-1] ]) else: @@ -339,17 +343,28 @@ class SAEModel(ModelBase): t = SampleProcessor.Types face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF - output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'resolution':resolution} ] - output_sample_types += [ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution // (2**i) } for i in range(ms_count)] + t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE + + output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution} ] + output_sample_types += [ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i) } for i in range(ms_count)] output_sample_types += [ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M, t.FACE_MASK_FULL), 'resolution': resolution // (2**i) } for i in range(ms_count)] + training_data_src_path = self.training_data_src_path + training_data_dst_path = self.training_data_dst_path + sort_by_yaw = self.sort_by_yaw + + if self.pretrain: + training_data_src_path = self.pretraining_data_path + training_data_dst_path = self.pretraining_data_path + sort_by_yaw = False + self.set_training_data_generators ([ - SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, + SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), output_sample_types=output_sample_types ), - SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, + SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ), output_sample_types=output_sample_types ) ]) @@ -362,20 +377,30 @@ class SAEModel(ModelBase): ar = [] if 'liae' in self.options['archi']: ar += [[self.encoder, 'encoder.h5'], - [self.inter_B, 'inter_B.h5'], - [self.inter_AB, 'inter_AB.h5'], - [self.decoder, 'decoder.h5'] - ] + [self.inter_B, 'inter_B.h5'], + [self.decoder, 'decoder.h5'] + ] + + if not self.pretrain or self.iter == 0: + ar += [ [self.inter_AB, 'inter_AB.h5'], + ] + if self.options['learn_mask']: ar += [ [self.decoderm, 'decoderm.h5'] ] + elif 'df' in self.options['archi']: - ar += [[self.encoder, 'encoder.h5'], - [self.decoder_src, 'decoder_src.h5'], - [self.decoder_dst, 'decoder_dst.h5'] - ] - if self.options['learn_mask']: - ar += [ [self.decoder_srcm, 'decoder_srcm.h5'], - [self.decoder_dstm, 'decoder_dstm.h5'] ] + ar += [ [self.encoder, 'encoder.h5'], + ] + + if not self.pretrain or self.iter == 0: + ar += [ [self.decoder_src, 'decoder_src.h5'], + [self.decoder_dst, 'decoder_dst.h5'] + ] + + if self.options['learn_mask']: + if not self.pretrain or self.iter == 0: + ar += [ [self.decoder_srcm, 'decoder_srcm.h5'], + [self.decoder_dstm, 'decoder_dstm.h5'] ] self.save_weights_safe(ar)