fix update preview samples after disable pretrain

This commit is contained in:
Colombo 2020-01-28 13:32:01 +04:00
parent 7386a9d6fd
commit 9c6ca24642
3 changed files with 61 additions and 52 deletions

View file

@ -200,8 +200,13 @@ class ModelBase(object):
if not isinstance(generator, SampleGeneratorBase):
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
if self.sample_for_preview is None or self.choose_preview_history:
if self.choose_preview_history and io.is_support_windows():
self.update_sample_for_preview(choose_preview_history=self.choose_preview_history)
io.log_info( self.get_summary_text() )
def update_sample_for_preview(self, choose_preview_history=False, force_new=False):
if self.sample_for_preview is None or choose_preview_history or force_new:
if choose_preview_history and io.is_support_windows():
io.log_info ("Choose image for the preview history. [p] - next. [enter] - confirm.")
wnd_name = "[p] - next. [enter] - confirm."
io.named_window(wnd_name)
@ -237,8 +242,6 @@ class ModelBase(object):
self.last_sample = self.sample_for_preview
io.log_info( self.get_summary_text() )
def load_or_def_option(self, name, def_value):
options_val = self.options.get(name, None)
if options_val is not None:

View file

@ -337,11 +337,12 @@ class QModel(ModelBase):
# Loading/initializing all models/optimizers weights
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
do_init = self.is_first_run()
if self.pretrain_just_disabled:
do_init = False
if model == self.inter:
do_init = True
else:
do_init = self.is_first_run()
if not do_init:
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )

View file

@ -34,13 +34,6 @@ class SAEHDModel(ModelBase):
default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256)
default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64)
default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims)
default_d_mask_dims = default_d_dims // 3
default_d_mask_dims += default_d_mask_dims % 2
default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims)
default_use_float16 = self.options['use_float16'] = self.load_or_def_option('use_float16', False)
default_learn_mask = self.options['learn_mask'] = self.load_or_def_option('learn_mask', True)
default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False)
@ -61,18 +54,28 @@ class SAEHDModel(ModelBase):
self.ask_random_flip()
self.ask_batch_size(suggest_batch_size)
if self.is_first_run():
resolution = io.input_int("Resolution", default_resolution, add_info="64-256", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.")
resolution = np.clip ( (resolution // 16) * 16, 64, 256)
self.options['resolution'] = resolution
self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f'], help_message="Half / mid face / full face. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face.").lower()
self.options['archi'] = io.input_str ("AE architecture", default_archi, ['dfhd','liaehd','df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'hd' is heavyweight version for the best quality.").lower() #-s version is slower, but has decreased change to collapse.
default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64
default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims)
default_d_mask_dims = default_d_dims // 3
default_d_mask_dims += default_d_mask_dims % 2
default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims)
if self.is_first_run():
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
self.options['e_dims'] = e_dims + e_dims % 2
d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 )
self.options['d_dims'] = d_dims + d_dims % 2
@ -108,10 +111,6 @@ class SAEHDModel(ModelBase):
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
if self.pretrain_just_disabled:
self.set_iter(1)
#override
def on_initialize(self):
device_config = nn.getCurrentDeviceConfig()
@ -343,6 +342,8 @@ class SAEHDModel(ModelBase):
d_dims = self.options['d_dims']
d_mask_dims = self.options['d_mask_dims']
self.pretrain = self.options['pretrain']
if self.pretrain_just_disabled:
self.set_iter(0)
self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0
@ -679,15 +680,16 @@ class SAEHDModel(ModelBase):
# Loading/initializing all models/optimizers weights
for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"):
do_init = self.is_first_run()
if self.pretrain_just_disabled:
do_init = False
if 'df' in archi:
if model == self.inter:
do_init = True
elif 'liae' in archi:
if model == self.inter_AB:
do_init = True
else:
do_init = self.is_first_run()
if not do_init:
do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) )
@ -734,6 +736,9 @@ class SAEHDModel(ModelBase):
generators_count=dst_generators_count )
])
if self.pretrain_just_disabled:
self.update_sample_for_preview(force_new=True)
#override
def get_model_filename_list(self):
return self.model_filename_list