mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 13:09:56 -07:00
New features and bug fixing
- Now you can choose to enable "retraining samples", which allows to retrain high loss samples. - Added new C variant - Implemented updates from main repository - fp16 bug fix
This commit is contained in:
parent
c0ddc6fc5d
commit
33f8d12d2b
3 changed files with 54 additions and 40 deletions
|
@ -191,6 +191,7 @@ class ModelBase(object):
|
||||||
self.random_flip = self.options.get('random_flip',True)
|
self.random_flip = self.options.get('random_flip',True)
|
||||||
self.random_src_flip = self.options.get('random_src_flip', False)
|
self.random_src_flip = self.options.get('random_src_flip', False)
|
||||||
self.random_dst_flip = self.options.get('random_dst_flip', True)
|
self.random_dst_flip = self.options.get('random_dst_flip', True)
|
||||||
|
self.retraining_samples = self.options.get('retraining_samples', False)
|
||||||
|
|
||||||
self.on_initialize()
|
self.on_initialize()
|
||||||
self.options['batch_size'] = self.batch_size
|
self.options['batch_size'] = self.batch_size
|
||||||
|
@ -330,6 +331,10 @@ class ModelBase(object):
|
||||||
|
|
||||||
self.options['batch_size'] = self.batch_size = batch_size
|
self.options['batch_size'] = self.batch_size = batch_size
|
||||||
|
|
||||||
|
def ask_retraining_samples(self, default_value=False):
|
||||||
|
default_retraining_samples = self.load_or_def_option('retraining_samples', default_value)
|
||||||
|
self.options['retraining_samples'] = io.input_bool("Retrain high loss samples", default_retraining_samples, help_message="Periodically retrains last 16 \"high-loss\" sample")
|
||||||
|
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
def on_initialize_options(self):
|
def on_initialize_options(self):
|
||||||
|
|
|
@ -16,6 +16,8 @@ class AMPModel(ModelBase):
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def on_initialize_options(self):
|
def on_initialize_options(self):
|
||||||
|
default_retraining_samples = self.options['retraining_samples'] = self.load_or_def_option('retraining_samples', False)
|
||||||
|
default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
|
||||||
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224)
|
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224)
|
||||||
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
|
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
|
||||||
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
|
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
|
||||||
|
@ -54,9 +56,12 @@ class AMPModel(ModelBase):
|
||||||
self.ask_autobackup_hour()
|
self.ask_autobackup_hour()
|
||||||
self.ask_write_preview_history()
|
self.ask_write_preview_history()
|
||||||
self.ask_target_iter()
|
self.ask_target_iter()
|
||||||
|
self.ask_retraining_samples()
|
||||||
self.ask_random_src_flip()
|
self.ask_random_src_flip()
|
||||||
self.ask_random_dst_flip()
|
self.ask_random_dst_flip()
|
||||||
self.ask_batch_size(8)
|
self.ask_batch_size(8)
|
||||||
|
self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.')
|
||||||
|
|
||||||
|
|
||||||
if self.is_first_run():
|
if self.is_first_run():
|
||||||
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .")
|
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .")
|
||||||
|
@ -154,9 +159,9 @@ class AMPModel(ModelBase):
|
||||||
if ct_mode == 'none':
|
if ct_mode == 'none':
|
||||||
ct_mode = None
|
ct_mode = None
|
||||||
|
|
||||||
use_fp16 = False
|
use_fp16 = self.options['use_fp16']
|
||||||
#if self.is_exporting:
|
if self.is_exporting:
|
||||||
#use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
|
use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
|
||||||
|
|
||||||
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
||||||
|
|
||||||
|
@ -624,8 +629,9 @@ class AMPModel(ModelBase):
|
||||||
generators_count=dst_generators_count )
|
generators_count=dst_generators_count )
|
||||||
])
|
])
|
||||||
|
|
||||||
self.last_src_samples_loss = []
|
if self.options['retraining_samples']:
|
||||||
self.last_dst_samples_loss = []
|
self.last_src_samples_loss = []
|
||||||
|
self.last_dst_samples_loss = []
|
||||||
|
|
||||||
def export_dfm (self):
|
def export_dfm (self):
|
||||||
output_path=self.get_strpath_storage_for_file('model.dfm')
|
output_path=self.get_strpath_storage_for_file('model.dfm')
|
||||||
|
@ -696,25 +702,26 @@ class AMPModel(ModelBase):
|
||||||
|
|
||||||
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||||
|
|
||||||
for i in range(bs):
|
if self.options['retraining_samples']:
|
||||||
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) )
|
for i in range(bs):
|
||||||
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) )
|
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) )
|
||||||
|
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) )
|
||||||
|
|
||||||
if len(self.last_src_samples_loss) >= bs*16:
|
if len(self.last_src_samples_loss) >= bs*16:
|
||||||
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
|
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
|
||||||
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True)
|
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True)
|
||||||
|
|
||||||
target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
|
target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
|
||||||
target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
|
target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
|
||||||
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
|
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
|
||||||
|
|
||||||
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
||||||
target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
|
target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
|
||||||
target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
|
target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
|
||||||
|
|
||||||
src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
|
src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
|
||||||
self.last_src_samples_loss = []
|
self.last_src_samples_loss = []
|
||||||
self.last_dst_samples_loss = []
|
self.last_dst_samples_loss = []
|
||||||
|
|
||||||
if self.gan_power != 0:
|
if self.gan_power != 0:
|
||||||
self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
|
||||||
|
|
|
@ -26,7 +26,6 @@ class SAEHDModel(ModelBase):
|
||||||
else:
|
else:
|
||||||
suggest_batch_size = 4
|
suggest_batch_size = 4
|
||||||
|
|
||||||
yn_str = {True:'y',False:'n'}
|
|
||||||
min_res = 64
|
min_res = 64
|
||||||
max_res = 640
|
max_res = 640
|
||||||
|
|
||||||
|
@ -77,6 +76,7 @@ class SAEHDModel(ModelBase):
|
||||||
self.ask_maximum_n_backups()
|
self.ask_maximum_n_backups()
|
||||||
self.ask_write_preview_history()
|
self.ask_write_preview_history()
|
||||||
self.ask_target_iter()
|
self.ask_target_iter()
|
||||||
|
self.ask_retraining_samples()
|
||||||
self.ask_random_src_flip()
|
self.ask_random_src_flip()
|
||||||
self.ask_random_dst_flip()
|
self.ask_random_dst_flip()
|
||||||
self.ask_batch_size(suggest_batch_size)
|
self.ask_batch_size(suggest_batch_size)
|
||||||
|
@ -113,7 +113,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
if archi_opts is not None:
|
if archi_opts is not None:
|
||||||
if len(archi_opts) == 0:
|
if len(archi_opts) == 0:
|
||||||
continue
|
continue
|
||||||
if len([ 1 for opt in archi_opts if opt not in ['u','d','t'] ]) != 0:
|
if len([ 1 for opt in archi_opts if opt not in ['u','d','t','c'] ]) != 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if 'd' in archi_opts:
|
if 'd' in archi_opts:
|
||||||
|
@ -253,7 +253,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
|
|
||||||
adabelief = self.options['adabelief']
|
adabelief = self.options['adabelief']
|
||||||
|
|
||||||
use_fp16 = False
|
use_fp16 = self.options['use_fp16']
|
||||||
if self.is_exporting:
|
if self.is_exporting:
|
||||||
use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
|
use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.')
|
||||||
|
|
||||||
|
@ -832,8 +832,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
generators_count=dst_generators_count )
|
generators_count=dst_generators_count )
|
||||||
])
|
])
|
||||||
|
|
||||||
self.last_src_samples_loss = []
|
if self.options['retraining_samples']:
|
||||||
self.last_dst_samples_loss = []
|
self.last_src_samples_loss = []
|
||||||
|
self.last_dst_samples_loss = []
|
||||||
|
|
||||||
if self.pretrain_just_disabled:
|
if self.pretrain_just_disabled:
|
||||||
self.update_sample_for_preview(force_new=True)
|
self.update_sample_for_preview(force_new=True)
|
||||||
|
@ -909,8 +910,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled:
|
if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled:
|
||||||
io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n')
|
io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n')
|
||||||
|
|
||||||
bs = self.get_batch_size()
|
|
||||||
|
|
||||||
( (warped_src, target_src, target_srcm, target_srcm_em), \
|
( (warped_src, target_src, target_srcm, target_srcm_em), \
|
||||||
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
|
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
|
||||||
|
|
||||||
|
@ -920,21 +919,24 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
||||||
self.last_src_samples_loss.append ( (target_src[i], target_srcm[i], target_srcm_em[i], src_loss[i] ) )
|
self.last_src_samples_loss.append ( (target_src[i], target_srcm[i], target_srcm_em[i], src_loss[i] ) )
|
||||||
self.last_dst_samples_loss.append ( (target_dst[i], target_dstm[i], target_dstm_em[i], dst_loss[i] ) )
|
self.last_dst_samples_loss.append ( (target_dst[i], target_dstm[i], target_dstm_em[i], dst_loss[i] ) )
|
||||||
|
|
||||||
if len(self.last_src_samples_loss) >= bs*16:
|
if self.options['retraining_samples']:
|
||||||
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True)
|
bs = self.get_batch_size()
|
||||||
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True)
|
|
||||||
|
|
||||||
target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] )
|
if len(self.last_src_samples_loss) >= bs*16:
|
||||||
target_srcm = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
|
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True)
|
||||||
target_srcm_em = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
|
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True)
|
||||||
|
|
||||||
target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] )
|
target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] )
|
||||||
target_dstm = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
target_srcm = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
|
||||||
target_dstm_em = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
|
target_srcm_em = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
|
||||||
|
|
||||||
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
|
target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] )
|
||||||
self.last_src_samples_loss = []
|
target_dstm = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
|
||||||
self.last_dst_samples_loss = []
|
target_dstm_em = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
|
||||||
|
|
||||||
|
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
|
||||||
|
self.last_src_samples_loss = []
|
||||||
|
self.last_dst_samples_loss = []
|
||||||
|
|
||||||
if self.options['true_face_power'] != 0 and not self.pretrain:
|
if self.options['true_face_power'] != 0 and not self.pretrain:
|
||||||
self.D_train (warped_src, warped_dst)
|
self.D_train (warped_src, warped_dst)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue