SAE: added option 'Pretrain the model?',

Pretrain the model with large amount of various faces. This technique may help to train the fake with overly different face shapes and light conditions of src/dst data. Face will be look more like a morphed. To reduce the morph effect, some model files will be initialized but not be updated after pretrain: LIAE: inter_AB.h5 DF: both decoders.h5. The longer you pretrain the model the more morphed face will look. After that, save and run the training again.
This commit is contained in:
iperov 2019-05-01 19:55:27 +04:00
parent 659aa5705a
commit 2a8dd788dc
8 changed files with 78 additions and 44 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -120,6 +120,7 @@ if __name__ == "__main__":
os_utils.set_process_lowest_prio() os_utils.set_process_lowest_prio()
args = {'training_data_src_dir' : arguments.training_data_src_dir, args = {'training_data_src_dir' : arguments.training_data_src_dir,
'training_data_dst_dir' : arguments.training_data_dst_dir, 'training_data_dst_dir' : arguments.training_data_dst_dir,
'pretraining_data_dir' : arguments.pretraining_data_dir,
'model_path' : arguments.model_dir, 'model_path' : arguments.model_dir,
'model_name' : arguments.model_name, 'model_name' : arguments.model_name,
'no_preview' : arguments.no_preview, 'no_preview' : arguments.no_preview,
@ -133,8 +134,9 @@ if __name__ == "__main__":
Trainer.main(args, device_args) Trainer.main(args, device_args)
p = subparsers.add_parser( "train", help="Trainer") p = subparsers.add_parser( "train", help="Trainer")
p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir", help="Dir of src-set.") p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir", help="Dir of extracted SRC faceset.")
p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir", help="Dir of dst-set.") p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir", help="Dir of extracted DST faceset.")
p.add_argument('--pretraining-data-dir', action=fixPathAction, dest="pretraining_data_dir", help="Optional dir of extracted faceset that will be used in pretraining mode.")
p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.") p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.")
p.add_argument('--model', required=True, dest="model_name", choices=Path_utils.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Type of model") p.add_argument('--model', required=True, dest="model_name", choices=Path_utils.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Type of model")
p.add_argument('--no-preview', action="store_true", dest="no_preview", default=False, help="Disable preview window.") p.add_argument('--no-preview', action="store_true", dest="no_preview", default=False, help="Disable preview window.")

View file

@ -19,6 +19,7 @@ def trainerThread (s2c, c2s, args, device_args):
training_data_src_path = Path( args.get('training_data_src_dir', '') ) training_data_src_path = Path( args.get('training_data_src_dir', '') )
training_data_dst_path = Path( args.get('training_data_dst_dir', '') ) training_data_dst_path = Path( args.get('training_data_dst_dir', '') )
pretraining_data_path = Path( args.get('pretraining_data_dir', '') )
model_path = Path( args.get('model_path', '') ) model_path = Path( args.get('model_path', '') )
model_name = args.get('model_name', '') model_name = args.get('model_name', '')
save_interval_min = 15 save_interval_min = 15
@ -40,6 +41,7 @@ def trainerThread (s2c, c2s, args, device_args):
model_path, model_path,
training_data_src_path=training_data_src_path, training_data_src_path=training_data_src_path,
training_data_dst_path=training_data_dst_path, training_data_dst_path=training_data_dst_path,
pretraining_data_path=pretraining_data_path,
debug=debug, debug=debug,
device_args=device_args) device_args=device_args)

View file

@ -20,9 +20,13 @@ You can implement your own model. Check examples.
class ModelBase(object): class ModelBase(object):
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, device_args = None, def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None,
ask_write_preview_history=True, ask_target_iter=True, ask_batch_size=True, ask_sort_by_yaw=True, ask_write_preview_history=True,
ask_random_flip=True, ask_src_scale_mod=True): ask_target_iter=True,
ask_batch_size=True,
ask_sort_by_yaw=True,
ask_random_flip=True,
ask_src_scale_mod=True):
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1) device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
device_args['cpu_only'] = device_args.get('cpu_only',False) device_args['cpu_only'] = device_args.get('cpu_only',False)
@ -46,7 +50,8 @@ class ModelBase(object):
self.training_data_src_path = training_data_src_path self.training_data_src_path = training_data_src_path
self.training_data_dst_path = training_data_dst_path self.training_data_dst_path = training_data_dst_path
self.pretraining_data_path = pretraining_data_path
self.src_images_paths = None self.src_images_paths = None
self.dst_images_paths = None self.dst_images_paths = None
self.src_yaw_images_paths = None self.src_yaw_images_paths = None
@ -85,7 +90,7 @@ class ModelBase(object):
else: else:
self.options['write_preview_history'] = self.options.get('write_preview_history', False) self.options['write_preview_history'] = self.options.get('write_preview_history', False)
if self.iter == 0 and self.options['write_preview_history'] and io.is_support_windows(): if (self.iter == 0 or ask_override) and self.options['write_preview_history'] and io.is_support_windows():
choose_preview_history = io.input_bool("Choose image for the preview history? (y/n skip:%s) : " % (yn_str[False]) , False) choose_preview_history = io.input_bool("Choose image for the preview history? (y/n skip:%s) : " % (yn_str[False]) , False)
else: else:
choose_preview_history = False choose_preview_history = False

View file

@ -91,7 +91,12 @@ class SAEModel(ModelBase):
self.options['pixel_loss'] = self.options.get('pixel_loss', False) self.options['pixel_loss'] = self.options.get('pixel_loss', False)
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power) self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power) self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
if is_first_run:
self.options['pretrain'] = io.input_bool ("Pretrain the model? (y/n, ?:help skip:n) : ", False, help_message="Pretrain the model with large amount of various faces. This technique may help to train the fake with overly different face shapes and light conditions of src/dst data. Face will be look more like a morphed. To reduce the morph effect, some model files will be initialized but not be updated after pretrain: LIAE: inter_AB.h5 DF: both decoders.h5. The longer you pretrain the model the more morphed face will look. After that, save and run the training again.")
else:
self.options['pretrain'] = False
#override #override
def onInitialize(self): def onInitialize(self):
exec(nnlib.import_all(), locals(), globals()) exec(nnlib.import_all(), locals(), globals())
@ -102,6 +107,10 @@ class SAEModel(ModelBase):
ae_dims = self.options['ae_dims'] ae_dims = self.options['ae_dims']
e_ch_dims = self.options['e_ch_dims'] e_ch_dims = self.options['e_ch_dims']
d_ch_dims = self.options['d_ch_dims'] d_ch_dims = self.options['d_ch_dims']
self.pretrain = self.options['pretrain'] = self.options.get('pretrain', False)
if not self.pretrain:
self.options.pop('pretrain')
d_residual_blocks = True d_residual_blocks = True
bgr_shape = (resolution, resolution, 3) bgr_shape = (resolution, resolution, 3)
mask_shape = (resolution, resolution, 1) mask_shape = (resolution, resolution, 1)
@ -123,16 +132,9 @@ class SAEModel(ModelBase):
target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)] target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)] target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
padding = 'zero' common_flow_kwargs = { 'padding': 'zero',
norm = '' 'norm': 'norm',
'act':'' }
if '-s' in self.options['archi']:
norm = 'bn'
common_flow_kwargs = { 'padding': padding,
'norm': norm,
'act':'' }
models_list = [] models_list = []
weights_to_load = [] weights_to_load = []
if 'liae' in self.options['archi']: if 'liae' in self.options['archi']:
@ -216,14 +218,17 @@ class SAEModel(ModelBase):
pred_dst_dstm = self.decoder_dstm(warped_dst_code) pred_dst_dstm = self.decoder_dstm(warped_dst_code)
pred_src_dstm = self.decoder_srcm(warped_dst_code) pred_src_dstm = self.decoder_srcm(warped_dst_code)
if self.is_first_run() and self.options.get('ca_weights',False): if self.is_first_run():
conv_weights_list = [] if self.options.get('ca_weights',False):
for model in models_list: conv_weights_list = []
for layer in model.layers: for model in models_list:
if type(layer) == keras.layers.Conv2D: for layer in model.layers:
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights if type(layer) == keras.layers.Conv2D:
CAInitializerMP ( conv_weights_list ) conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
CAInitializerMP ( conv_weights_list )
else:
self.load_weights_safe(weights_to_load)
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ] pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
if self.options['learn_mask']: if self.options['learn_mask']:
@ -259,7 +264,7 @@ class SAEModel(ModelBase):
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
if self.is_training_mode: if self.is_training_mode:
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
@ -323,9 +328,8 @@ class SAEModel(ModelBase):
else: else:
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] ) self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] )
self.load_weights_safe(weights_to_load)#, [ [self.src_dst_opt, 'src_dst_opt'], [self.src_dst_mask_opt, 'src_dst_mask_opt']])
else: else:
self.load_weights_safe(weights_to_load)
if self.options['learn_mask']: if self.options['learn_mask']:
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_dst_dstm[-1], pred_src_dstm[-1] ]) self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_dst_dstm[-1], pred_src_dstm[-1] ])
else: else:
@ -339,17 +343,28 @@ class SAEModel(ModelBase):
t = SampleProcessor.Types t = SampleProcessor.Types
face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR), 'resolution':resolution} ] t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
output_sample_types += [ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution': resolution // (2**i) } for i in range(ms_count)]
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution} ]
output_sample_types += [ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i) } for i in range(ms_count)]
output_sample_types += [ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M, t.FACE_MASK_FULL), 'resolution': resolution // (2**i) } for i in range(ms_count)] output_sample_types += [ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M, t.FACE_MASK_FULL), 'resolution': resolution // (2**i) } for i in range(ms_count)]
training_data_src_path = self.training_data_src_path
training_data_dst_path = self.training_data_dst_path
sort_by_yaw = self.sort_by_yaw
if self.pretrain:
training_data_src_path = self.pretraining_data_path
training_data_dst_path = self.pretraining_data_path
sort_by_yaw = False
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None, SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
output_sample_types=output_sample_types ), output_sample_types=output_sample_types ),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ), sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
output_sample_types=output_sample_types ) output_sample_types=output_sample_types )
]) ])
@ -362,20 +377,30 @@ class SAEModel(ModelBase):
ar = [] ar = []
if 'liae' in self.options['archi']: if 'liae' in self.options['archi']:
ar += [[self.encoder, 'encoder.h5'], ar += [[self.encoder, 'encoder.h5'],
[self.inter_B, 'inter_B.h5'], [self.inter_B, 'inter_B.h5'],
[self.inter_AB, 'inter_AB.h5'], [self.decoder, 'decoder.h5']
[self.decoder, 'decoder.h5'] ]
]
if not self.pretrain or self.iter == 0:
ar += [ [self.inter_AB, 'inter_AB.h5'],
]
if self.options['learn_mask']: if self.options['learn_mask']:
ar += [ [self.decoderm, 'decoderm.h5'] ] ar += [ [self.decoderm, 'decoderm.h5'] ]
elif 'df' in self.options['archi']: elif 'df' in self.options['archi']:
ar += [[self.encoder, 'encoder.h5'], ar += [ [self.encoder, 'encoder.h5'],
[self.decoder_src, 'decoder_src.h5'], ]
[self.decoder_dst, 'decoder_dst.h5']
] if not self.pretrain or self.iter == 0:
if self.options['learn_mask']: ar += [ [self.decoder_src, 'decoder_src.h5'],
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'], [self.decoder_dst, 'decoder_dst.h5']
[self.decoder_dstm, 'decoder_dstm.h5'] ] ]
if self.options['learn_mask']:
if not self.pretrain or self.iter == 0:
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
[self.decoder_dstm, 'decoder_dstm.h5'] ]
self.save_weights_safe(ar) self.save_weights_safe(ar)