Merge branch 'master' into feat/random-ct-only-on-face-mask

# Conflicts:
#	models/Model_SAE/Model.py
This commit is contained in:
Jeremy Hummel 2019-08-14 13:47:26 -07:00
commit 2c28c6eb3f
4 changed files with 316 additions and 176 deletions

17
main.py
View file

@ -134,16 +134,23 @@ if __name__ == "__main__":
Trainer.main(args, device_args)
p = subparsers.add_parser( "train", help="Trainer")
p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir", help="Dir of extracted SRC faceset.")
p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir", help="Dir of extracted DST faceset.")
p.add_argument('--pretraining-data-dir', action=fixPathAction, dest="pretraining_data_dir", default=None, help="Optional dir of extracted faceset that will be used in pretraining mode.")
p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir",
help="Dir of extracted SRC faceset.")
p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir",
help="Dir of extracted DST faceset.")
p.add_argument('--pretraining-data-dir', action=fixPathAction, dest="pretraining_data_dir", default=None,
help="Optional dir of extracted faceset that will be used in pretraining mode.")
p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.")
p.add_argument('--model', required=True, dest="model_name", choices=Path_utils.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Type of model")
p.add_argument('--no-preview', action="store_true", dest="no_preview", default=False, help="Disable preview window.")
p.add_argument('--model', required=True, dest="model_name", choices=Path_utils.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'),
help="Type of model")
p.add_argument('--no-preview', action="store_true", dest="no_preview", default=False,
help="Disable preview window.")
p.add_argument('--debug', action="store_true", dest="debug", default=False, help="Debug samples.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
p.add_argument('--force-gpu-idx', type=int, dest="force_gpu_idx", default=-1, help="Force to choose this GPU idx.")
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
p.add_argument('--pingpong', dest="ping_pong", default=False,
help="Cycle between a batch size of 1 and the chosen batch size")
p.set_defaults (func=process_train)
def process_convert(arguments):

View file

@ -104,14 +104,14 @@ def trainerThread (s2c, c2s, args, device_args):
print("Unable to execute program: %s" % (prog) )
if not is_reached_goal:
iter, iter_time = model.train_one_iter()
iter, iter_time, batch_size = model.train_one_iter()
loss_history = model.get_loss_history()
time_str = time.strftime("[%H:%M:%S]")
if iter_time >= 10:
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) )
loss_string = "{0}[#{1:06d}][{2:.5s}s][bs: {3}]".format ( time_str, iter, '{:0.4f}'.format(iter_time), batch_size )
else:
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) )
loss_string = "{0}[#{1:06d}][{2:04d}ms][bs: {3}]".format ( time_str, iter, int(iter_time*1000), batch_size)
if shared_state['after_save']:
shared_state['after_save'] = False
@ -186,6 +186,7 @@ def main(args, device_args):
no_preview = args.get('no_preview', False)
s2c = queue.Queue()
c2s = queue.Queue()
@ -216,6 +217,7 @@ def main(args, device_args):
is_waiting_preview = False
show_last_history_iters_count = 0
iter = 0
batch_size = 1
while True:
if not c2s.empty():
input = c2s.get()
@ -225,6 +227,7 @@ def main(args, device_args):
loss_history = input['loss_history'] if 'loss_history' in input.keys() else None
previews = input['previews'] if 'previews' in input.keys() else None
iter = input['iter'] if 'iter' in input.keys() else 0
#batch_size = input['batch_size'] if 'iter' in input.keys() else 1
if previews is not None:
max_w = 0
max_h = 0
@ -280,7 +283,7 @@ def main(args, device_args):
else:
loss_history_to_show = loss_history[-show_last_history_iters_count:]
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c)
lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, batch_size, w, c)
final = np.concatenate ( [final, lh_img], axis=0 )
final = np.concatenate ( [final, selected_preview_rgb], axis=0 )

View file

@ -23,6 +23,7 @@ You can implement your own model. Check examples.
class ModelBase(object):
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None,
ask_enable_autobackup=True,
ask_write_preview_history=True,
@ -42,7 +43,8 @@ class ModelBase(object):
for idx, name in idxs_names_list:
io.log_info("[%d] : %s" % (idx, name))
device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1, [ x[0] for x in idxs_names_list] )
device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1,
[x[0] for x in idxs_names_list])
self.device_args = device_args
self.device_config = nnlib.DeviceConfig(allow_growth=True, **self.device_args)
@ -65,6 +67,8 @@ class ModelBase(object):
self.debug = debug
self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None)
self.paddle = 'pong'
self.iter = 0
self.options = {}
self.loss_history = []
@ -81,7 +85,9 @@ class ModelBase(object):
self.loss_history = model_data.get('loss_history', [])
self.sample_for_preview = model_data.get('sample_for_preview', None)
ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time ("Press enter in 2 seconds to override model settings.", 5 if io.is_colab() else 2 )
ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time("Press enter in 2 seconds to"
" override model settings.",
5 if io.is_colab() else 2)
yn_str = {True: 'y', False: 'n'}
@ -90,20 +96,36 @@ class ModelBase(object):
if ask_enable_autobackup and (self.iter == 0 or ask_override):
default_autobackup = False if self.iter == 0 else self.options.get('autobackup', False)
self.options['autobackup'] = io.input_bool("Enable autobackup? (y/n ?:help skip:%s) : " % (yn_str[default_autobackup]) , default_autobackup, help_message="Autobackup model files with preview every hour for last 15 hours. Latest backup located in model/<>_autobackups/01")
self.options['autobackup'] = io.input_bool("Enable autobackup? (y/n ?:help skip:%s) : " %
(yn_str[default_autobackup]), default_autobackup,
help_message="Autobackup model files with preview every hour for"
" last 15 hours. Latest backup located in model/<>"
"_autobackups/01")
else:
self.options['autobackup'] = self.options.get('autobackup', False)
if ask_write_preview_history and (self.iter == 0 or ask_override):
default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',False)
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',
False)
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " %
(yn_str[default_write_preview_history]),
default_write_preview_history,
help_message="Preview history will be writed to"
" <ModelName>_history folder.")
else:
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
if (self.iter == 0 or ask_override) and self.options['write_preview_history'] and io.is_support_windows():
choose_preview_history = io.input_bool("Choose image for the preview history? (y/n skip:%s) : " % (yn_str[False]) , False)
choose_preview_history = io.input_bool("Choose image for the preview history?"
" (y/n skip:%s) : " % (yn_str[False]), False)
elif (self.iter == 0 or ask_override) and self.options['write_preview_history'] and io.is_colab():
choose_preview_history = io.input_bool("Randomly choose new image for preview history? (y/n ?:help skip:%s) : " % (yn_str[False]), False, help_message="Preview image history will stay stuck with old faces if you reuse the same model on different celebs. Choose no unless you are changing src/dst to a new person")
choose_preview_history = io.input_bool("Randomly choose new image for preview history? (y/n ?:help skip:%s)"
": " % (yn_str[False]), False,
help_message="Preview image history will stay stuck with old faces"
" if you reuse the same model on different celebs."
" Choose no unless you are changing src/dst to a"
" new person")
else:
choose_preview_history = False
@ -117,23 +139,48 @@ class ModelBase(object):
if ask_batch_size and (self.iter == 0 or ask_override):
default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size', 0)
self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % (default_batch_size), default_batch_size, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
self.options['batch_cap'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % 0,
0,
help_message="Larger batch size is better for NN's"
" generalization, but it can cause Out of"
" Memory error. Tune this value for your"
" videocard manually."))
self.options['ping_pong'] = io.input_bool(
"Enable ping-pong? (y/n ?:help skip:%s) : " % (yn_str[True]),
True,
help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence")
self.options['paddle'] = self.options.get('paddle','ping')
if self.options.get('ping_pong',True):
self.options['ping_pong_iter'] = max(0, io.input_int("Ping-pong iteration (skip:1000/default) : ", 1000))
else:
self.options['batch_size'] = self.options.get('batch_size', 0)
self.options['batch_cap'] = self.options.get('batch_cap', 16)
self.options['ping_pong'] = self.options.get('ping_pong', True)
self.options['ping_pong_iter'] = self.options.get('ping_pong_iter',1000)
if ask_sort_by_yaw:
if (self.iter == 0 or ask_override):
default_sort_by_yaw = self.options.get('sort_by_yaw', False)
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s) : " % (yn_str[default_sort_by_yaw]), default_sort_by_yaw, help_message="NN will not learn src face directions that don't match dst face directions. Do not use if the dst face has hair that covers the jaw." )
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s):"
" " % (yn_str[default_sort_by_yaw]), default_sort_by_yaw,
help_message="NN will not learn src face directions that"
" don't match dst face directions. Do not use "
"if the dst face has hair that covers the jaw")
else:
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
if ask_override:
if self.iter == 0 or ask_override:
self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
if ask_src_scale_mod:
if (self.iter == 0):
self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
if self.iter == 0:
self.options['src_scale_mod'] = np.clip(io.input_int("Src face scale modifier %"
" ( -30...30, ?:help skip:0) : ", 0,
help_message="If src face shape is wider than"
" dst, try to decrease this value to"
" get a better result."), -30, 30)
else:
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
@ -149,9 +196,14 @@ class ModelBase(object):
if self.target_iter == 0 and 'target_iter' in self.options:
self.options.pop('target_iter')
self.batch_size = self.options.get('batch_size',0)
self.batch_size = self.options.get('batch_size', 8)
self.batch_cap = self.options.get('batch_cap',16)
self.ping_pong_iter = self.options.get('ping_pong_iter',1000)
self.sort_by_yaw = self.options.get('sort_by_yaw', False)
self.random_flip = self.options.get('random_flip', True)
if self.batch_cap == 0:
self.options['batch_cap'] = self.batch_size
self.batch_cap = self.options.get('batch_cap',16)
self.src_scale_mod = self.options.get('src_scale_mod', 0)
if self.src_scale_mod == 0 and 'src_scale_mod' in self.options:
@ -166,6 +218,7 @@ class ModelBase(object):
self.onInitialize()
self.options['batch_size'] = self.batch_size
self.paddle = self.options.get('paddle', 'ping')
if self.debug or self.batch_size == 0:
self.batch_size = 1
@ -175,8 +228,10 @@ class ModelBase(object):
self.preview_history_path = self.model_path / ('%s_history' % (self.get_model_name()))
self.autobackups_path = self.model_path / ('%s_autobackups' % (self.get_model_name()))
else:
self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
self.autobackups_path = self.model_path / ( '%d_%s_autobackups' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
self.preview_history_path = self.model_path / ('%d_%s_history' % (self.device_args['force_gpu_idx'],
self.get_model_name()))
self.autobackups_path = self.model_path / ('%d_%s_autobackups' % (self.device_args['force_gpu_idx'],
self.get_model_name()))
if self.autobackup:
self.autobackup_current_hour = time.localtime().tm_hour
@ -212,7 +267,8 @@ class ModelBase(object):
while True:
key_events = io.get_key_events(wnd_name)
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(
key_events) > 0 else (0, 0, False, False, False)
if key == ord('\n') or key == ord('\r'):
choosed = True
break
@ -339,6 +395,8 @@ class ModelBase(object):
return self.onGetPreview(self.sample_for_preview)[0][1] # first preview, and bgr
def save(self):
self.options['batch_size'] = self.batch_size
self.options['paddle'] = self.paddle
summary_path = self.get_strpath_storage_for_file('summary.txt')
Path(summary_path).write_text(self.model_summary_text)
self.onSave()
@ -351,7 +409,8 @@ class ModelBase(object):
}
self.model_data_path.write_bytes(pickle.dumps(model_data))
bckp_filename_list = [ self.get_strpath_storage_for_file(filename) for _, filename in self.get_model_filename_list() ]
bckp_filename_list = [self.get_strpath_storage_for_file(filename) for _, filename in
self.get_model_filename_list()]
bckp_filename_list += [str(summary_path), str(self.model_data_path)]
if self.autobackup:
@ -385,7 +444,8 @@ class ModelBase(object):
plist += [(bgr, idx_backup_path / (('preview_%s.jpg') % (name)))]
for preview, filepath in plist:
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter,self.batch_size,
preview.shape[1], preview.shape[2])
img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8)
cv2_imwrite(filepath, img)
@ -412,7 +472,6 @@ class ModelBase(object):
except Exception as e:
print("Unable to load ", opt_filename)
def save_weights_safe(self, model_filename_list):
for model, filename in model_filename_list:
filename = self.get_strpath_storage_for_file(filename)
@ -466,6 +525,12 @@ class ModelBase(object):
return [next(generator) for generator in self.generator_list]
def train_one_iter(self):
if self.iter == 1 and self.options.get('ping_pong', True):
self.set_batch_size(1)
self.paddle = 'ping'
elif not self.options.get('ping_pong', True) and self.batch_cap != self.batch_size:
self.set_batch_size(self.batch_cap)
sample = self.generate_next_sample()
iter_time = time.time()
losses = self.onTrainOneIter(sample, self.generator_list)
@ -487,14 +552,27 @@ class ModelBase(object):
plist += [(self.get_static_preview(), str(self.preview_history_path / ('%.6d.jpg' % (self.iter))))]
for preview, filepath in plist:
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter,self.batch_size, preview.shape[1],
preview.shape[2])
img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8)
cv2_imwrite(filepath, img)
if self.iter % self.ping_pong_iter == 0 and self.iter != 0 and self.options.get('ping_pong', True):
if self.batch_size == self.batch_cap:
self.paddle = 'pong'
if self.batch_size > self.batch_cap:
self.set_batch_size(self.batch_cap)
self.paddle = 'pong'
if self.batch_size == 1:
self.paddle = 'ping'
if self.paddle == 'ping':
self.set_batch_size(self.batch_size + 1)
else:
self.set_batch_size(self.batch_size - 1)
self.iter += 1
return self.iter, iter_time
return self.iter, iter_time, self.batch_size
def pass_one_iter(self):
self.last_sample = self.generate_next_sample()
@ -533,7 +611,8 @@ class ModelBase(object):
if self.device_args['force_gpu_idx'] == -1:
return str(self.model_path / (self.get_model_name() + '_' + filename))
else:
return str( self.model_path / ( str(self.device_args['force_gpu_idx']) + '_' + self.get_model_name() + '_' + filename) )
return str(self.model_path / (
str(self.device_args['force_gpu_idx']) + '_' + self.get_model_name() + '_' + filename))
def set_vram_batch_requirements(self, d):
# example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48}
@ -553,7 +632,7 @@ class ModelBase(object):
self.batch_size = d[keys[-1]]
@staticmethod
def get_loss_history_preview(loss_history, iter, w, c):
def get_loss_history_preview(loss_history, iter,batch_size, w, c):
loss_history = np.array(loss_history.copy())
lh_height = 100
@ -608,7 +687,9 @@ class ModelBase(object):
last_line_t = int((lh_lines - 1) * lh_line_height)
last_line_b = int(lh_lines * lh_line_height)
lh_text = 'Iter: %d' % (iter) if iter != 0 else ''
lh_text = 'Iter: %d' % iter if iter != 0 else ''
bs_text = 'BS: %d' % batch_size if batch_size is not None else '1'
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c )
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image((last_line_b - last_line_t, w, c), lh_text,
color=[0.8] * c)
return lh_img

