mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
Merge branch 'master' into feat/random-ct-only-on-face-mask
# Conflicts: # models/Model_SAE/Model.py
This commit is contained in:
commit
2c28c6eb3f
4 changed files with 316 additions and 176 deletions
17
main.py
17
main.py
|
@ -134,16 +134,23 @@ if __name__ == "__main__":
|
||||||
Trainer.main(args, device_args)
|
Trainer.main(args, device_args)
|
||||||
|
|
||||||
p = subparsers.add_parser( "train", help="Trainer")
|
p = subparsers.add_parser( "train", help="Trainer")
|
||||||
p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir", help="Dir of extracted SRC faceset.")
|
p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir",
|
||||||
p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir", help="Dir of extracted DST faceset.")
|
help="Dir of extracted SRC faceset.")
|
||||||
p.add_argument('--pretraining-data-dir', action=fixPathAction, dest="pretraining_data_dir", default=None, help="Optional dir of extracted faceset that will be used in pretraining mode.")
|
p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir",
|
||||||
|
help="Dir of extracted DST faceset.")
|
||||||
|
p.add_argument('--pretraining-data-dir', action=fixPathAction, dest="pretraining_data_dir", default=None,
|
||||||
|
help="Optional dir of extracted faceset that will be used in pretraining mode.")
|
||||||
p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.")
|
p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.")
|
||||||
p.add_argument('--model', required=True, dest="model_name", choices=Path_utils.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Type of model")
|
p.add_argument('--model', required=True, dest="model_name", choices=Path_utils.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'),
|
||||||
p.add_argument('--no-preview', action="store_true", dest="no_preview", default=False, help="Disable preview window.")
|
help="Type of model")
|
||||||
|
p.add_argument('--no-preview', action="store_true", dest="no_preview", default=False,
|
||||||
|
help="Disable preview window.")
|
||||||
p.add_argument('--debug', action="store_true", dest="debug", default=False, help="Debug samples.")
|
p.add_argument('--debug', action="store_true", dest="debug", default=False, help="Debug samples.")
|
||||||
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
|
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
|
||||||
p.add_argument('--force-gpu-idx', type=int, dest="force_gpu_idx", default=-1, help="Force to choose this GPU idx.")
|
p.add_argument('--force-gpu-idx', type=int, dest="force_gpu_idx", default=-1, help="Force to choose this GPU idx.")
|
||||||
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
|
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
|
||||||
|
p.add_argument('--pingpong', dest="ping_pong", default=False,
|
||||||
|
help="Cycle between a batch size of 1 and the chosen batch size")
|
||||||
p.set_defaults (func=process_train)
|
p.set_defaults (func=process_train)
|
||||||
|
|
||||||
def process_convert(arguments):
|
def process_convert(arguments):
|
||||||
|
|
|
@ -104,14 +104,14 @@ def trainerThread (s2c, c2s, args, device_args):
|
||||||
print("Unable to execute program: %s" % (prog) )
|
print("Unable to execute program: %s" % (prog) )
|
||||||
|
|
||||||
if not is_reached_goal:
|
if not is_reached_goal:
|
||||||
iter, iter_time = model.train_one_iter()
|
iter, iter_time, batch_size = model.train_one_iter()
|
||||||
|
|
||||||
loss_history = model.get_loss_history()
|
loss_history = model.get_loss_history()
|
||||||
time_str = time.strftime("[%H:%M:%S]")
|
time_str = time.strftime("[%H:%M:%S]")
|
||||||
if iter_time >= 10:
|
if iter_time >= 10:
|
||||||
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) )
|
loss_string = "{0}[#{1:06d}][{2:.5s}s][bs: {3}]".format ( time_str, iter, '{:0.4f}'.format(iter_time), batch_size )
|
||||||
else:
|
else:
|
||||||
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) )
|
loss_string = "{0}[#{1:06d}][{2:04d}ms][bs: {3}]".format ( time_str, iter, int(iter_time*1000), batch_size)
|
||||||
|
|
||||||
if shared_state['after_save']:
|
if shared_state['after_save']:
|
||||||
shared_state['after_save'] = False
|
shared_state['after_save'] = False
|
||||||
|
@ -186,6 +186,7 @@ def main(args, device_args):
|
||||||
|
|
||||||
no_preview = args.get('no_preview', False)
|
no_preview = args.get('no_preview', False)
|
||||||
|
|
||||||
|
|
||||||
s2c = queue.Queue()
|
s2c = queue.Queue()
|
||||||
c2s = queue.Queue()
|
c2s = queue.Queue()
|
||||||
|
|
||||||
|
@ -216,6 +217,7 @@ def main(args, device_args):
|
||||||
is_waiting_preview = False
|
is_waiting_preview = False
|
||||||
show_last_history_iters_count = 0
|
show_last_history_iters_count = 0
|
||||||
iter = 0
|
iter = 0
|
||||||
|
batch_size = 1
|
||||||
while True:
|
while True:
|
||||||
if not c2s.empty():
|
if not c2s.empty():
|
||||||
input = c2s.get()
|
input = c2s.get()
|
||||||
|
@ -225,6 +227,7 @@ def main(args, device_args):
|
||||||
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
|
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
|
||||||
previews = input['previews'] if 'previews' in input.keys() else None
|
previews = input['previews'] if 'previews' in input.keys() else None
|
||||||
iter = input['iter'] if 'iter' in input.keys() else 0
|
iter = input['iter'] if 'iter' in input.keys() else 0
|
||||||
|
#batch_size = input['batch_size'] if 'iter' in input.keys() else 1
|
||||||
if previews is not None:
|
if previews is not None:
|
||||||
max_w = 0
|
max_w = 0
|
||||||
max_h = 0
|
max_h = 0
|
||||||
|
@ -280,7 +283,7 @@ def main(args, device_args):
|
||||||
else:
|
else:
|
||||||
loss_history_to_show = loss_history[-show_last_history_iters_count:]
|
loss_history_to_show = loss_history[-show_last_history_iters_count:]
|
||||||
|
|
||||||
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c)
|
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, batch_size, w, c)
|
||||||
final = np.concatenate ( [final, lh_img], axis=0 )
|
final = np.concatenate ( [final, lh_img], axis=0 )
|
||||||
|
|
||||||
final = np.concatenate ( [final, selected_preview_rgb], axis=0 )
|
final = np.concatenate ( [final, selected_preview_rgb], axis=0 )
|
||||||
|
|
|
@ -23,39 +23,41 @@ You can implement your own model. Check examples.
|
||||||
class ModelBase(object):
|
class ModelBase(object):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None,
|
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None,
|
||||||
ask_enable_autobackup=True,
|
ask_enable_autobackup=True,
|
||||||
ask_write_preview_history=True,
|
ask_write_preview_history=True,
|
||||||
ask_target_iter=True,
|
ask_target_iter=True,
|
||||||
ask_batch_size=True,
|
ask_batch_size=True,
|
||||||
ask_sort_by_yaw=True,
|
ask_sort_by_yaw=True,
|
||||||
ask_random_flip=True,
|
ask_random_flip=True,
|
||||||
ask_src_scale_mod=True):
|
ask_src_scale_mod=True):
|
||||||
|
|
||||||
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
|
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx', -1)
|
||||||
device_args['cpu_only'] = device_args.get('cpu_only',False)
|
device_args['cpu_only'] = device_args.get('cpu_only', False)
|
||||||
|
|
||||||
if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
|
if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
|
||||||
idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
|
idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
|
||||||
if len(idxs_names_list) > 1:
|
if len(idxs_names_list) > 1:
|
||||||
io.log_info ("You have multi GPUs in a system: ")
|
io.log_info("You have multi GPUs in a system: ")
|
||||||
for idx, name in idxs_names_list:
|
for idx, name in idxs_names_list:
|
||||||
io.log_info ("[%d] : %s" % (idx, name) )
|
io.log_info("[%d] : %s" % (idx, name))
|
||||||
|
|
||||||
device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1, [ x[0] for x in idxs_names_list] )
|
device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1,
|
||||||
|
[x[0] for x in idxs_names_list])
|
||||||
self.device_args = device_args
|
self.device_args = device_args
|
||||||
|
|
||||||
self.device_config = nnlib.DeviceConfig(allow_growth=True, **self.device_args)
|
self.device_config = nnlib.DeviceConfig(allow_growth=True, **self.device_args)
|
||||||
|
|
||||||
io.log_info ("Loading model...")
|
io.log_info("Loading model...")
|
||||||
|
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') )
|
self.model_data_path = Path(self.get_strpath_storage_for_file('data.dat'))
|
||||||
|
|
||||||
self.training_data_src_path = training_data_src_path
|
self.training_data_src_path = training_data_src_path
|
||||||
self.training_data_dst_path = training_data_dst_path
|
self.training_data_dst_path = training_data_dst_path
|
||||||
self.pretraining_data_path = pretraining_data_path
|
self.pretraining_data_path = pretraining_data_path
|
||||||
|
|
||||||
self.src_images_paths = None
|
self.src_images_paths = None
|
||||||
self.dst_images_paths = None
|
self.dst_images_paths = None
|
||||||
self.src_yaw_images_paths = None
|
self.src_yaw_images_paths = None
|
||||||
|
@ -65,6 +67,8 @@ class ModelBase(object):
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None)
|
self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None)
|
||||||
|
|
||||||
|
self.paddle = 'pong'
|
||||||
|
|
||||||
self.iter = 0
|
self.iter = 0
|
||||||
self.options = {}
|
self.options = {}
|
||||||
self.loss_history = []
|
self.loss_history = []
|
||||||
|
@ -72,8 +76,8 @@ class ModelBase(object):
|
||||||
|
|
||||||
model_data = {}
|
model_data = {}
|
||||||
if self.model_data_path.exists():
|
if self.model_data_path.exists():
|
||||||
model_data = pickle.loads ( self.model_data_path.read_bytes() )
|
model_data = pickle.loads(self.model_data_path.read_bytes())
|
||||||
self.iter = max( model_data.get('iter',0), model_data.get('epoch',0) )
|
self.iter = max(model_data.get('iter', 0), model_data.get('epoch', 0))
|
||||||
if 'epoch' in self.options:
|
if 'epoch' in self.options:
|
||||||
self.options.pop('epoch')
|
self.options.pop('epoch')
|
||||||
if self.iter != 0:
|
if self.iter != 0:
|
||||||
|
@ -81,62 +85,105 @@ class ModelBase(object):
|
||||||
self.loss_history = model_data.get('loss_history', [])
|
self.loss_history = model_data.get('loss_history', [])
|
||||||
self.sample_for_preview = model_data.get('sample_for_preview', None)
|
self.sample_for_preview = model_data.get('sample_for_preview', None)
|
||||||
|
|
||||||
ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time ("Press enter in 2 seconds to override model settings.", 5 if io.is_colab() else 2 )
|
ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time("Press enter in 2 seconds to"
|
||||||
|
" override model settings.",
|
||||||
|
5 if io.is_colab() else 2)
|
||||||
|
|
||||||
yn_str = {True:'y',False:'n'}
|
yn_str = {True: 'y', False: 'n'}
|
||||||
|
|
||||||
if self.iter == 0:
|
if self.iter == 0:
|
||||||
io.log_info ("\nModel first run. Enter model options as default for each run.")
|
io.log_info("\nModel first run. Enter model options as default for each run.")
|
||||||
|
|
||||||
if ask_enable_autobackup and (self.iter == 0 or ask_override):
|
if ask_enable_autobackup and (self.iter == 0 or ask_override):
|
||||||
default_autobackup = False if self.iter == 0 else self.options.get('autobackup',False)
|
default_autobackup = False if self.iter == 0 else self.options.get('autobackup', False)
|
||||||
self.options['autobackup'] = io.input_bool("Enable autobackup? (y/n ?:help skip:%s) : " % (yn_str[default_autobackup]) , default_autobackup, help_message="Autobackup model files with preview every hour for last 15 hours. Latest backup located in model/<>_autobackups/01")
|
self.options['autobackup'] = io.input_bool("Enable autobackup? (y/n ?:help skip:%s) : " %
|
||||||
|
(yn_str[default_autobackup]), default_autobackup,
|
||||||
|
help_message="Autobackup model files with preview every hour for"
|
||||||
|
" last 15 hours. Latest backup located in model/<>"
|
||||||
|
"_autobackups/01")
|
||||||
else:
|
else:
|
||||||
self.options['autobackup'] = self.options.get('autobackup', False)
|
self.options['autobackup'] = self.options.get('autobackup', False)
|
||||||
|
|
||||||
if ask_write_preview_history and (self.iter == 0 or ask_override):
|
if ask_write_preview_history and (self.iter == 0 or ask_override):
|
||||||
default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',False)
|
default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',
|
||||||
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
|
False)
|
||||||
|
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " %
|
||||||
|
(yn_str[default_write_preview_history]),
|
||||||
|
default_write_preview_history,
|
||||||
|
help_message="Preview history will be writed to"
|
||||||
|
" <ModelName>_history folder.")
|
||||||
else:
|
else:
|
||||||
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
|
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
|
||||||
|
|
||||||
if (self.iter == 0 or ask_override) and self.options['write_preview_history'] and io.is_support_windows():
|
if (self.iter == 0 or ask_override) and self.options['write_preview_history'] and io.is_support_windows():
|
||||||
choose_preview_history = io.input_bool("Choose image for the preview history? (y/n skip:%s) : " % (yn_str[False]) , False)
|
choose_preview_history = io.input_bool("Choose image for the preview history?"
|
||||||
|
" (y/n skip:%s) : " % (yn_str[False]), False)
|
||||||
|
|
||||||
elif (self.iter == 0 or ask_override) and self.options['write_preview_history'] and io.is_colab():
|
elif (self.iter == 0 or ask_override) and self.options['write_preview_history'] and io.is_colab():
|
||||||
choose_preview_history = io.input_bool("Randomly choose new image for preview history? (y/n ?:help skip:%s) : " % (yn_str[False]), False, help_message="Preview image history will stay stuck with old faces if you reuse the same model on different celebs. Choose no unless you are changing src/dst to a new person")
|
choose_preview_history = io.input_bool("Randomly choose new image for preview history? (y/n ?:help skip:%s)"
|
||||||
|
": " % (yn_str[False]), False,
|
||||||
|
help_message="Preview image history will stay stuck with old faces"
|
||||||
|
" if you reuse the same model on different celebs."
|
||||||
|
" Choose no unless you are changing src/dst to a"
|
||||||
|
" new person")
|
||||||
else:
|
else:
|
||||||
choose_preview_history = False
|
choose_preview_history = False
|
||||||
|
|
||||||
if ask_target_iter:
|
if ask_target_iter:
|
||||||
if (self.iter == 0 or ask_override):
|
if (self.iter == 0 or ask_override):
|
||||||
self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0))
|
self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0))
|
||||||
else:
|
else:
|
||||||
self.options['target_iter'] = max(model_data.get('target_iter',0), self.options.get('target_epoch',0))
|
self.options['target_iter'] = max(model_data.get('target_iter', 0), self.options.get('target_epoch', 0))
|
||||||
if 'target_epoch' in self.options:
|
if 'target_epoch' in self.options:
|
||||||
self.options.pop('target_epoch')
|
self.options.pop('target_epoch')
|
||||||
|
|
||||||
if ask_batch_size and (self.iter == 0 or ask_override):
|
if ask_batch_size and (self.iter == 0 or ask_override):
|
||||||
default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0)
|
default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size', 0)
|
||||||
self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % (default_batch_size), default_batch_size, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
|
self.options['batch_cap'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % 0,
|
||||||
else:
|
0,
|
||||||
self.options['batch_size'] = self.options.get('batch_size', 0)
|
help_message="Larger batch size is better for NN's"
|
||||||
|
" generalization, but it can cause Out of"
|
||||||
|
" Memory error. Tune this value for your"
|
||||||
|
" videocard manually."))
|
||||||
|
self.options['ping_pong'] = io.input_bool(
|
||||||
|
"Enable ping-pong? (y/n ?:help skip:%s) : " % (yn_str[True]),
|
||||||
|
True,
|
||||||
|
help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence")
|
||||||
|
self.options['paddle'] = self.options.get('paddle','ping')
|
||||||
|
if self.options.get('ping_pong',True):
|
||||||
|
self.options['ping_pong_iter'] = max(0, io.input_int("Ping-pong iteration (skip:1000/default) : ", 1000))
|
||||||
|
|
||||||
if ask_sort_by_yaw:
|
else:
|
||||||
|
self.options['batch_cap'] = self.options.get('batch_cap', 16)
|
||||||
|
self.options['ping_pong'] = self.options.get('ping_pong', True)
|
||||||
|
self.options['ping_pong_iter'] = self.options.get('ping_pong_iter',1000)
|
||||||
|
|
||||||
|
if ask_sort_by_yaw:
|
||||||
if (self.iter == 0 or ask_override):
|
if (self.iter == 0 or ask_override):
|
||||||
default_sort_by_yaw = self.options.get('sort_by_yaw', False)
|
default_sort_by_yaw = self.options.get('sort_by_yaw', False)
|
||||||
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s) : " % (yn_str[default_sort_by_yaw]), default_sort_by_yaw, help_message="NN will not learn src face directions that don't match dst face directions. Do not use if the dst face has hair that covers the jaw." )
|
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s):"
|
||||||
|
" " % (yn_str[default_sort_by_yaw]), default_sort_by_yaw,
|
||||||
|
help_message="NN will not learn src face directions that"
|
||||||
|
" don't match dst face directions. Do not use "
|
||||||
|
"if the dst face has hair that covers the jaw")
|
||||||
else:
|
else:
|
||||||
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
|
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
|
||||||
|
|
||||||
if ask_override:
|
|
||||||
|
if self.iter == 0 or ask_override:
|
||||||
self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
|
self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
|
||||||
|
|
||||||
|
|
||||||
if ask_src_scale_mod:
|
if ask_src_scale_mod:
|
||||||
if (self.iter == 0):
|
if self.iter == 0:
|
||||||
self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
|
self.options['src_scale_mod'] = np.clip(io.input_int("Src face scale modifier %"
|
||||||
|
" ( -30...30, ?:help skip:0) : ", 0,
|
||||||
|
help_message="If src face shape is wider than"
|
||||||
|
" dst, try to decrease this value to"
|
||||||
|
" get a better result."), -30, 30)
|
||||||
else:
|
else:
|
||||||
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
||||||
|
|
||||||
self.autobackup = self.options.get('autobackup', False)
|
self.autobackup = self.options.get('autobackup', False)
|
||||||
if not self.autobackup and 'autobackup' in self.options:
|
if not self.autobackup and 'autobackup' in self.options:
|
||||||
self.options.pop('autobackup')
|
self.options.pop('autobackup')
|
||||||
|
@ -145,15 +192,20 @@ class ModelBase(object):
|
||||||
if not self.write_preview_history and 'write_preview_history' in self.options:
|
if not self.write_preview_history and 'write_preview_history' in self.options:
|
||||||
self.options.pop('write_preview_history')
|
self.options.pop('write_preview_history')
|
||||||
|
|
||||||
self.target_iter = self.options.get('target_iter',0)
|
self.target_iter = self.options.get('target_iter', 0)
|
||||||
if self.target_iter == 0 and 'target_iter' in self.options:
|
if self.target_iter == 0 and 'target_iter' in self.options:
|
||||||
self.options.pop('target_iter')
|
self.options.pop('target_iter')
|
||||||
|
|
||||||
self.batch_size = self.options.get('batch_size',0)
|
self.batch_size = self.options.get('batch_size', 8)
|
||||||
self.sort_by_yaw = self.options.get('sort_by_yaw',False)
|
self.batch_cap = self.options.get('batch_cap',16)
|
||||||
self.random_flip = self.options.get('random_flip',True)
|
self.ping_pong_iter = self.options.get('ping_pong_iter',1000)
|
||||||
|
self.sort_by_yaw = self.options.get('sort_by_yaw', False)
|
||||||
|
self.random_flip = self.options.get('random_flip', True)
|
||||||
|
if self.batch_cap == 0:
|
||||||
|
self.options['batch_cap'] = self.batch_size
|
||||||
|
self.batch_cap = self.options.get('batch_cap',16)
|
||||||
|
|
||||||
self.src_scale_mod = self.options.get('src_scale_mod',0)
|
self.src_scale_mod = self.options.get('src_scale_mod', 0)
|
||||||
if self.src_scale_mod == 0 and 'src_scale_mod' in self.options:
|
if self.src_scale_mod == 0 and 'src_scale_mod' in self.options:
|
||||||
self.options.pop('src_scale_mod')
|
self.options.pop('src_scale_mod')
|
||||||
|
|
||||||
|
@ -166,21 +218,24 @@ class ModelBase(object):
|
||||||
self.onInitialize()
|
self.onInitialize()
|
||||||
|
|
||||||
self.options['batch_size'] = self.batch_size
|
self.options['batch_size'] = self.batch_size
|
||||||
|
self.paddle = self.options.get('paddle', 'ping')
|
||||||
|
|
||||||
if self.debug or self.batch_size == 0:
|
if self.debug or self.batch_size == 0:
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
if self.device_args['force_gpu_idx'] == -1:
|
if self.device_args['force_gpu_idx'] == -1:
|
||||||
self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) )
|
self.preview_history_path = self.model_path / ('%s_history' % (self.get_model_name()))
|
||||||
self.autobackups_path = self.model_path / ( '%s_autobackups' % (self.get_model_name()) )
|
self.autobackups_path = self.model_path / ('%s_autobackups' % (self.get_model_name()))
|
||||||
else:
|
else:
|
||||||
self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
|
self.preview_history_path = self.model_path / ('%d_%s_history' % (self.device_args['force_gpu_idx'],
|
||||||
self.autobackups_path = self.model_path / ( '%d_%s_autobackups' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
|
self.get_model_name()))
|
||||||
|
self.autobackups_path = self.model_path / ('%d_%s_autobackups' % (self.device_args['force_gpu_idx'],
|
||||||
|
self.get_model_name()))
|
||||||
|
|
||||||
if self.autobackup:
|
if self.autobackup:
|
||||||
self.autobackup_current_hour = time.localtime().tm_hour
|
self.autobackup_current_hour = time.localtime().tm_hour
|
||||||
|
|
||||||
if not self.autobackups_path.exists():
|
if not self.autobackups_path.exists():
|
||||||
self.autobackups_path.mkdir(exist_ok=True)
|
self.autobackups_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
@ -193,13 +248,13 @@ class ModelBase(object):
|
||||||
Path(filename).unlink()
|
Path(filename).unlink()
|
||||||
|
|
||||||
if self.generator_list is None:
|
if self.generator_list is None:
|
||||||
raise ValueError( 'You didnt set_training_data_generators()')
|
raise ValueError('You didnt set_training_data_generators()')
|
||||||
else:
|
else:
|
||||||
for i, generator in enumerate(self.generator_list):
|
for i, generator in enumerate(self.generator_list):
|
||||||
if not isinstance(generator, SampleGeneratorBase):
|
if not isinstance(generator, SampleGeneratorBase):
|
||||||
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
|
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
|
||||||
|
|
||||||
if self.sample_for_preview is None or choose_preview_history:
|
if self.sample_for_preview is None or choose_preview_history:
|
||||||
if choose_preview_history and io.is_support_windows():
|
if choose_preview_history and io.is_support_windows():
|
||||||
wnd_name = "[p] - next. [enter] - confirm."
|
wnd_name = "[p] - next. [enter] - confirm."
|
||||||
io.named_window(wnd_name)
|
io.named_window(wnd_name)
|
||||||
|
@ -208,25 +263,26 @@ class ModelBase(object):
|
||||||
while not choosed:
|
while not choosed:
|
||||||
self.sample_for_preview = self.generate_next_sample()
|
self.sample_for_preview = self.generate_next_sample()
|
||||||
preview = self.get_static_preview()
|
preview = self.get_static_preview()
|
||||||
io.show_image( wnd_name, (preview*255).astype(np.uint8) )
|
io.show_image(wnd_name, (preview * 255).astype(np.uint8))
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
key_events = io.get_key_events(wnd_name)
|
key_events = io.get_key_events(wnd_name)
|
||||||
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
|
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(
|
||||||
|
key_events) > 0 else (0, 0, False, False, False)
|
||||||
if key == ord('\n') or key == ord('\r'):
|
if key == ord('\n') or key == ord('\r'):
|
||||||
choosed = True
|
choosed = True
|
||||||
break
|
break
|
||||||
elif key == ord('p'):
|
elif key == ord('p'):
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
io.process_messages(0.1)
|
io.process_messages(0.1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
choosed = True
|
choosed = True
|
||||||
|
|
||||||
io.destroy_window(wnd_name)
|
io.destroy_window(wnd_name)
|
||||||
else:
|
else:
|
||||||
self.sample_for_preview = self.generate_next_sample()
|
self.sample_for_preview = self.generate_next_sample()
|
||||||
self.last_sample = self.sample_for_preview
|
self.last_sample = self.sample_for_preview
|
||||||
model_summary_text = []
|
model_summary_text = []
|
||||||
|
|
||||||
|
@ -257,15 +313,15 @@ class ModelBase(object):
|
||||||
model_summary_text += ["=="]
|
model_summary_text += ["=="]
|
||||||
|
|
||||||
model_summary_text += ["========================="]
|
model_summary_text += ["========================="]
|
||||||
model_summary_text = "\r\n".join (model_summary_text)
|
model_summary_text = "\r\n".join(model_summary_text)
|
||||||
self.model_summary_text = model_summary_text
|
self.model_summary_text = model_summary_text
|
||||||
io.log_info(model_summary_text)
|
io.log_info(model_summary_text)
|
||||||
|
|
||||||
#overridable
|
# overridable
|
||||||
def onInitializeOptions(self, is_first_run, ask_override):
|
def onInitializeOptions(self, is_first_run, ask_override):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#overridable
|
# overridable
|
||||||
def onInitialize(self):
|
def onInitialize(self):
|
||||||
'''
|
'''
|
||||||
initialize your keras models
|
initialize your keras models
|
||||||
|
@ -276,36 +332,36 @@ class ModelBase(object):
|
||||||
'''
|
'''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#overridable
|
# overridable
|
||||||
def onSave(self):
|
def onSave(self):
|
||||||
#save your keras models here
|
# save your keras models here
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#overridable
|
# overridable
|
||||||
def onTrainOneIter(self, sample, generator_list):
|
def onTrainOneIter(self, sample, generator_list):
|
||||||
#train your keras models here
|
# train your keras models here
|
||||||
|
|
||||||
#return array of losses
|
# return array of losses
|
||||||
return ( ('loss_src', 0), ('loss_dst', 0) )
|
return (('loss_src', 0), ('loss_dst', 0))
|
||||||
|
|
||||||
#overridable
|
# overridable
|
||||||
def onGetPreview(self, sample):
|
def onGetPreview(self, sample):
|
||||||
#you can return multiple previews
|
# you can return multiple previews
|
||||||
#return [ ('preview_name',preview_rgb), ... ]
|
# return [ ('preview_name',preview_rgb), ... ]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
#overridable if you want model name differs from folder name
|
# overridable if you want model name differs from folder name
|
||||||
def get_model_name(self):
|
def get_model_name(self):
|
||||||
return Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1]
|
return Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1]
|
||||||
|
|
||||||
#overridable , return [ [model, filename],... ] list
|
# overridable , return [ [model, filename],... ] list
|
||||||
def get_model_filename_list(self):
|
def get_model_filename_list(self):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
#overridable
|
# overridable
|
||||||
def get_converter(self):
|
def get_converter(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
#return existing or your own converter which derived from base
|
# return existing or your own converter which derived from base
|
||||||
|
|
||||||
def get_target_iter(self):
|
def get_target_iter(self):
|
||||||
return self.target_iter
|
return self.target_iter
|
||||||
|
@ -313,8 +369,8 @@ class ModelBase(object):
|
||||||
def is_reached_iter_goal(self):
|
def is_reached_iter_goal(self):
|
||||||
return self.target_iter != 0 and self.iter >= self.target_iter
|
return self.target_iter != 0 and self.iter >= self.target_iter
|
||||||
|
|
||||||
#multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976
|
# multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976
|
||||||
#def to_multi_gpu_model_if_possible (self, models_list):
|
# def to_multi_gpu_model_if_possible (self, models_list):
|
||||||
# if len(self.device_config.gpu_idxs) > 1:
|
# if len(self.device_config.gpu_idxs) > 1:
|
||||||
# #make batch_size to divide on GPU count without remainder
|
# #make batch_size to divide on GPU count without remainder
|
||||||
# self.batch_size = int( self.batch_size / len(self.device_config.gpu_idxs) )
|
# self.batch_size = int( self.batch_size / len(self.device_config.gpu_idxs) )
|
||||||
|
@ -333,61 +389,65 @@ class ModelBase(object):
|
||||||
# return models_list
|
# return models_list
|
||||||
|
|
||||||
def get_previews(self):
|
def get_previews(self):
|
||||||
return self.onGetPreview ( self.last_sample )
|
return self.onGetPreview(self.last_sample)
|
||||||
|
|
||||||
def get_static_preview(self):
|
def get_static_preview(self):
|
||||||
return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr
|
return self.onGetPreview(self.sample_for_preview)[0][1] # first preview, and bgr
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
|
self.options['batch_size'] = self.batch_size
|
||||||
|
self.options['paddle'] = self.paddle
|
||||||
summary_path = self.get_strpath_storage_for_file('summary.txt')
|
summary_path = self.get_strpath_storage_for_file('summary.txt')
|
||||||
Path( summary_path ).write_text(self.model_summary_text)
|
Path(summary_path).write_text(self.model_summary_text)
|
||||||
self.onSave()
|
self.onSave()
|
||||||
|
|
||||||
model_data = {
|
model_data = {
|
||||||
'iter': self.iter,
|
'iter': self.iter,
|
||||||
'options': self.options,
|
'options': self.options,
|
||||||
'loss_history': self.loss_history,
|
'loss_history': self.loss_history,
|
||||||
'sample_for_preview' : self.sample_for_preview
|
'sample_for_preview': self.sample_for_preview
|
||||||
}
|
}
|
||||||
self.model_data_path.write_bytes( pickle.dumps(model_data) )
|
self.model_data_path.write_bytes(pickle.dumps(model_data))
|
||||||
|
|
||||||
|
bckp_filename_list = [self.get_strpath_storage_for_file(filename) for _, filename in
|
||||||
|
self.get_model_filename_list()]
|
||||||
|
bckp_filename_list += [str(summary_path), str(self.model_data_path)]
|
||||||
|
|
||||||
bckp_filename_list = [ self.get_strpath_storage_for_file(filename) for _, filename in self.get_model_filename_list() ]
|
|
||||||
bckp_filename_list += [ str(summary_path), str(self.model_data_path) ]
|
|
||||||
|
|
||||||
if self.autobackup:
|
if self.autobackup:
|
||||||
current_hour = time.localtime().tm_hour
|
current_hour = time.localtime().tm_hour
|
||||||
if self.autobackup_current_hour != current_hour:
|
if self.autobackup_current_hour != current_hour:
|
||||||
self.autobackup_current_hour = current_hour
|
self.autobackup_current_hour = current_hour
|
||||||
|
|
||||||
for i in range(15,0,-1):
|
for i in range(15, 0, -1):
|
||||||
idx_str = '%.2d' % i
|
idx_str = '%.2d' % i
|
||||||
next_idx_str = '%.2d' % (i+1)
|
next_idx_str = '%.2d' % (i + 1)
|
||||||
|
|
||||||
idx_backup_path = self.autobackups_path / idx_str
|
idx_backup_path = self.autobackups_path / idx_str
|
||||||
next_idx_packup_path = self.autobackups_path / next_idx_str
|
next_idx_packup_path = self.autobackups_path / next_idx_str
|
||||||
|
|
||||||
if idx_backup_path.exists():
|
if idx_backup_path.exists():
|
||||||
if i == 15:
|
if i == 15:
|
||||||
Path_utils.delete_all_files(idx_backup_path)
|
Path_utils.delete_all_files(idx_backup_path)
|
||||||
else:
|
else:
|
||||||
next_idx_packup_path.mkdir(exist_ok=True)
|
next_idx_packup_path.mkdir(exist_ok=True)
|
||||||
Path_utils.move_all_files (idx_backup_path, next_idx_packup_path)
|
Path_utils.move_all_files(idx_backup_path, next_idx_packup_path)
|
||||||
|
|
||||||
if i == 1:
|
if i == 1:
|
||||||
idx_backup_path.mkdir(exist_ok=True)
|
idx_backup_path.mkdir(exist_ok=True)
|
||||||
for filename in bckp_filename_list:
|
for filename in bckp_filename_list:
|
||||||
shutil.copy ( str(filename), str(idx_backup_path / Path(filename).name) )
|
shutil.copy(str(filename), str(idx_backup_path / Path(filename).name))
|
||||||
|
|
||||||
previews = self.get_previews()
|
previews = self.get_previews()
|
||||||
plist = []
|
plist = []
|
||||||
for i in range(len(previews)):
|
for i in range(len(previews)):
|
||||||
name, bgr = previews[i]
|
name, bgr = previews[i]
|
||||||
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]
|
plist += [(bgr, idx_backup_path / (('preview_%s.jpg') % (name)))]
|
||||||
|
|
||||||
for preview, filepath in plist:
|
for preview, filepath in plist:
|
||||||
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
|
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter,self.batch_size,
|
||||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
preview.shape[1], preview.shape[2])
|
||||||
cv2_imwrite (filepath, img )
|
img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8)
|
||||||
|
cv2_imwrite(filepath, img)
|
||||||
|
|
||||||
def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
|
def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
|
||||||
for model, filename in model_filename_list:
|
for model, filename in model_filename_list:
|
||||||
|
@ -410,16 +470,15 @@ class ModelBase(object):
|
||||||
opt.set_weights(weights)
|
opt.set_weights(weights)
|
||||||
print("set ok")
|
print("set ok")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print ("Unable to load ", opt_filename)
|
print("Unable to load ", opt_filename)
|
||||||
|
|
||||||
|
|
||||||
def save_weights_safe(self, model_filename_list):
|
def save_weights_safe(self, model_filename_list):
|
||||||
for model, filename in model_filename_list:
|
for model, filename in model_filename_list:
|
||||||
filename = self.get_strpath_storage_for_file(filename)
|
filename = self.get_strpath_storage_for_file(filename)
|
||||||
model.save_weights( filename + '.tmp' )
|
model.save_weights(filename + '.tmp')
|
||||||
|
|
||||||
rename_list = model_filename_list
|
rename_list = model_filename_list
|
||||||
|
|
||||||
"""
|
"""
|
||||||
#unused
|
#unused
|
||||||
, optimizer_filename_list=[]
|
, optimizer_filename_list=[]
|
||||||
|
@ -443,36 +502,42 @@ class ModelBase(object):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print ("Unable to save ", opt_filename)
|
print ("Unable to save ", opt_filename)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for _, filename in rename_list:
|
for _, filename in rename_list:
|
||||||
filename = self.get_strpath_storage_for_file(filename)
|
filename = self.get_strpath_storage_for_file(filename)
|
||||||
source_filename = Path(filename+'.tmp')
|
source_filename = Path(filename + '.tmp')
|
||||||
if source_filename.exists():
|
if source_filename.exists():
|
||||||
target_filename = Path(filename)
|
target_filename = Path(filename)
|
||||||
if target_filename.exists():
|
if target_filename.exists():
|
||||||
target_filename.unlink()
|
target_filename.unlink()
|
||||||
source_filename.rename ( str(target_filename) )
|
source_filename.rename(str(target_filename))
|
||||||
|
|
||||||
def debug_one_iter(self):
|
def debug_one_iter(self):
|
||||||
images = []
|
images = []
|
||||||
for generator in self.generator_list:
|
for generator in self.generator_list:
|
||||||
for i,batch in enumerate(next(generator)):
|
for i, batch in enumerate(next(generator)):
|
||||||
if len(batch.shape) == 4:
|
if len(batch.shape) == 4:
|
||||||
images.append( batch[0] )
|
images.append(batch[0])
|
||||||
|
|
||||||
return imagelib.equalize_and_stack_square (images)
|
return imagelib.equalize_and_stack_square(images)
|
||||||
|
|
||||||
def generate_next_sample(self):
|
def generate_next_sample(self):
|
||||||
return [next(generator) for generator in self.generator_list]
|
return [next(generator) for generator in self.generator_list]
|
||||||
|
|
||||||
def train_one_iter(self):
|
def train_one_iter(self):
|
||||||
|
|
||||||
|
if self.iter == 1 and self.options.get('ping_pong', True):
|
||||||
|
self.set_batch_size(1)
|
||||||
|
self.paddle = 'ping'
|
||||||
|
elif not self.options.get('ping_pong', True) and self.batch_cap != self.batch_size:
|
||||||
|
self.set_batch_size(self.batch_cap)
|
||||||
sample = self.generate_next_sample()
|
sample = self.generate_next_sample()
|
||||||
iter_time = time.time()
|
iter_time = time.time()
|
||||||
losses = self.onTrainOneIter(sample, self.generator_list)
|
losses = self.onTrainOneIter(sample, self.generator_list)
|
||||||
iter_time = time.time() - iter_time
|
iter_time = time.time() - iter_time
|
||||||
self.last_sample = sample
|
self.last_sample = sample
|
||||||
|
|
||||||
self.loss_history.append ( [float(loss[1]) for loss in losses] )
|
self.loss_history.append([float(loss[1]) for loss in losses])
|
||||||
|
|
||||||
if self.iter % 10 == 0:
|
if self.iter % 10 == 0:
|
||||||
plist = []
|
plist = []
|
||||||
|
@ -481,20 +546,33 @@ class ModelBase(object):
|
||||||
previews = self.get_previews()
|
previews = self.get_previews()
|
||||||
for i in range(len(previews)):
|
for i in range(len(previews)):
|
||||||
name, bgr = previews[i]
|
name, bgr = previews[i]
|
||||||
plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ]
|
plist += [(bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name)))]
|
||||||
|
|
||||||
if self.write_preview_history:
|
if self.write_preview_history:
|
||||||
plist += [ (self.get_static_preview(), str (self.preview_history_path / ('%.6d.jpg' % (self.iter))) ) ]
|
plist += [(self.get_static_preview(), str(self.preview_history_path / ('%.6d.jpg' % (self.iter))))]
|
||||||
|
|
||||||
for preview, filepath in plist:
|
for preview, filepath in plist:
|
||||||
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
|
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter,self.batch_size, preview.shape[1],
|
||||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
preview.shape[2])
|
||||||
cv2_imwrite (filepath, img )
|
img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8)
|
||||||
|
cv2_imwrite(filepath, img)
|
||||||
|
|
||||||
|
if self.iter % self.ping_pong_iter == 0 and self.iter != 0 and self.options.get('ping_pong', True):
|
||||||
|
if self.batch_size == self.batch_cap:
|
||||||
|
self.paddle = 'pong'
|
||||||
|
if self.batch_size > self.batch_cap:
|
||||||
|
self.set_batch_size(self.batch_cap)
|
||||||
|
self.paddle = 'pong'
|
||||||
|
if self.batch_size == 1:
|
||||||
|
self.paddle = 'ping'
|
||||||
|
if self.paddle == 'ping':
|
||||||
|
self.set_batch_size(self.batch_size + 1)
|
||||||
|
else:
|
||||||
|
self.set_batch_size(self.batch_size - 1)
|
||||||
|
|
||||||
self.iter += 1
|
self.iter += 1
|
||||||
|
|
||||||
return self.iter, iter_time
|
return self.iter, iter_time, self.batch_size
|
||||||
|
|
||||||
def pass_one_iter(self):
|
def pass_one_iter(self):
|
||||||
self.last_sample = self.generate_next_sample()
|
self.last_sample = self.generate_next_sample()
|
||||||
|
@ -520,10 +598,10 @@ class ModelBase(object):
|
||||||
def get_loss_history(self):
|
def get_loss_history(self):
|
||||||
return self.loss_history
|
return self.loss_history
|
||||||
|
|
||||||
def set_training_data_generators (self, generator_list):
|
def set_training_data_generators(self, generator_list):
|
||||||
self.generator_list = generator_list
|
self.generator_list = generator_list
|
||||||
|
|
||||||
def get_training_data_generators (self):
|
def get_training_data_generators(self):
|
||||||
return self.generator_list
|
return self.generator_list
|
||||||
|
|
||||||
def get_model_root_path(self):
|
def get_model_root_path(self):
|
||||||
|
@ -531,12 +609,13 @@ class ModelBase(object):
|
||||||
|
|
||||||
def get_strpath_storage_for_file(self, filename):
|
def get_strpath_storage_for_file(self, filename):
|
||||||
if self.device_args['force_gpu_idx'] == -1:
|
if self.device_args['force_gpu_idx'] == -1:
|
||||||
return str( self.model_path / ( self.get_model_name() + '_' + filename) )
|
return str(self.model_path / (self.get_model_name() + '_' + filename))
|
||||||
else:
|
else:
|
||||||
return str( self.model_path / ( str(self.device_args['force_gpu_idx']) + '_' + self.get_model_name() + '_' + filename) )
|
return str(self.model_path / (
|
||||||
|
str(self.device_args['force_gpu_idx']) + '_' + self.get_model_name() + '_' + filename))
|
||||||
|
|
||||||
def set_vram_batch_requirements (self, d):
|
def set_vram_batch_requirements(self, d):
|
||||||
#example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48}
|
# example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48}
|
||||||
keys = [x for x in d.keys()]
|
keys = [x for x in d.keys()]
|
||||||
|
|
||||||
if self.device_config.cpu_only:
|
if self.device_config.cpu_only:
|
||||||
|
@ -550,65 +629,67 @@ class ModelBase(object):
|
||||||
break
|
break
|
||||||
|
|
||||||
if self.batch_size == 0:
|
if self.batch_size == 0:
|
||||||
self.batch_size = d[ keys[-1] ]
|
self.batch_size = d[keys[-1]]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_loss_history_preview(loss_history, iter, w, c):
|
def get_loss_history_preview(loss_history, iter,batch_size, w, c):
|
||||||
loss_history = np.array (loss_history.copy())
|
loss_history = np.array(loss_history.copy())
|
||||||
|
|
||||||
lh_height = 100
|
lh_height = 100
|
||||||
lh_img = np.ones ( (lh_height,w,c) ) * 0.1
|
lh_img = np.ones((lh_height, w, c)) * 0.1
|
||||||
|
|
||||||
if len(loss_history) != 0:
|
if len(loss_history) != 0:
|
||||||
loss_count = len(loss_history[0])
|
loss_count = len(loss_history[0])
|
||||||
lh_len = len(loss_history)
|
lh_len = len(loss_history)
|
||||||
|
|
||||||
l_per_col = lh_len / w
|
l_per_col = lh_len / w
|
||||||
plist_max = [ [ max (0.0, loss_history[int(col*l_per_col)][p],
|
plist_max = [[max(0.0, loss_history[int(col * l_per_col)][p],
|
||||||
*[ loss_history[i_ab][p]
|
*[loss_history[i_ab][p]
|
||||||
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
|
for i_ab in range(int(col * l_per_col), int((col + 1) * l_per_col))
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
for p in range(loss_count)
|
for p in range(loss_count)
|
||||||
]
|
]
|
||||||
for col in range(w)
|
for col in range(w)
|
||||||
]
|
]
|
||||||
|
|
||||||
plist_min = [ [ min (plist_max[col][p], loss_history[int(col*l_per_col)][p],
|
plist_min = [[min(plist_max[col][p], loss_history[int(col * l_per_col)][p],
|
||||||
*[ loss_history[i_ab][p]
|
*[loss_history[i_ab][p]
|
||||||
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
|
for i_ab in range(int(col * l_per_col), int((col + 1) * l_per_col))
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
for p in range(loss_count)
|
for p in range(loss_count)
|
||||||
]
|
]
|
||||||
for col in range(w)
|
for col in range(w)
|
||||||
]
|
]
|
||||||
|
|
||||||
plist_abs_max = np.mean(loss_history[ len(loss_history) // 5 : ]) * 2
|
plist_abs_max = np.mean(loss_history[len(loss_history) // 5:]) * 2
|
||||||
|
|
||||||
for col in range(0, w):
|
for col in range(0, w):
|
||||||
for p in range(0,loss_count):
|
for p in range(0, loss_count):
|
||||||
point_color = [1.0]*c
|
point_color = [1.0] * c
|
||||||
point_color[0:3] = colorsys.hsv_to_rgb ( p * (1.0/loss_count), 1.0, 1.0 )
|
point_color[0:3] = colorsys.hsv_to_rgb(p * (1.0 / loss_count), 1.0, 1.0)
|
||||||
|
|
||||||
ph_max = int ( (plist_max[col][p] / plist_abs_max) * (lh_height-1) )
|
ph_max = int((plist_max[col][p] / plist_abs_max) * (lh_height - 1))
|
||||||
ph_max = np.clip( ph_max, 0, lh_height-1 )
|
ph_max = np.clip(ph_max, 0, lh_height - 1)
|
||||||
|
|
||||||
ph_min = int ( (plist_min[col][p] / plist_abs_max) * (lh_height-1) )
|
ph_min = int((plist_min[col][p] / plist_abs_max) * (lh_height - 1))
|
||||||
ph_min = np.clip( ph_min, 0, lh_height-1 )
|
ph_min = np.clip(ph_min, 0, lh_height - 1)
|
||||||
|
|
||||||
for ph in range(ph_min, ph_max+1):
|
for ph in range(ph_min, ph_max + 1):
|
||||||
lh_img[ (lh_height-ph-1), col ] = point_color
|
lh_img[(lh_height - ph - 1), col] = point_color
|
||||||
|
|
||||||
lh_lines = 5
|
lh_lines = 5
|
||||||
lh_line_height = (lh_height-1)/lh_lines
|
lh_line_height = (lh_height - 1) / lh_lines
|
||||||
for i in range(0,lh_lines+1):
|
for i in range(0, lh_lines + 1):
|
||||||
lh_img[ int(i*lh_line_height), : ] = (0.8,)*c
|
lh_img[int(i * lh_line_height), :] = (0.8,) * c
|
||||||
|
|
||||||
last_line_t = int((lh_lines-1)*lh_line_height)
|
last_line_t = int((lh_lines - 1) * lh_line_height)
|
||||||
last_line_b = int(lh_lines*lh_line_height)
|
last_line_b = int(lh_lines * lh_line_height)
|
||||||
|
|
||||||
lh_text = 'Iter: %d' % (iter) if iter != 0 else ''
|
lh_text = 'Iter: %d' % iter if iter != 0 else ''
|
||||||
|
bs_text = 'BS: %d' % batch_size if batch_size is not None else '1'
|
||||||
|
|
||||||
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c )
|
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image((last_line_b - last_line_t, w, c), lh_text,
|
||||||
|
color=[0.8] * c)
|
||||||
return lh_img
|
return lh_img
|
||||||
|
|
|
@ -9,6 +9,8 @@ from interact import interact as io
|
||||||
|
|
||||||
from samplelib.SampleProcessor import ColorTransferMode
|
from samplelib.SampleProcessor import ColorTransferMode
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# SAE - Styled AutoEncoder
|
# SAE - Styled AutoEncoder
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,7 +162,8 @@ class SAEModel(ModelBase):
|
||||||
SAEModel.initialize_nn_functions()
|
SAEModel.initialize_nn_functions()
|
||||||
self.set_vram_batch_requirements({1.5: 4})
|
self.set_vram_batch_requirements({1.5: 4})
|
||||||
|
|
||||||
resolution = self.options['resolution']
|
global resolution
|
||||||
|
resolution= self.options['resolution']
|
||||||
ae_dims = self.options['ae_dims']
|
ae_dims = self.options['ae_dims']
|
||||||
e_ch_dims = self.options['e_ch_dims']
|
e_ch_dims = self.options['e_ch_dims']
|
||||||
d_ch_dims = self.options['d_ch_dims']
|
d_ch_dims = self.options['d_ch_dims']
|
||||||
|
@ -171,9 +174,10 @@ class SAEModel(ModelBase):
|
||||||
d_residual_blocks = True
|
d_residual_blocks = True
|
||||||
bgr_shape = (resolution, resolution, 3)
|
bgr_shape = (resolution, resolution, 3)
|
||||||
mask_shape = (resolution, resolution, 1)
|
mask_shape = (resolution, resolution, 1)
|
||||||
|
global ms_count
|
||||||
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
|
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
|
||||||
|
|
||||||
|
global apply_random_ct
|
||||||
apply_random_ct = self.options.get('apply_random_ct', ColorTransferMode.NONE)
|
apply_random_ct = self.options.get('apply_random_ct', ColorTransferMode.NONE)
|
||||||
masked_training = True
|
masked_training = True
|
||||||
|
|
||||||
|
@ -444,18 +448,24 @@ class SAEModel(ModelBase):
|
||||||
self.src_sample_losses = []
|
self.src_sample_losses = []
|
||||||
self.dst_sample_losses = []
|
self.dst_sample_losses = []
|
||||||
|
|
||||||
|
global t
|
||||||
t = SampleProcessor.Types
|
t = SampleProcessor.Types
|
||||||
|
global face_type
|
||||||
face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF
|
face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF
|
||||||
|
|
||||||
|
global t_mode_bgr
|
||||||
t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
|
t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
|
||||||
|
|
||||||
|
global training_data_src_path
|
||||||
training_data_src_path = self.training_data_src_path
|
training_data_src_path = self.training_data_src_path
|
||||||
training_data_dst_path = self.training_data_dst_path
|
global training_data_dst_path
|
||||||
|
training_data_dst_path= self.training_data_dst_path
|
||||||
|
global sort_by_yaw
|
||||||
sort_by_yaw = self.sort_by_yaw
|
sort_by_yaw = self.sort_by_yaw
|
||||||
|
|
||||||
if self.pretrain and self.pretraining_data_path is not None:
|
if self.pretrain and self.pretraining_data_path is not None:
|
||||||
training_data_src_path = self.pretraining_data_path
|
training_data_src_path = self.pretraining_data_path
|
||||||
training_data_dst_path = self.pretraining_data_path
|
training_data_dst_path= self.pretraining_data_path
|
||||||
sort_by_yaw = False
|
sort_by_yaw = False
|
||||||
|
|
||||||
self.set_training_data_generators([
|
self.set_training_data_generators([
|
||||||
|
@ -466,8 +476,9 @@ class SAEModel(ModelBase):
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
||||||
scale_range=np.array([-0.05,
|
scale_range=np.array([-0.05,
|
||||||
0.05]) + self.src_scale_mod / 100.0),
|
0.05]) + self.src_scale_mod / 100.0),
|
||||||
output_sample_types=[{'types': (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
output_sample_types=[{'types': (
|
||||||
'resolution': resolution, 'apply_ct': apply_random_ct}] + \
|
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
|
'resolution': resolution, 'apply_ct': apply_random_ct}] + \
|
||||||
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
|
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
'resolution': resolution // (2 ** i),
|
'resolution': resolution // (2 ** i),
|
||||||
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
|
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
|
||||||
|
@ -478,8 +489,9 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
||||||
output_sample_types=[{'types': (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
output_sample_types=[{'types': (
|
||||||
'resolution': resolution}] + \
|
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
|
'resolution': resolution}] + \
|
||||||
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
|
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
'resolution': resolution // (2 ** i)} for i in
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
range(ms_count)] + \
|
range(ms_count)] + \
|
||||||
|
@ -522,6 +534,43 @@ class SAEModel(ModelBase):
|
||||||
def onSave(self):
|
def onSave(self):
|
||||||
self.save_weights_safe(self.get_model_filename_list())
|
self.save_weights_safe(self.get_model_filename_list())
|
||||||
|
|
||||||
|
# override
|
||||||
|
def set_batch_size(self, batch_size):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.set_training_data_generators([
|
||||||
|
SampleGeneratorFace(training_data_src_path,
|
||||||
|
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||||
|
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
|
||||||
|
debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
||||||
|
scale_range=np.array([-0.05,
|
||||||
|
0.05]) + self.src_scale_mod / 100.0),
|
||||||
|
output_sample_types=[{'types': (
|
||||||
|
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
|
'resolution': resolution, 'apply_ct': apply_random_ct}] + \
|
||||||
|
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
|
'resolution': resolution // (2 ** i),
|
||||||
|
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
|
||||||
|
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||||
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
|
range(ms_count)]
|
||||||
|
),
|
||||||
|
|
||||||
|
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
||||||
|
output_sample_types=[{'types': (
|
||||||
|
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
|
'resolution': resolution}] + \
|
||||||
|
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
|
range(ms_count)] + \
|
||||||
|
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||||
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
|
range(ms_count)])
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# override
|
# override
|
||||||
def onTrainOneIter(self, generators_samples, generators_list):
|
def onTrainOneIter(self, generators_samples, generators_list):
|
||||||
src_samples = generators_samples[0]
|
src_samples = generators_samples[0]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue