diff --git a/models/ModelBase.py b/models/ModelBase.py index 855f823..ee06f00 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -200,45 +200,48 @@ class ModelBase(object): if not isinstance(generator, SampleGeneratorBase): raise ValueError('training data generator is not subclass of SampleGeneratorBase') - if self.sample_for_preview is None or self.choose_preview_history: - if self.choose_preview_history and io.is_support_windows(): - io.log_info ("Choose image for the preview history. [p] - next. [enter] - confirm.") - wnd_name = "[p] - next. [enter] - confirm." - io.named_window(wnd_name) - io.capture_keys(wnd_name) - choosed = False - while not choosed: - self.sample_for_preview = self.generate_next_samples() - preview = self.get_static_preview() - io.show_image( wnd_name, (preview*255).astype(np.uint8) ) - - while True: - key_events = io.get_key_events(wnd_name) - key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) - if key == ord('\n') or key == ord('\r'): - choosed = True - break - elif key == ord('p'): - break - - try: - io.process_messages(0.1) - except KeyboardInterrupt: - choosed = True - - io.destroy_window(wnd_name) - else: - self.sample_for_preview = self.generate_next_samples() - - try: - self.get_static_preview() - except: - self.sample_for_preview = self.generate_next_samples() - - self.last_sample = self.sample_for_preview + self.update_sample_for_preview(choose_preview_history=self.choose_preview_history) io.log_info( self.get_summary_text() ) + + def update_sample_for_preview(self, choose_preview_history=False, force_new=False): + if self.sample_for_preview is None or choose_preview_history or force_new: + if choose_preview_history and io.is_support_windows(): + io.log_info ("Choose image for the preview history. [p] - next. [enter] - confirm.") + wnd_name = "[p] - next. [enter] - confirm." + io.named_window(wnd_name) + io.capture_keys(wnd_name) + choosed = False + while not choosed: + self.sample_for_preview = self.generate_next_samples() + preview = self.get_static_preview() + io.show_image( wnd_name, (preview*255).astype(np.uint8) ) + while True: + key_events = io.get_key_events(wnd_name) + key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) + if key == ord('\n') or key == ord('\r'): + choosed = True + break + elif key == ord('p'): + break + + try: + io.process_messages(0.1) + except KeyboardInterrupt: + choosed = True + + io.destroy_window(wnd_name) + else: + self.sample_for_preview = self.generate_next_samples() + + try: + self.get_static_preview() + except: + self.sample_for_preview = self.generate_next_samples() + + self.last_sample = self.sample_for_preview + def load_or_def_option(self, name, def_value): options_val = self.options.get(name, None) if options_val is not None: diff --git a/models/Model_Quick96/Model.py b/models/Model_Quick96/Model.py index 91d874b..3be6c53 100644 --- a/models/Model_Quick96/Model.py +++ b/models/Model_Quick96/Model.py @@ -337,11 +337,12 @@ class QModel(ModelBase): # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): - do_init = self.is_first_run() - if self.pretrain_just_disabled: + do_init = False if model == self.inter: do_init = True + else: + do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 412588b..3494aa2 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -34,13 +34,6 @@ class SAEHDModel(ModelBase): default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) - default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64 - default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims) - - default_d_mask_dims = default_d_dims // 3 - default_d_mask_dims += default_d_mask_dims % 2 - default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims) - default_use_float16 = self.options['use_float16'] = self.load_or_def_option('use_float16', False) default_learn_mask = self.options['learn_mask'] = self.load_or_def_option('learn_mask', True) default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', False) @@ -61,18 +54,28 @@ class SAEHDModel(ModelBase): self.ask_random_flip() self.ask_batch_size(suggest_batch_size) - if self.is_first_run(): + resolution = io.input_int("Resolution", default_resolution, add_info="64-256", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16.") resolution = np.clip ( (resolution // 16) * 16, 64, 256) self.options['resolution'] = resolution self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f'], help_message="Half / mid face / full face. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face.").lower() self.options['archi'] = io.input_str ("AE architecture", default_archi, ['dfhd','liaehd','df','liae'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'hd' is heavyweight version for the best quality.").lower() #-s version is slower, but has decreased change to collapse. + + default_d_dims = 48 if self.options['archi'] == 'dfhd' else 64 + default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', default_d_dims) + + default_d_mask_dims = default_d_dims // 3 + default_d_mask_dims += default_d_mask_dims % 2 + default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims) + + if self.is_first_run(): self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) self.options['e_dims'] = e_dims + e_dims % 2 + d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) self.options['d_dims'] = d_dims + d_dims % 2 @@ -108,10 +111,6 @@ class SAEHDModel(ModelBase): self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) - if self.pretrain_just_disabled: - self.set_iter(1) - - #override def on_initialize(self): device_config = nn.getCurrentDeviceConfig() @@ -343,6 +342,8 @@ class SAEHDModel(ModelBase): d_dims = self.options['d_dims'] d_mask_dims = self.options['d_mask_dims'] self.pretrain = self.options['pretrain'] + if self.pretrain_just_disabled: + self.set_iter(0) self.gan_power = gan_power = self.options['gan_power'] if not self.pretrain else 0.0 @@ -679,15 +680,16 @@ class SAEHDModel(ModelBase): # Loading/initializing all models/optimizers weights for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): - do_init = self.is_first_run() - if self.pretrain_just_disabled: + do_init = False if 'df' in archi: if model == self.inter: do_init = True elif 'liae' in archi: if model == self.inter_AB: do_init = True + else: + do_init = self.is_first_run() if not do_init: do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) @@ -734,6 +736,9 @@ class SAEHDModel(ModelBase): generators_count=dst_generators_count ) ]) + if self.pretrain_just_disabled: + self.update_sample_for_preview(force_new=True) + #override def get_model_filename_list(self): return self.model_filename_list