diff --git a/models/ModelBase.py b/models/ModelBase.py index b043eda..d0b0909 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -191,6 +191,7 @@ class ModelBase(object): self.random_flip = self.options.get('random_flip',True) self.random_src_flip = self.options.get('random_src_flip', False) self.random_dst_flip = self.options.get('random_dst_flip', True) + self.retraining_samples = self.options.get('retraining_samples', False) self.on_initialize() self.options['batch_size'] = self.batch_size @@ -330,6 +331,10 @@ class ModelBase(object): self.options['batch_size'] = self.batch_size = batch_size + def ask_retraining_samples(self, default_value=False): + default_retraining_samples = self.load_or_def_option('retraining_samples', default_value) + self.options['retraining_samples'] = io.input_bool("Retrain high loss samples", default_retraining_samples, help_message="Periodically retrains last 16 \"high-loss\" sample") + #overridable def on_initialize_options(self): diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index 8a88bd3..2e95504 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -16,6 +16,8 @@ class AMPModel(ModelBase): #override def on_initialize_options(self): + default_retraining_samples = self.options['retraining_samples'] = self.load_or_def_option('retraining_samples', False) + default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False) default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224) default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) @@ -54,9 +56,12 @@ class AMPModel(ModelBase): self.ask_autobackup_hour() self.ask_write_preview_history() self.ask_target_iter() + self.ask_retraining_samples() self.ask_random_src_flip() self.ask_random_dst_flip() self.ask_batch_size(8) + self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.') + if self.is_first_run(): resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .") @@ -154,10 +159,10 @@ class AMPModel(ModelBase): if ct_mode == 'none': ct_mode = None - use_fp16 = False - #if self.is_exporting: - #use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') - + use_fp16 = self.options['use_fp16'] + if self.is_exporting: + use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') + conv_dtype = tf.float16 if use_fp16 else tf.float32 class Downscale(nn.ModelBase): @@ -624,8 +629,9 @@ class AMPModel(ModelBase): generators_count=dst_generators_count ) ]) - self.last_src_samples_loss = [] - self.last_dst_samples_loss = [] + if self.options['retraining_samples']: + self.last_src_samples_loss = [] + self.last_dst_samples_loss = [] def export_dfm (self): output_path=self.get_strpath_storage_for_file('model.dfm') @@ -696,25 +702,26 @@ class AMPModel(ModelBase): src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) - for i in range(bs): - self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) ) - self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) ) + if self.options['retraining_samples']: + for i in range(bs): + self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) ) + self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) ) - if len(self.last_src_samples_loss) >= bs*16: - src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True) - dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True) + if len(self.last_src_samples_loss) >= bs*16: + src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True) + dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True) - target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) - target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) - target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] ) + target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) + target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) + target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] ) - target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) - target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) - target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) + target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) + target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) + target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) - src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em) - self.last_src_samples_loss = [] - self.last_dst_samples_loss = [] + src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em) + self.last_src_samples_loss = [] + self.last_dst_samples_loss = [] if self.gan_power != 0: self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 4ade959..a1b71f7 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -26,7 +26,6 @@ class SAEHDModel(ModelBase): else: suggest_batch_size = 4 - yn_str = {True:'y',False:'n'} min_res = 64 max_res = 640 @@ -77,6 +76,7 @@ class SAEHDModel(ModelBase): self.ask_maximum_n_backups() self.ask_write_preview_history() self.ask_target_iter() + self.ask_retraining_samples() self.ask_random_src_flip() self.ask_random_dst_flip() self.ask_batch_size(suggest_batch_size) @@ -113,7 +113,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... if archi_opts is not None: if len(archi_opts) == 0: continue - if len([ 1 for opt in archi_opts if opt not in ['u','d','t'] ]) != 0: + if len([ 1 for opt in archi_opts if opt not in ['u','d','t','c'] ]) != 0: continue if 'd' in archi_opts: @@ -253,7 +253,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... adabelief = self.options['adabelief'] - use_fp16 = False + use_fp16 = self.options['use_fp16'] if self.is_exporting: use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') @@ -832,8 +832,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... generators_count=dst_generators_count ) ]) - self.last_src_samples_loss = [] - self.last_dst_samples_loss = [] + if self.options['retraining_samples']: + self.last_src_samples_loss = [] + self.last_dst_samples_loss = [] if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True) @@ -909,8 +910,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled: io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n') - bs = self.get_batch_size() - ( (warped_src, target_src, target_srcm, target_srcm_em), \ (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() @@ -920,21 +919,24 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... self.last_src_samples_loss.append ( (target_src[i], target_srcm[i], target_srcm_em[i], src_loss[i] ) ) self.last_dst_samples_loss.append ( (target_dst[i], target_dstm[i], target_dstm_em[i], dst_loss[i] ) ) - if len(self.last_src_samples_loss) >= bs*16: - src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True) - dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True) + if self.options['retraining_samples']: + bs = self.get_batch_size() - target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] ) - target_srcm = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) - target_srcm_em = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) + if len(self.last_src_samples_loss) >= bs*16: + src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True) + dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True) - target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] ) - target_dstm = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) - target_dstm_em = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) + target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] ) + target_srcm = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) + target_srcm_em = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) - src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em) - self.last_src_samples_loss = [] - self.last_dst_samples_loss = [] + target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] ) + target_dstm = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) + target_dstm_em = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) + + src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em) + self.last_src_samples_loss = [] + self.last_dst_samples_loss = [] if self.options['true_face_power'] != 0 and not self.pretrain: self.D_train (warped_src, warped_dst)