New features and bug fixing

- Now you can choose to enable "retraining samples", which allows to
  retrain high loss samples.
- Added new C variant
- Implemented updates from main repository
- fp16 bug fix
This commit is contained in:
Cioscos 2021-10-02 16:37:28 +02:00
commit 33f8d12d2b
3 changed files with 54 additions and 40 deletions

View file

@ -191,6 +191,7 @@ class ModelBase(object):
self.random_flip = self.options.get('random_flip',True) self.random_flip = self.options.get('random_flip',True)
self.random_src_flip = self.options.get('random_src_flip', False) self.random_src_flip = self.options.get('random_src_flip', False)
self.random_dst_flip = self.options.get('random_dst_flip', True) self.random_dst_flip = self.options.get('random_dst_flip', True)
self.retraining_samples = self.options.get('retraining_samples', False)
self.on_initialize() self.on_initialize()
self.options['batch_size'] = self.batch_size self.options['batch_size'] = self.batch_size
@ -330,6 +331,10 @@ class ModelBase(object):
self.options['batch_size'] = self.batch_size = batch_size self.options['batch_size'] = self.batch_size = batch_size
def ask_retraining_samples(self, default_value=False):
default_retraining_samples = self.load_or_def_option('retraining_samples', default_value)
self.options['retraining_samples'] = io.input_bool("Retrain high loss samples", default_retraining_samples, help_message="Periodically retrains last 16 \"high-loss\" sample")
#overridable #overridable
def on_initialize_options(self): def on_initialize_options(self):

View file

@ -16,6 +16,8 @@ class AMPModel(ModelBase):
#override #override
def on_initialize_options(self): def on_initialize_options(self):
default_retraining_samples = self.options['retraining_samples'] = self.load_or_def_option('retraining_samples', False)
default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224) default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224)
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
@ -54,9 +56,12 @@ class AMPModel(ModelBase):
self.ask_autobackup_hour() self.ask_autobackup_hour()
self.ask_write_preview_history() self.ask_write_preview_history()
self.ask_target_iter() self.ask_target_iter()
self.ask_retraining_samples()
self.ask_random_src_flip() self.ask_random_src_flip()
self.ask_random_dst_flip() self.ask_random_dst_flip()
self.ask_batch_size(8) self.ask_batch_size(8)
self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.')
if self.is_first_run(): if self.is_first_run():
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .") resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .")
@ -154,9 +159,9 @@ class AMPModel(ModelBase):
if ct_mode == 'none': if ct_mode == 'none':
ct_mode = None ct_mode = None
use_fp16 = False use_fp16 = self.options['use_fp16']
#if self.is_exporting: if self.is_exporting:
#use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
conv_dtype = tf.float16 if use_fp16 else tf.float32 conv_dtype = tf.float16 if use_fp16 else tf.float32
@ -624,8 +629,9 @@ class AMPModel(ModelBase):
generators_count=dst_generators_count ) generators_count=dst_generators_count )
]) ])
self.last_src_samples_loss = [] if self.options['retraining_samples']:
self.last_dst_samples_loss = [] self.last_src_samples_loss = []
self.last_dst_samples_loss = []
def export_dfm (self): def export_dfm (self):
output_path=self.get_strpath_storage_for_file('model.dfm') output_path=self.get_strpath_storage_for_file('model.dfm')
@ -696,25 +702,26 @@ class AMPModel(ModelBase):
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
for i in range(bs): if self.options['retraining_samples']:
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) ) for i in range(bs):
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) ) self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) )
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) )
if len(self.last_src_samples_loss) >= bs*16: if len(self.last_src_samples_loss) >= bs*16:
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True) src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True) dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True)
target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] ) target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em) src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
self.last_src_samples_loss = [] self.last_src_samples_loss = []
self.last_dst_samples_loss = [] self.last_dst_samples_loss = []
if self.gan_power != 0: if self.gan_power != 0:
self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)

View file

@ -26,7 +26,6 @@ class SAEHDModel(ModelBase):
else: else:
suggest_batch_size = 4 suggest_batch_size = 4
yn_str = {True:'y',False:'n'}
min_res = 64 min_res = 64
max_res = 640 max_res = 640
@ -77,6 +76,7 @@ class SAEHDModel(ModelBase):
self.ask_maximum_n_backups() self.ask_maximum_n_backups()
self.ask_write_preview_history() self.ask_write_preview_history()
self.ask_target_iter() self.ask_target_iter()
self.ask_retraining_samples()
self.ask_random_src_flip() self.ask_random_src_flip()
self.ask_random_dst_flip() self.ask_random_dst_flip()
self.ask_batch_size(suggest_batch_size) self.ask_batch_size(suggest_batch_size)
@ -113,7 +113,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if archi_opts is not None: if archi_opts is not None:
if len(archi_opts) == 0: if len(archi_opts) == 0:
continue continue
if len([ 1 for opt in archi_opts if opt not in ['u','d','t'] ]) != 0: if len([ 1 for opt in archi_opts if opt not in ['u','d','t','c'] ]) != 0:
continue continue
if 'd' in archi_opts: if 'd' in archi_opts:
@ -253,7 +253,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
adabelief = self.options['adabelief'] adabelief = self.options['adabelief']
use_fp16 = False use_fp16 = self.options['use_fp16']
if self.is_exporting: if self.is_exporting:
use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
@ -832,8 +832,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
generators_count=dst_generators_count ) generators_count=dst_generators_count )
]) ])
self.last_src_samples_loss = [] if self.options['retraining_samples']:
self.last_dst_samples_loss = [] self.last_src_samples_loss = []
self.last_dst_samples_loss = []
if self.pretrain_just_disabled: if self.pretrain_just_disabled:
self.update_sample_for_preview(force_new=True) self.update_sample_for_preview(force_new=True)
@ -909,8 +910,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled: if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled:
io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n') io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n')
bs = self.get_batch_size()
( (warped_src, target_src, target_srcm, target_srcm_em), \ ( (warped_src, target_src, target_srcm, target_srcm_em), \
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
@ -920,21 +919,24 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.last_src_samples_loss.append ( (target_src[i], target_srcm[i], target_srcm_em[i], src_loss[i] ) ) self.last_src_samples_loss.append ( (target_src[i], target_srcm[i], target_srcm_em[i], src_loss[i] ) )
self.last_dst_samples_loss.append ( (target_dst[i], target_dstm[i], target_dstm_em[i], dst_loss[i] ) ) self.last_dst_samples_loss.append ( (target_dst[i], target_dstm[i], target_dstm_em[i], dst_loss[i] ) )
if len(self.last_src_samples_loss) >= bs*16: if self.options['retraining_samples']:
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True) bs = self.get_batch_size()
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True)
target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] ) if len(self.last_src_samples_loss) >= bs*16:
target_srcm = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True)
target_srcm_em = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True)
target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] ) target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] )
target_dstm = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) target_srcm = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
target_dstm_em = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) target_srcm_em = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em) target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] )
self.last_src_samples_loss = [] target_dstm = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
self.last_dst_samples_loss = [] target_dstm_em = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
if self.options['true_face_power'] != 0 and not self.pretrain: if self.options['true_face_power'] != 0 and not self.pretrain:
self.D_train (warped_src, warped_dst) self.D_train (warped_src, warped_dst)