View file

@ -9,6 +9,8 @@ from interact import interact as io
from samplelib.SampleProcessor import ColorTransferMode
# SAE - Styled AutoEncoder
@ -160,6 +162,7 @@ class SAEModel(ModelBase):
SAEModel.initialize_nn_functions()
self.set_vram_batch_requirements({1.5: 4})
global resolution
resolution= self.options['resolution']
ae_dims = self.options['ae_dims']
e_ch_dims = self.options['e_ch_dims']
@ -171,9 +174,10 @@ class SAEModel(ModelBase):
d_residual_blocks = True
bgr_shape = (resolution, resolution, 3)
mask_shape = (resolution, resolution, 1)
global ms_count
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
global apply_random_ct
apply_random_ct = self.options.get('apply_random_ct', ColorTransferMode.NONE)
masked_training = True
@ -444,13 +448,19 @@ class SAEModel(ModelBase):
self.src_sample_losses = []
self.dst_sample_losses = []
global t
t = SampleProcessor.Types
global face_type
face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF
global t_mode_bgr
t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
global training_data_src_path
training_data_src_path = self.training_data_src_path
global training_data_dst_path
training_data_dst_path= self.training_data_dst_path
global sort_by_yaw
sort_by_yaw = self.sort_by_yaw
if self.pretrain and self.pretraining_data_path is not None:
@ -466,7 +476,8 @@ class SAEModel(ModelBase):
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
scale_range=np.array([-0.05,
0.05]) + self.src_scale_mod / 100.0),
output_sample_types=[{'types': (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
output_sample_types=[{'types': (
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution, 'apply_ct': apply_random_ct}] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution // (2 ** i),
@ -478,7 +489,8 @@ class SAEModel(ModelBase):
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
output_sample_types=[{'types': (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
output_sample_types=[{'types': (
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution}] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution // (2 ** i)} for i in
@ -522,6 +534,43 @@ class SAEModel(ModelBase):
def onSave(self):
self.save_weights_safe(self.get_model_filename_list())
# override
def set_batch_size(self, batch_size):
self.batch_size = batch_size
self.set_training_data_generators([
SampleGeneratorFace(training_data_src_path,
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
scale_range=np.array([-0.05,
0.05]) + self.src_scale_mod / 100.0),
output_sample_types=[{'types': (
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution, 'apply_ct': apply_random_ct}] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution // (2 ** i),
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)]
),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
output_sample_types=[{'types': (
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution}] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)])
])
# override
def onTrainOneIter(self, generators_samples, generators_list):
src_samples = generators_samples[0]