mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
change 'epoch' to 'iter',
added timestamp prefix to training string
This commit is contained in:
parent
69174a48e0
commit
97b6fabaab
7 changed files with 93 additions and 87 deletions
|
@ -40,30 +40,33 @@ def trainerThread (s2c, c2s, args, device_args):
|
||||||
debug=debug,
|
debug=debug,
|
||||||
device_args=device_args)
|
device_args=device_args)
|
||||||
|
|
||||||
is_reached_goal = model.is_reached_epoch_goal()
|
is_reached_goal = model.is_reached_iter_goal()
|
||||||
is_upd_save_time_after_train = False
|
is_upd_save_time_after_train = False
|
||||||
|
loss_string = ""
|
||||||
def model_save():
|
def model_save():
|
||||||
if not debug and not is_reached_goal:
|
if not debug and not is_reached_goal:
|
||||||
|
io.log_info ("Saving....", end='\r')
|
||||||
model.save()
|
model.save()
|
||||||
|
io.log_info(loss_string)
|
||||||
is_upd_save_time_after_train = True
|
is_upd_save_time_after_train = True
|
||||||
|
|
||||||
def send_preview():
|
def send_preview():
|
||||||
if not debug:
|
if not debug:
|
||||||
previews = model.get_previews()
|
previews = model.get_previews()
|
||||||
c2s.put ( {'op':'show', 'previews': previews, 'epoch':model.get_epoch(), 'loss_history': model.get_loss_history().copy() } )
|
c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } )
|
||||||
else:
|
else:
|
||||||
previews = [( 'debug, press update for new', model.debug_one_epoch())]
|
previews = [( 'debug, press update for new', model.debug_one_iter())]
|
||||||
c2s.put ( {'op':'show', 'previews': previews} )
|
c2s.put ( {'op':'show', 'previews': previews} )
|
||||||
|
|
||||||
|
|
||||||
if model.is_first_run():
|
if model.is_first_run():
|
||||||
model_save()
|
model_save()
|
||||||
|
|
||||||
if model.get_target_epoch() != 0:
|
if model.get_target_iter() != 0:
|
||||||
if is_reached_goal:
|
if is_reached_goal:
|
||||||
io.log_info('Model already trained to target epoch. You can use preview.')
|
io.log_info('Model already trained to target iteration. You can use preview.')
|
||||||
else:
|
else:
|
||||||
io.log_info('Starting. Target epoch: %d. Press "Enter" to stop training and save model.' % ( model.get_target_epoch() ) )
|
io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) )
|
||||||
else:
|
else:
|
||||||
io.log_info('Starting. Press "Enter" to stop training and save model.')
|
io.log_info('Starting. Press "Enter" to stop training and save model.')
|
||||||
|
|
||||||
|
@ -72,14 +75,14 @@ def trainerThread (s2c, c2s, args, device_args):
|
||||||
for i in itertools.count(0,1):
|
for i in itertools.count(0,1):
|
||||||
if not debug:
|
if not debug:
|
||||||
if not is_reached_goal:
|
if not is_reached_goal:
|
||||||
loss_string = model.train_one_epoch()
|
loss_string = model.train_one_iter()
|
||||||
if is_upd_save_time_after_train:
|
if is_upd_save_time_after_train:
|
||||||
#save resets plaidML programs, so upd last_save_time only after plaidML rebuild them
|
#save resets plaidML programs, so upd last_save_time only after plaidML rebuild them
|
||||||
last_save_time = time.time()
|
last_save_time = time.time()
|
||||||
|
|
||||||
io.log_info (loss_string, end='\r')
|
io.log_info (loss_string, end='\r')
|
||||||
if model.get_target_epoch() != 0 and model.is_reached_epoch_goal():
|
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
|
||||||
io.log_info ('Reached target epoch.')
|
io.log_info ('Reached target iteration.')
|
||||||
model_save()
|
model_save()
|
||||||
is_reached_goal = True
|
is_reached_goal = True
|
||||||
io.log_info ('You can use preview now.')
|
io.log_info ('You can use preview now.')
|
||||||
|
@ -91,7 +94,7 @@ def trainerThread (s2c, c2s, args, device_args):
|
||||||
|
|
||||||
if i==0:
|
if i==0:
|
||||||
if is_reached_goal:
|
if is_reached_goal:
|
||||||
model.pass_one_epoch()
|
model.pass_one_iter()
|
||||||
send_preview()
|
send_preview()
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
|
@ -104,7 +107,7 @@ def trainerThread (s2c, c2s, args, device_args):
|
||||||
model_save()
|
model_save()
|
||||||
elif op == 'preview':
|
elif op == 'preview':
|
||||||
if is_reached_goal:
|
if is_reached_goal:
|
||||||
model.pass_one_epoch()
|
model.pass_one_iter()
|
||||||
send_preview()
|
send_preview()
|
||||||
elif op == 'close':
|
elif op == 'close':
|
||||||
model_save()
|
model_save()
|
||||||
|
@ -156,8 +159,8 @@ def main(args, device_args):
|
||||||
update_preview = False
|
update_preview = False
|
||||||
is_showing = False
|
is_showing = False
|
||||||
is_waiting_preview = False
|
is_waiting_preview = False
|
||||||
show_last_history_epochs_count = 0
|
show_last_history_iters_count = 0
|
||||||
epoch = 0
|
iter = 0
|
||||||
while True:
|
while True:
|
||||||
if not c2s.empty():
|
if not c2s.empty():
|
||||||
input = c2s.get()
|
input = c2s.get()
|
||||||
|
@ -166,7 +169,7 @@ def main(args, device_args):
|
||||||
is_waiting_preview = False
|
is_waiting_preview = False
|
||||||
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
|
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
|
||||||
previews = input['previews'] if 'previews' in input.keys() else None
|
previews = input['previews'] if 'previews' in input.keys() else None
|
||||||
epoch = input['epoch'] if 'epoch' in input.keys() else 0
|
iter = input['iter'] if 'iter' in input.keys() else 0
|
||||||
if previews is not None:
|
if previews is not None:
|
||||||
max_w = 0
|
max_w = 0
|
||||||
max_h = 0
|
max_h = 0
|
||||||
|
@ -217,12 +220,12 @@ def main(args, device_args):
|
||||||
final = head
|
final = head
|
||||||
|
|
||||||
if loss_history is not None:
|
if loss_history is not None:
|
||||||
if show_last_history_epochs_count == 0:
|
if show_last_history_iters_count == 0:
|
||||||
loss_history_to_show = loss_history
|
loss_history_to_show = loss_history
|
||||||
else:
|
else:
|
||||||
loss_history_to_show = loss_history[-show_last_history_epochs_count:]
|
loss_history_to_show = loss_history[-show_last_history_iters_count:]
|
||||||
|
|
||||||
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, epoch, w, c)
|
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c)
|
||||||
final = np.concatenate ( [final, lh_img], axis=0 )
|
final = np.concatenate ( [final, lh_img], axis=0 )
|
||||||
|
|
||||||
final = np.concatenate ( [final, selected_preview_rgb], axis=0 )
|
final = np.concatenate ( [final, selected_preview_rgb], axis=0 )
|
||||||
|
@ -243,16 +246,16 @@ def main(args, device_args):
|
||||||
is_waiting_preview = True
|
is_waiting_preview = True
|
||||||
s2c.put ( {'op': 'preview'} )
|
s2c.put ( {'op': 'preview'} )
|
||||||
elif key == ord('l'):
|
elif key == ord('l'):
|
||||||
if show_last_history_epochs_count == 0:
|
if show_last_history_iters_count == 0:
|
||||||
show_last_history_epochs_count = 5000
|
show_last_history_iters_count = 5000
|
||||||
elif show_last_history_epochs_count == 5000:
|
elif show_last_history_iters_count == 5000:
|
||||||
show_last_history_epochs_count = 10000
|
show_last_history_iters_count = 10000
|
||||||
elif show_last_history_epochs_count == 10000:
|
elif show_last_history_iters_count == 10000:
|
||||||
show_last_history_epochs_count = 50000
|
show_last_history_iters_count = 50000
|
||||||
elif show_last_history_epochs_count == 50000:
|
elif show_last_history_iters_count == 50000:
|
||||||
show_last_history_epochs_count = 100000
|
show_last_history_iters_count = 100000
|
||||||
elif show_last_history_epochs_count == 100000:
|
elif show_last_history_iters_count == 100000:
|
||||||
show_last_history_epochs_count = 0
|
show_last_history_iters_count = 0
|
||||||
update_preview = True
|
update_preview = True
|
||||||
elif key == ord(' '):
|
elif key == ord(' '):
|
||||||
selected_preview = (selected_preview + 1) % len(previews)
|
selected_preview = (selected_preview + 1) % len(previews)
|
||||||
|
|
|
@ -51,53 +51,57 @@ class ModelBase(object):
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None)
|
self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None)
|
||||||
|
|
||||||
self.epoch = 0
|
self.iter = 0
|
||||||
self.options = {}
|
self.options = {}
|
||||||
self.loss_history = []
|
self.loss_history = []
|
||||||
self.sample_for_preview = None
|
self.sample_for_preview = None
|
||||||
if self.model_data_path.exists():
|
if self.model_data_path.exists():
|
||||||
model_data = pickle.loads ( self.model_data_path.read_bytes() )
|
model_data = pickle.loads ( self.model_data_path.read_bytes() )
|
||||||
self.epoch = model_data['epoch']
|
self.iter = max( model_data.get('iter',0), model_data.get('epoch',0) )
|
||||||
if self.epoch != 0:
|
if 'epoch' in self.options:
|
||||||
|
self.options.pop('epoch')
|
||||||
|
if self.iter != 0:
|
||||||
self.options = model_data['options']
|
self.options = model_data['options']
|
||||||
self.loss_history = model_data['loss_history'] if 'loss_history' in model_data.keys() else []
|
self.loss_history = model_data['loss_history'] if 'loss_history' in model_data.keys() else []
|
||||||
self.sample_for_preview = model_data['sample_for_preview'] if 'sample_for_preview' in model_data.keys() else None
|
self.sample_for_preview = model_data['sample_for_preview'] if 'sample_for_preview' in model_data.keys() else None
|
||||||
|
|
||||||
ask_override = self.is_training_mode and self.epoch != 0 and io.input_in_time ("Press enter in 2 seconds to override model settings.", 2)
|
ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time ("Press enter in 2 seconds to override model settings.", 2)
|
||||||
|
|
||||||
yn_str = {True:'y',False:'n'}
|
yn_str = {True:'y',False:'n'}
|
||||||
|
|
||||||
if self.epoch == 0:
|
if self.iter == 0:
|
||||||
io.log_info ("\nModel first run. Enter model options as default for each run.")
|
io.log_info ("\nModel first run. Enter model options as default for each run.")
|
||||||
|
|
||||||
if self.epoch == 0 or ask_override:
|
if self.iter == 0 or ask_override:
|
||||||
default_write_preview_history = False if self.epoch == 0 else self.options.get('write_preview_history',False)
|
default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',False)
|
||||||
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
|
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
|
||||||
else:
|
else:
|
||||||
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
|
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
|
||||||
|
|
||||||
if self.epoch == 0 or ask_override:
|
if self.iter == 0 or ask_override:
|
||||||
self.options['target_epoch'] = max(0, io.input_int("Target epoch (skip:unlimited/default) : ", 0))
|
self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0))
|
||||||
else:
|
else:
|
||||||
self.options['target_epoch'] = self.options.get('target_epoch', 0)
|
self.options['target_iter'] = max(model_data.get('target_iter',0), self.options.get('target_epoch',0))
|
||||||
|
if 'target_epoch' in self.options:
|
||||||
|
self.options.pop('target_epoch')
|
||||||
|
|
||||||
if self.epoch == 0 or ask_override:
|
if self.iter == 0 or ask_override:
|
||||||
default_batch_size = 0 if self.epoch == 0 else self.options.get('batch_size',0)
|
default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0)
|
||||||
self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:0/default) : ", default_batch_size, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
|
self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:0/default) : ", default_batch_size, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
|
||||||
else:
|
else:
|
||||||
self.options['batch_size'] = self.options.get('batch_size', 0)
|
self.options['batch_size'] = self.options.get('batch_size', 0)
|
||||||
|
|
||||||
if self.epoch == 0:
|
if self.iter == 0:
|
||||||
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." )
|
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." )
|
||||||
else:
|
else:
|
||||||
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
|
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
|
||||||
|
|
||||||
if self.epoch == 0:
|
if self.iter == 0:
|
||||||
self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
|
self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
|
||||||
else:
|
else:
|
||||||
self.options['random_flip'] = self.options.get('random_flip', True)
|
self.options['random_flip'] = self.options.get('random_flip', True)
|
||||||
|
|
||||||
if self.epoch == 0:
|
if self.iter == 0:
|
||||||
self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
|
self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
|
||||||
else:
|
else:
|
||||||
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
||||||
|
@ -106,9 +110,9 @@ class ModelBase(object):
|
||||||
if not self.options['write_preview_history']:
|
if not self.options['write_preview_history']:
|
||||||
self.options.pop('write_preview_history')
|
self.options.pop('write_preview_history')
|
||||||
|
|
||||||
self.target_epoch = self.options['target_epoch']
|
self.target_iter = self.options['target_iter']
|
||||||
if self.options['target_epoch'] == 0:
|
if self.options['target_iter'] == 0:
|
||||||
self.options.pop('target_epoch')
|
self.options.pop('target_iter')
|
||||||
|
|
||||||
self.batch_size = self.options['batch_size']
|
self.batch_size = self.options['batch_size']
|
||||||
self.sort_by_yaw = self.options['sort_by_yaw']
|
self.sort_by_yaw = self.options['sort_by_yaw']
|
||||||
|
@ -118,7 +122,7 @@ class ModelBase(object):
|
||||||
if self.src_scale_mod == 0:
|
if self.src_scale_mod == 0:
|
||||||
self.options.pop('src_scale_mod')
|
self.options.pop('src_scale_mod')
|
||||||
|
|
||||||
self.onInitializeOptions(self.epoch == 0, ask_override)
|
self.onInitializeOptions(self.iter == 0, ask_override)
|
||||||
|
|
||||||
nnlib.import_all ( nnlib.DeviceConfig(allow_growth=False, **self.device_args) )
|
nnlib.import_all ( nnlib.DeviceConfig(allow_growth=False, **self.device_args) )
|
||||||
self.device_config = nnlib.active_DeviceConfig
|
self.device_config = nnlib.active_DeviceConfig
|
||||||
|
@ -142,7 +146,7 @@ class ModelBase(object):
|
||||||
if not self.preview_history_path.exists():
|
if not self.preview_history_path.exists():
|
||||||
self.preview_history_path.mkdir(exist_ok=True)
|
self.preview_history_path.mkdir(exist_ok=True)
|
||||||
else:
|
else:
|
||||||
if self.epoch == 0:
|
if self.iter == 0:
|
||||||
for filename in Path_utils.get_image_paths(self.preview_history_path):
|
for filename in Path_utils.get_image_paths(self.preview_history_path):
|
||||||
Path(filename).unlink()
|
Path(filename).unlink()
|
||||||
|
|
||||||
|
@ -153,7 +157,7 @@ class ModelBase(object):
|
||||||
if not isinstance(generator, SampleGeneratorBase):
|
if not isinstance(generator, SampleGeneratorBase):
|
||||||
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
|
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
|
||||||
|
|
||||||
if (self.sample_for_preview is None) or (self.epoch == 0):
|
if (self.sample_for_preview is None) or (self.iter == 0):
|
||||||
self.sample_for_preview = self.generate_next_sample()
|
self.sample_for_preview = self.generate_next_sample()
|
||||||
|
|
||||||
model_summary_text = []
|
model_summary_text = []
|
||||||
|
@ -161,7 +165,7 @@ class ModelBase(object):
|
||||||
model_summary_text += ["===== Model summary ====="]
|
model_summary_text += ["===== Model summary ====="]
|
||||||
model_summary_text += ["== Model name: " + self.get_model_name()]
|
model_summary_text += ["== Model name: " + self.get_model_name()]
|
||||||
model_summary_text += ["=="]
|
model_summary_text += ["=="]
|
||||||
model_summary_text += ["== Current epoch: " + str(self.epoch)]
|
model_summary_text += ["== Current iteration: " + str(self.iter)]
|
||||||
model_summary_text += ["=="]
|
model_summary_text += ["=="]
|
||||||
model_summary_text += ["== Model options:"]
|
model_summary_text += ["== Model options:"]
|
||||||
for key in self.options.keys():
|
for key in self.options.keys():
|
||||||
|
@ -210,7 +214,7 @@ class ModelBase(object):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#overridable
|
#overridable
|
||||||
def onTrainOneEpoch(self, sample, generator_list):
|
def onTrainOneIter(self, sample, generator_list):
|
||||||
#train your keras models here
|
#train your keras models here
|
||||||
|
|
||||||
#return array of losses
|
#return array of losses
|
||||||
|
@ -231,11 +235,11 @@ class ModelBase(object):
|
||||||
raise NotImplementeError
|
raise NotImplementeError
|
||||||
#return existing or your own converter which derived from base
|
#return existing or your own converter which derived from base
|
||||||
|
|
||||||
def get_target_epoch(self):
|
def get_target_iter(self):
|
||||||
return self.target_epoch
|
return self.target_iter
|
||||||
|
|
||||||
def is_reached_epoch_goal(self):
|
def is_reached_iter_goal(self):
|
||||||
return self.target_epoch != 0 and self.epoch >= self.target_epoch
|
return self.target_iter != 0 and self.iter >= self.target_iter
|
||||||
|
|
||||||
#multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976
|
#multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976
|
||||||
#def to_multi_gpu_model_if_possible (self, models_list):
|
#def to_multi_gpu_model_if_possible (self, models_list):
|
||||||
|
@ -263,13 +267,13 @@ class ModelBase(object):
|
||||||
return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr
|
return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
io.log_info ("Saving...")
|
io.log_info ("Saving....", end='\r')
|
||||||
|
|
||||||
Path( self.get_strpath_storage_for_file('summary.txt') ).write_text(self.model_summary_text)
|
Path( self.get_strpath_storage_for_file('summary.txt') ).write_text(self.model_summary_text)
|
||||||
self.onSave()
|
self.onSave()
|
||||||
|
|
||||||
model_data = {
|
model_data = {
|
||||||
'epoch': self.epoch,
|
'iter': self.iter,
|
||||||
'options': self.options,
|
'options': self.options,
|
||||||
'loss_history': self.loss_history,
|
'loss_history': self.loss_history,
|
||||||
'sample_for_preview' : self.sample_for_preview
|
'sample_for_preview' : self.sample_for_preview
|
||||||
|
@ -336,7 +340,7 @@ class ModelBase(object):
|
||||||
source_filename.rename ( str(target_filename) )
|
source_filename.rename ( str(target_filename) )
|
||||||
|
|
||||||
|
|
||||||
def debug_one_epoch(self):
|
def debug_one_iter(self):
|
||||||
images = []
|
images = []
|
||||||
for generator in self.generator_list:
|
for generator in self.generator_list:
|
||||||
for i,batch in enumerate(next(generator)):
|
for i,batch in enumerate(next(generator)):
|
||||||
|
@ -348,42 +352,42 @@ class ModelBase(object):
|
||||||
def generate_next_sample(self):
|
def generate_next_sample(self):
|
||||||
return [next(generator) for generator in self.generator_list]
|
return [next(generator) for generator in self.generator_list]
|
||||||
|
|
||||||
def train_one_epoch(self):
|
def train_one_iter(self):
|
||||||
sample = self.generate_next_sample()
|
sample = self.generate_next_sample()
|
||||||
epoch_time = time.time()
|
iter_time = time.time()
|
||||||
losses = self.onTrainOneEpoch(sample, self.generator_list)
|
losses = self.onTrainOneIter(sample, self.generator_list)
|
||||||
epoch_time = time.time() - epoch_time
|
iter_time = time.time() - iter_time
|
||||||
self.last_sample = sample
|
self.last_sample = sample
|
||||||
|
|
||||||
self.loss_history.append ( [float(loss[1]) for loss in losses] )
|
self.loss_history.append ( [float(loss[1]) for loss in losses] )
|
||||||
|
|
||||||
if self.write_preview_history:
|
if self.write_preview_history:
|
||||||
if self.epoch % 10 == 0:
|
if self.iter % 10 == 0:
|
||||||
preview = self.get_static_preview()
|
preview = self.get_static_preview()
|
||||||
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.epoch, preview.shape[1], preview.shape[2])
|
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
|
||||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||||
cv2_imwrite ( str (self.preview_history_path / ('%.6d.jpg' %( self.epoch) )), img )
|
cv2_imwrite ( str (self.preview_history_path / ('%.6d.jpg' %( self.iter) )), img )
|
||||||
|
|
||||||
self.epoch += 1
|
self.iter += 1
|
||||||
|
|
||||||
if epoch_time >= 10:
|
time_str = time.strftime("[%H:%M:%S]")
|
||||||
#............."Saving...
|
if iter_time >= 10:
|
||||||
loss_string = "Training [#{0:06d}][{1:.5s}s]".format ( self.epoch, '{:0.4f}'.format(epoch_time) )
|
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, self.iter, '{:0.4f}'.format(iter_time) )
|
||||||
else:
|
else:
|
||||||
loss_string = "Training [#{0:06d}][{1:04d}ms]".format ( self.epoch, int(epoch_time*1000) )
|
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, self.iter, int(iter_time*1000) )
|
||||||
for (loss_name, loss_value) in losses:
|
for (loss_name, loss_value) in losses:
|
||||||
loss_string += " %s:%.3f" % (loss_name, loss_value)
|
loss_string += " %s:%.3f" % (loss_name, loss_value)
|
||||||
|
|
||||||
return loss_string
|
return loss_string
|
||||||
|
|
||||||
def pass_one_epoch(self):
|
def pass_one_iter(self):
|
||||||
self.last_sample = self.generate_next_sample()
|
self.last_sample = self.generate_next_sample()
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
nnlib.finalize_all()
|
nnlib.finalize_all()
|
||||||
|
|
||||||
def is_first_run(self):
|
def is_first_run(self):
|
||||||
return self.epoch == 0
|
return self.iter == 0
|
||||||
|
|
||||||
def is_debug(self):
|
def is_debug(self):
|
||||||
return self.debug
|
return self.debug
|
||||||
|
@ -394,8 +398,8 @@ class ModelBase(object):
|
||||||
def get_batch_size(self):
|
def get_batch_size(self):
|
||||||
return self.batch_size
|
return self.batch_size
|
||||||
|
|
||||||
def get_epoch(self):
|
def get_iter(self):
|
||||||
return self.epoch
|
return self.iter
|
||||||
|
|
||||||
def get_loss_history(self):
|
def get_loss_history(self):
|
||||||
return self.loss_history
|
return self.loss_history
|
||||||
|
@ -430,7 +434,7 @@ class ModelBase(object):
|
||||||
self.batch_size = d[ keys[-1] ]
|
self.batch_size = d[ keys[-1] ]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_loss_history_preview(loss_history, epoch, w, c):
|
def get_loss_history_preview(loss_history, iter, w, c):
|
||||||
loss_history = np.array (loss_history.copy())
|
loss_history = np.array (loss_history.copy())
|
||||||
|
|
||||||
lh_height = 100
|
lh_height = 100
|
||||||
|
@ -483,7 +487,7 @@ class ModelBase(object):
|
||||||
last_line_t = int((lh_lines-1)*lh_line_height)
|
last_line_t = int((lh_lines-1)*lh_line_height)
|
||||||
last_line_b = int(lh_lines*lh_line_height)
|
last_line_b = int(lh_lines*lh_line_height)
|
||||||
|
|
||||||
lh_text = 'Epoch: %d' % (epoch) if epoch != 0 else ''
|
lh_text = 'Iter: %d' % (iter) if iter != 0 else ''
|
||||||
|
|
||||||
lh_img[last_line_t:last_line_b, 0:w] += image_utils.get_text_image ( (w,last_line_b-last_line_t,c), lh_text, color=[0.8]*c )
|
lh_img[last_line_t:last_line_b, 0:w] += image_utils.get_text_image ( (w,last_line_b-last_line_t,c), lh_text, color=[0.8]*c )
|
||||||
return lh_img
|
return lh_img
|
|
@ -12,7 +12,7 @@ class Model(ModelBase):
|
||||||
def onInitializeOptions(self, is_first_run, ask_override):
|
def onInitializeOptions(self, is_first_run, ask_override):
|
||||||
if is_first_run or ask_override:
|
if is_first_run or ask_override:
|
||||||
def_pixel_loss = self.options.get('pixel_loss', False)
|
def_pixel_loss = self.options.get('pixel_loss', False)
|
||||||
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k epochs to enhance fine details and decrease face jitter.")
|
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k iters to enhance fine details and decrease face jitter.")
|
||||||
else:
|
else:
|
||||||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ class Model(ModelBase):
|
||||||
[self.decoder_dst, 'decoder_dst.h5']] )
|
[self.decoder_dst, 'decoder_dst.h5']] )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneEpoch(self, sample, generators_list):
|
def onTrainOneIter(self, sample, generators_list):
|
||||||
warped_src, target_src, target_src_mask = sample[0]
|
warped_src, target_src, target_src_mask = sample[0]
|
||||||
warped_dst, target_dst, target_dst_mask = sample[1]
|
warped_dst, target_dst, target_dst_mask = sample[1]
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ class Model(ModelBase):
|
||||||
|
|
||||||
if is_first_run or ask_override:
|
if is_first_run or ask_override:
|
||||||
def_pixel_loss = self.options.get('pixel_loss', False)
|
def_pixel_loss = self.options.get('pixel_loss', False)
|
||||||
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k epochs to enhance fine details and decrease face jitter.")
|
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k iters to enhance fine details and decrease face jitter.")
|
||||||
else:
|
else:
|
||||||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ class Model(ModelBase):
|
||||||
[self.decoder_dst, 'decoder_dst.h5']] )
|
[self.decoder_dst, 'decoder_dst.h5']] )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneEpoch(self, sample, generators_list):
|
def onTrainOneIter(self, sample, generators_list):
|
||||||
warped_src, target_src, target_src_mask = sample[0]
|
warped_src, target_src, target_src_mask = sample[0]
|
||||||
warped_dst, target_dst, target_dst_mask = sample[1]
|
warped_dst, target_dst, target_dst_mask = sample[1]
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ class Model(ModelBase):
|
||||||
|
|
||||||
if is_first_run or ask_override:
|
if is_first_run or ask_override:
|
||||||
def_pixel_loss = self.options.get('pixel_loss', False)
|
def_pixel_loss = self.options.get('pixel_loss', False)
|
||||||
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k epochs to enhance fine details and decrease face jitter.")
|
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k iters to enhance fine details and decrease face jitter.")
|
||||||
else:
|
else:
|
||||||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ class Model(ModelBase):
|
||||||
[self.decoder_dst, 'decoder_dst.h5']] )
|
[self.decoder_dst, 'decoder_dst.h5']] )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneEpoch(self, sample, generators_list):
|
def onTrainOneIter(self, sample, generators_list):
|
||||||
warped_src, target_src, target_src_full_mask = sample[0]
|
warped_src, target_src, target_src_full_mask = sample[0]
|
||||||
warped_dst, target_dst, target_dst_full_mask = sample[1]
|
warped_dst, target_dst, target_dst_full_mask = sample[1]
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ class Model(ModelBase):
|
||||||
def onInitializeOptions(self, is_first_run, ask_override):
|
def onInitializeOptions(self, is_first_run, ask_override):
|
||||||
if is_first_run or ask_override:
|
if is_first_run or ask_override:
|
||||||
def_pixel_loss = self.options.get('pixel_loss', False)
|
def_pixel_loss = self.options.get('pixel_loss', False)
|
||||||
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k epochs to enhance fine details and decrease face jitter.")
|
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 20k iters to enhance fine details and decrease face jitter.")
|
||||||
else:
|
else:
|
||||||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ class Model(ModelBase):
|
||||||
[self.inter_AB, 'inter_AB.h5']] )
|
[self.inter_AB, 'inter_AB.h5']] )
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneEpoch(self, sample, generators_list):
|
def onTrainOneIter(self, sample, generators_list):
|
||||||
warped_src, target_src, target_src_mask = sample[0]
|
warped_src, target_src, target_src_mask = sample[0]
|
||||||
warped_dst, target_dst, target_dst_mask = sample[1]
|
warped_dst, target_dst, target_dst_mask = sample[1]
|
||||||
|
|
||||||
|
|
|
@ -77,11 +77,11 @@ class SAEModel(ModelBase):
|
||||||
default_bg_style_power = 0.0
|
default_bg_style_power = 0.0
|
||||||
if is_first_run or ask_override:
|
if is_first_run or ask_override:
|
||||||
def_pixel_loss = self.options.get('pixel_loss', False)
|
def_pixel_loss = self.options.get('pixel_loss', False)
|
||||||
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: %s ) : " % (yn_str[def_pixel_loss]), def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 15-25k epochs to enhance fine details and decrease face jitter.")
|
self.options['pixel_loss'] = io.input_bool ("Use pixel loss? (y/n, ?:help skip: %s ) : " % (yn_str[def_pixel_loss]), def_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 15-25k iters to enhance fine details and decrease face jitter.")
|
||||||
|
|
||||||
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
|
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
|
||||||
self.options['face_style_power'] = np.clip ( io.input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_face_style_power), default_face_style_power,
|
self.options['face_style_power'] = np.clip ( io.input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_face_style_power), default_face_style_power,
|
||||||
help_message="Learn to transfer face style details such as light and color conditions. Warning: Enable it only after 10k epochs, when predicted face is clear enough to start learn style. Start from 0.1 value and check history changes."), 0.0, 100.0 )
|
help_message="Learn to transfer face style details such as light and color conditions. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.1 value and check history changes."), 0.0, 100.0 )
|
||||||
|
|
||||||
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
|
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
|
||||||
self.options['bg_style_power'] = np.clip ( io.input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power,
|
self.options['bg_style_power'] = np.clip ( io.input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power,
|
||||||
|
@ -107,7 +107,6 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
masked_training = True
|
masked_training = True
|
||||||
|
|
||||||
epoch_alpha = Input( (1,) )
|
|
||||||
warped_src = Input(bgr_shape)
|
warped_src = Input(bgr_shape)
|
||||||
target_src = Input(bgr_shape)
|
target_src = Input(bgr_shape)
|
||||||
target_srcm = Input(mask_shape)
|
target_srcm = Input(mask_shape)
|
||||||
|
@ -395,7 +394,7 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneEpoch(self, generators_samples, generators_list):
|
def onTrainOneIter(self, generators_samples, generators_list):
|
||||||
src_samples = generators_samples[0]
|
src_samples = generators_samples[0]
|
||||||
dst_samples = generators_samples[1]
|
dst_samples = generators_samples[1]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue