mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-08 05:51:40 -07:00
SAEHD: lr_dropout now as an option, and disabled by default
This commit is contained in:
parent
951942821d
commit
021bb6d128
1 changed files with 18 additions and 14 deletions
|
@ -65,6 +65,9 @@ class SAEHDModel(ModelBase):
|
||||||
default_bg_style_power = self.options.get('bg_style_power', 0.0)
|
default_bg_style_power = self.options.get('bg_style_power', 0.0)
|
||||||
|
|
||||||
if is_first_run or ask_override:
|
if is_first_run or ask_override:
|
||||||
|
default_lr_dropout = self.options.get('lr_dropout', False)
|
||||||
|
self.options['lr_dropout'] = io.input_bool ( f"Use learning rate dropout? (y/n, ?:help skip:{yn_str[default_lr_dropout]} ) : ", default_lr_dropout, help_message="When the face is trained enough, you can enable this option to get extra sharpness for less amount of iterations.")
|
||||||
|
|
||||||
default_random_warp = self.options.get('random_warp', True)
|
default_random_warp = self.options.get('random_warp', True)
|
||||||
self.options['random_warp'] = io.input_bool (f"Enable random warp of samples? ( y/n, ?:help skip:{yn_str[default_random_warp]}) : ", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness for less amount of iterations.")
|
self.options['random_warp'] = io.input_bool (f"Enable random warp of samples? ( y/n, ?:help skip:{yn_str[default_random_warp]}) : ", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness for less amount of iterations.")
|
||||||
|
|
||||||
|
@ -84,8 +87,8 @@ class SAEHDModel(ModelBase):
|
||||||
self.options['clipgrad'] = io.input_bool (f"Enable gradient clipping? (y/n, ?:help skip:{yn_str[default_clipgrad]}) : ", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
|
self.options['clipgrad'] = io.input_bool (f"Enable gradient clipping? (y/n, ?:help skip:{yn_str[default_clipgrad]}) : ", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
|
||||||
else:
|
else:
|
||||||
self.options['clipgrad'] = False
|
self.options['clipgrad'] = False
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
self.options['lr_dropout'] = self.options.get('lr_dropout', default_lr_dropout)
|
||||||
self.options['random_warp'] = self.options.get('random_warp', True)
|
self.options['random_warp'] = self.options.get('random_warp', True)
|
||||||
self.options['true_face_training'] = self.options.get('true_face_training', default_true_face_training)
|
self.options['true_face_training'] = self.options.get('true_face_training', default_true_face_training)
|
||||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||||
|
@ -452,9 +455,10 @@ class SAEHDModel(ModelBase):
|
||||||
psd_target_dst_anti_masked = self.model.pred_src_dst*(1.0 - target_dstm)
|
psd_target_dst_anti_masked = self.model.pred_src_dst*(1.0 - target_dstm)
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
self.src_dst_opt = RMSprop(lr=5e-5, lr_dropout=0.3, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
lr_dropout = 0.3 if self.options['lr_dropout'] else 0.0
|
||||||
self.src_dst_mask_opt = RMSprop(lr=5e-5, lr_dropout=0.3, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_opt = RMSprop(lr=5e-5, lr_dropout=lr_dropout, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
self.D_opt = RMSprop(lr=5e-5, lr_dropout=0.3, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_mask_opt = RMSprop(lr=5e-5, lr_dropout=lr_dropout, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
|
self.D_opt = RMSprop(lr=5e-5, lr_dropout=lr_dropout, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
|
|
||||||
src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )
|
src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) )
|
||||||
src_loss += K.mean ( 10*K.square( target_src_masked_opt - pred_src_src_masked_opt ) )
|
src_loss += K.mean ( 10*K.square( target_src_masked_opt - pred_src_src_masked_opt ) )
|
||||||
|
@ -535,7 +539,7 @@ class SAEHDModel(ModelBase):
|
||||||
t_img_warped = t.IMG_WARPED_TRANSFORMED if self.options['random_warp'] else t.IMG_TRANSFORMED
|
t_img_warped = t.IMG_WARPED_TRANSFORMED if self.options['random_warp'] else t.IMG_TRANSFORMED
|
||||||
|
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None, use_caching=False,
|
SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||||
random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' else None,
|
random_ct_samples_path=training_data_dst_path if self.options['ct_mode'] != 'none' else None,
|
||||||
debug=self.is_debug(), batch_size=self.batch_size,
|
debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||||
|
@ -544,7 +548,7 @@ class SAEHDModel(ModelBase):
|
||||||
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution } ]
|
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution } ]
|
||||||
),
|
),
|
||||||
|
|
||||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, use_caching=False,
|
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
||||||
output_sample_types = [ {'types' : (t_img_warped, face_type, t_mode_bgr), 'resolution':resolution},
|
output_sample_types = [ {'types' : (t_img_warped, face_type, t_mode_bgr), 'resolution':resolution},
|
||||||
{'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution},
|
{'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution},
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue