mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 13:32:09 -07:00
fix update preview samples after disable pretrain
This commit is contained in:
parent
7386a9d6fd
commit
9c6ca24642
3 changed files with 61 additions and 52 deletions
|
@ -200,8 +200,13 @@ class ModelBase(object):
|
|||
if not isinstance(generator, SampleGeneratorBase):
|
||||
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
|
||||
|
||||
if self.sample_for_preview is None or self.choose_preview_history:
|
||||
if self.choose_preview_history and io.is_support_windows():
|
||||
self.update_sample_for_preview(choose_preview_history=self.choose_preview_history)
|
||||
|
||||
io.log_info( self.get_summary_text() )
|
||||
|
||||
def update_sample_for_preview(self, choose_preview_history=False, force_new=False):
|
||||
if self.sample_for_preview is None or choose_preview_history or force_new:
|
||||
if choose_preview_history and io.is_support_windows():
|
||||
io.log_info ("Choose image for the preview history. [p] - next. [enter] - confirm.")
|
||||
wnd_name = "[p] - next. [enter] - confirm."
|
||||
io.named_window(wnd_name)
|
||||
|
@ -237,8 +242,6 @@ class ModelBase(object):
|
|||
|
||||
self.last_sample = self.sample_for_preview
|
||||
|
||||
io.log_info( self.get_summary_text() )
|
||||
|
||||
def load_or_def_option(self, name, def_value):
|
||||
options_val = self.options.get(name, None)
|
||||
if options_val is not None:
|
||||
|
|
|
@ -337,11 +337,12 @@ class QModel(ModelBase):
|
|||
|
||||
# Loading/initializing all models/optimizers weights
|
||||
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
|
||||
do_init = self.is_first_run()
|
||||
|
||||
if self.pretrain_just_disabled:
|
||||
do_init = False
|
||||
if model == self.inter:
|
||||
do_init = True
|
||||
else:
|
||||
do_init = self.is_first_run()
|
||||
|
||||
if not do_init:
|
||||
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
|
||||
|
|
|
@ -34,13 +34,6 @@ class SAEHDModel(ModelBase):
|
|||
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
|
||||
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
|
||||
|
||||
default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64
|
||||
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims)
|
||||
|
||||
default_d_mask_dims = default_d_dims // 3
|
||||
default_d_mask_dims += default_d_mask_dims % 2
|
||||
default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims)
|
||||
|
||||
default_use_float16 = self.options['use_float16'] = self.load_or_def_option('use_float16', False)
|
||||
default_learn_mask = self.options['learn_mask'] = self.load_or_def_option('learn_mask', True)
|
||||
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False)
|
||||
|
@ -61,18 +54,28 @@ class SAEHDModel(ModelBase):
|
|||
self.ask_random_flip()
|
||||
self.ask_batch_size(suggest_batch_size)
|
||||
|
||||
if self.is_first_run():
|
||||
|
||||
resolution = io.input_int("Resolution", default_resolution, add_info="64-256", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
|
||||
resolution = np.clip ( (resolution // 16) * 16, 64, 256)
|
||||
self.options['resolution'] = resolution
|
||||
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f'], help_message="Half / mid face / full face. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face.").lower()
|
||||
|
||||
self.options['archi'] = io.input_str ("AE architecture", default_archi, ['dfhd','liaehd','df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'hd' is heavyweight version for the best quality.").lower() #-s version is slower, but has decreased change to collapse.
|
||||
|
||||
default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64
|
||||
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims)
|
||||
|
||||
default_d_mask_dims = default_d_dims // 3
|
||||
default_d_mask_dims += default_d_mask_dims % 2
|
||||
default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims)
|
||||
|
||||
if self.is_first_run():
|
||||
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
|
||||
|
||||
e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
|
||||
self.options['e_dims'] = e_dims + e_dims % 2
|
||||
|
||||
|
||||
d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
|
||||
self.options['d_dims'] = d_dims + d_dims % 2
|
||||
|
||||
|
@ -108,10 +111,6 @@ class SAEHDModel(ModelBase):
|
|||
|
||||
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
|
||||
|
||||
if self.pretrain_just_disabled:
|
||||
self.set_iter(1)
|
||||
|
||||
|
||||
#override
|
||||
def on_initialize(self):
|
||||
device_config = nn.getCurrentDeviceConfig()
|
||||
|
@ -343,6 +342,8 @@ class SAEHDModel(ModelBase):
|
|||
d_dims = self.options['d_dims']
|
||||
d_mask_dims = self.options['d_mask_dims']
|
||||
self.pretrain = self.options['pretrain']
|
||||
if self.pretrain_just_disabled:
|
||||
self.set_iter(0)
|
||||
|
||||
self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0
|
||||
|
||||
|
@ -679,15 +680,16 @@ class SAEHDModel(ModelBase):
|
|||
|
||||
# Loading/initializing all models/optimizers weights
|
||||
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
|
||||
do_init = self.is_first_run()
|
||||
|
||||
if self.pretrain_just_disabled:
|
||||
do_init = False
|
||||
if 'df' in archi:
|
||||
if model == self.inter:
|
||||
do_init = True
|
||||
elif 'liae' in archi:
|
||||
if model == self.inter_AB:
|
||||
do_init = True
|
||||
else:
|
||||
do_init = self.is_first_run()
|
||||
|
||||
if not do_init:
|
||||
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
|
||||
|
@ -734,6 +736,9 @@ class SAEHDModel(ModelBase):
|
|||
generators_count=dst_generators_count )
|
||||
])
|
||||
|
||||
if self.pretrain_just_disabled:
|
||||
self.update_sample_for_preview(force_new=True)
|
||||
|
||||
#override
|
||||
def get_model_filename_list(self):
|
||||
return self.model_filename_list
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue