mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 13:32:09 -07:00
fix update preview samples after disable pretrain
This commit is contained in:
parent
7386a9d6fd
commit
9c6ca24642
3 changed files with 61 additions and 52 deletions
|
@ -200,45 +200,48 @@ class ModelBase(object):
|
||||||
if not isinstance(generator, SampleGeneratorBase):
|
if not isinstance(generator, SampleGeneratorBase):
|
||||||
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
|
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
|
||||||
|
|
||||||
if self.sample_for_preview is None or self.choose_preview_history:
|
self.update_sample_for_preview(choose_preview_history=self.choose_preview_history)
|
||||||
if self.choose_preview_history and io.is_support_windows():
|
|
||||||
io.log_info ("Choose image for the preview history. [p] - next. [enter] - confirm.")
|
|
||||||
wnd_name = "[p] - next. [enter] - confirm."
|
|
||||||
io.named_window(wnd_name)
|
|
||||||
io.capture_keys(wnd_name)
|
|
||||||
choosed = False
|
|
||||||
while not choosed:
|
|
||||||
self.sample_for_preview = self.generate_next_samples()
|
|
||||||
preview = self.get_static_preview()
|
|
||||||
io.show_image( wnd_name, (preview*255).astype(np.uint8) )
|
|
||||||
|
|
||||||
while True:
|
|
||||||
key_events = io.get_key_events(wnd_name)
|
|
||||||
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
|
|
||||||
if key == ord('\n') or key == ord('\r'):
|
|
||||||
choosed = True
|
|
||||||
break
|
|
||||||
elif key == ord('p'):
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
io.process_messages(0.1)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
choosed = True
|
|
||||||
|
|
||||||
io.destroy_window(wnd_name)
|
|
||||||
else:
|
|
||||||
self.sample_for_preview = self.generate_next_samples()
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.get_static_preview()
|
|
||||||
except:
|
|
||||||
self.sample_for_preview = self.generate_next_samples()
|
|
||||||
|
|
||||||
self.last_sample = self.sample_for_preview
|
|
||||||
|
|
||||||
io.log_info( self.get_summary_text() )
|
io.log_info( self.get_summary_text() )
|
||||||
|
|
||||||
|
def update_sample_for_preview(self, choose_preview_history=False, force_new=False):
|
||||||
|
if self.sample_for_preview is None or choose_preview_history or force_new:
|
||||||
|
if choose_preview_history and io.is_support_windows():
|
||||||
|
io.log_info ("Choose image for the preview history. [p] - next. [enter] - confirm.")
|
||||||
|
wnd_name = "[p] - next. [enter] - confirm."
|
||||||
|
io.named_window(wnd_name)
|
||||||
|
io.capture_keys(wnd_name)
|
||||||
|
choosed = False
|
||||||
|
while not choosed:
|
||||||
|
self.sample_for_preview = self.generate_next_samples()
|
||||||
|
preview = self.get_static_preview()
|
||||||
|
io.show_image( wnd_name, (preview*255).astype(np.uint8) )
|
||||||
|
|
||||||
|
while True:
|
||||||
|
key_events = io.get_key_events(wnd_name)
|
||||||
|
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
|
||||||
|
if key == ord('\n') or key == ord('\r'):
|
||||||
|
choosed = True
|
||||||
|
break
|
||||||
|
elif key == ord('p'):
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
io.process_messages(0.1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
choosed = True
|
||||||
|
|
||||||
|
io.destroy_window(wnd_name)
|
||||||
|
else:
|
||||||
|
self.sample_for_preview = self.generate_next_samples()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.get_static_preview()
|
||||||
|
except:
|
||||||
|
self.sample_for_preview = self.generate_next_samples()
|
||||||
|
|
||||||
|
self.last_sample = self.sample_for_preview
|
||||||
|
|
||||||
def load_or_def_option(self, name, def_value):
|
def load_or_def_option(self, name, def_value):
|
||||||
options_val = self.options.get(name, None)
|
options_val = self.options.get(name, None)
|
||||||
if options_val is not None:
|
if options_val is not None:
|
||||||
|
|
|
@ -337,11 +337,12 @@ class QModel(ModelBase):
|
||||||
|
|
||||||
# Loading/initializing all models/optimizers weights
|
# Loading/initializing all models/optimizers weights
|
||||||
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
|
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
|
||||||
do_init = self.is_first_run()
|
|
||||||
|
|
||||||
if self.pretrain_just_disabled:
|
if self.pretrain_just_disabled:
|
||||||
|
do_init = False
|
||||||
if model == self.inter:
|
if model == self.inter:
|
||||||
do_init = True
|
do_init = True
|
||||||
|
else:
|
||||||
|
do_init = self.is_first_run()
|
||||||
|
|
||||||
if not do_init:
|
if not do_init:
|
||||||
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
|
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
|
||||||
|
|
|
@ -34,13 +34,6 @@ class SAEHDModel(ModelBase):
|
||||||
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
|
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
|
||||||
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
|
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
|
||||||
|
|
||||||
default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64
|
|
||||||
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims)
|
|
||||||
|
|
||||||
default_d_mask_dims = default_d_dims // 3
|
|
||||||
default_d_mask_dims += default_d_mask_dims % 2
|
|
||||||
default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims)
|
|
||||||
|
|
||||||
default_use_float16 = self.options['use_float16'] = self.load_or_def_option('use_float16', False)
|
default_use_float16 = self.options['use_float16'] = self.load_or_def_option('use_float16', False)
|
||||||
default_learn_mask = self.options['learn_mask'] = self.load_or_def_option('learn_mask', True)
|
default_learn_mask = self.options['learn_mask'] = self.load_or_def_option('learn_mask', True)
|
||||||
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False)
|
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False)
|
||||||
|
@ -61,18 +54,28 @@ class SAEHDModel(ModelBase):
|
||||||
self.ask_random_flip()
|
self.ask_random_flip()
|
||||||
self.ask_batch_size(suggest_batch_size)
|
self.ask_batch_size(suggest_batch_size)
|
||||||
|
|
||||||
if self.is_first_run():
|
|
||||||
resolution = io.input_int("Resolution", default_resolution, add_info="64-256", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
|
resolution = io.input_int("Resolution", default_resolution, add_info="64-256", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
|
||||||
resolution = np.clip ( (resolution // 16) * 16, 64, 256)
|
resolution = np.clip ( (resolution // 16) * 16, 64, 256)
|
||||||
self.options['resolution'] = resolution
|
self.options['resolution'] = resolution
|
||||||
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f'], help_message="Half / mid face / full face. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face.").lower()
|
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f'], help_message="Half / mid face / full face. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face.").lower()
|
||||||
|
|
||||||
self.options['archi'] = io.input_str ("AE architecture", default_archi, ['dfhd','liaehd','df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'hd' is heavyweight version for the best quality.").lower() #-s version is slower, but has decreased change to collapse.
|
self.options['archi'] = io.input_str ("AE architecture", default_archi, ['dfhd','liaehd','df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'hd' is heavyweight version for the best quality.").lower() #-s version is slower, but has decreased change to collapse.
|
||||||
|
|
||||||
|
default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64
|
||||||
|
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims)
|
||||||
|
|
||||||
|
default_d_mask_dims = default_d_dims // 3
|
||||||
|
default_d_mask_dims += default_d_mask_dims % 2
|
||||||
|
default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims)
|
||||||
|
|
||||||
|
if self.is_first_run():
|
||||||
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
|
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
|
||||||
|
|
||||||
e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
|
e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
|
||||||
self.options['e_dims'] = e_dims + e_dims % 2
|
self.options['e_dims'] = e_dims + e_dims % 2
|
||||||
|
|
||||||
|
|
||||||
d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
|
d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
|
||||||
self.options['d_dims'] = d_dims + d_dims % 2
|
self.options['d_dims'] = d_dims + d_dims % 2
|
||||||
|
|
||||||
|
@ -108,10 +111,6 @@ class SAEHDModel(ModelBase):
|
||||||
|
|
||||||
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
|
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
|
||||||
|
|
||||||
if self.pretrain_just_disabled:
|
|
||||||
self.set_iter(1)
|
|
||||||
|
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def on_initialize(self):
|
def on_initialize(self):
|
||||||
device_config = nn.getCurrentDeviceConfig()
|
device_config = nn.getCurrentDeviceConfig()
|
||||||
|
@ -343,6 +342,8 @@ class SAEHDModel(ModelBase):
|
||||||
d_dims = self.options['d_dims']
|
d_dims = self.options['d_dims']
|
||||||
d_mask_dims = self.options['d_mask_dims']
|
d_mask_dims = self.options['d_mask_dims']
|
||||||
self.pretrain = self.options['pretrain']
|
self.pretrain = self.options['pretrain']
|
||||||
|
if self.pretrain_just_disabled:
|
||||||
|
self.set_iter(0)
|
||||||
|
|
||||||
self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0
|
self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0
|
||||||
|
|
||||||
|
@ -679,15 +680,16 @@ class SAEHDModel(ModelBase):
|
||||||
|
|
||||||
# Loading/initializing all models/optimizers weights
|
# Loading/initializing all models/optimizers weights
|
||||||
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
|
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
|
||||||
do_init = self.is_first_run()
|
|
||||||
|
|
||||||
if self.pretrain_just_disabled:
|
if self.pretrain_just_disabled:
|
||||||
|
do_init = False
|
||||||
if 'df' in archi:
|
if 'df' in archi:
|
||||||
if model == self.inter:
|
if model == self.inter:
|
||||||
do_init = True
|
do_init = True
|
||||||
elif 'liae' in archi:
|
elif 'liae' in archi:
|
||||||
if model == self.inter_AB:
|
if model == self.inter_AB:
|
||||||
do_init = True
|
do_init = True
|
||||||
|
else:
|
||||||
|
do_init = self.is_first_run()
|
||||||
|
|
||||||
if not do_init:
|
if not do_init:
|
||||||
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
|
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
|
||||||
|
@ -734,6 +736,9 @@ class SAEHDModel(ModelBase):
|
||||||
generators_count=dst_generators_count )
|
generators_count=dst_generators_count )
|
||||||
])
|
])
|
||||||
|
|
||||||
|
if self.pretrain_just_disabled:
|
||||||
|
self.update_sample_for_preview(force_new=True)
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def get_model_filename_list(self):
|
def get_model_filename_list(self):
|
||||||
return self.model_filename_list
|
return self.model_filename_list
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue