From d433acc4a0e0dccf3a5943a0e4671733712b8715 Mon Sep 17 00:00:00 2001 From: Brigham Lysenko Date: Tue, 13 Aug 2019 13:13:36 -0600 Subject: [PATCH] Batch size can now be changed during training --- mainscripts/Trainer.py | 9 +- models/ModelBase.py | 370 ++++++++++++++++++++++---------------- models/Model_SAE/Model.py | 65 ++++++- 3 files changed, 276 insertions(+), 168 deletions(-) diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index f170fcc..8a41c58 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -104,14 +104,14 @@ def trainerThread (s2c, c2s, args, device_args): print("Unable to execute program: %s" % (prog) ) if not is_reached_goal: - iter, iter_time = model.train_one_iter() + iter, iter_time, batch_size = model.train_one_iter() loss_history = model.get_loss_history() time_str = time.strftime("[%H:%M:%S]") if iter_time >= 10: loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) ) else: - loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) ) + loss_string = "{0}[#{1:06d}][{2:04d}ms][bs: {3}]".format ( time_str, iter, int(iter_time*1000), batch_size) if shared_state['after_save']: shared_state['after_save'] = False @@ -186,6 +186,7 @@ def main(args, device_args): no_preview = args.get('no_preview', False) + s2c = queue.Queue() c2s = queue.Queue() @@ -216,6 +217,7 @@ def main(args, device_args): is_waiting_preview = False show_last_history_iters_count = 0 iter = 0 + batch_size = 1 while True: if not c2s.empty(): input = c2s.get() @@ -225,6 +227,7 @@ def main(args, device_args): loss_history = input['loss_history'] if 'loss_history' in input.keys() else None previews = input['previews'] if 'previews' in input.keys() else None iter = input['iter'] if 'iter' in input.keys() else 0 + #batch_size = input['batch_size'] if 'iter' in input.keys() else 1 if previews is not None: max_w = 0 max_h = 0 @@ -280,7 +283,7 @@ def main(args, device_args): else: loss_history_to_show = loss_history[-show_last_history_iters_count:] - lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c) + lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, batch_size, w, c) final = np.concatenate ( [final, lh_img], axis=0 ) final = np.concatenate ( [final, selected_preview_rgb], axis=0 ) diff --git a/models/ModelBase.py b/models/ModelBase.py index 25757b3..80e490f 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -23,39 +23,41 @@ You can implement your own model. Check examples. class ModelBase(object): + def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None, ask_enable_autobackup=True, - ask_write_preview_history=True, - ask_target_iter=True, - ask_batch_size=True, + ask_write_preview_history=True, + ask_target_iter=True, + ask_batch_size=True, ask_sort_by_yaw=True, - ask_random_flip=True, + ask_random_flip=True, ask_src_scale_mod=True): - device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1) - device_args['cpu_only'] = device_args.get('cpu_only',False) + device_args['force_gpu_idx'] = device_args.get('force_gpu_idx', -1) + device_args['cpu_only'] = device_args.get('cpu_only', False) if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']: idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList() if len(idxs_names_list) > 1: - io.log_info ("You have multi GPUs in a system: ") + io.log_info("You have multi GPUs in a system: ") for idx, name in idxs_names_list: - io.log_info ("[%d] : %s" % (idx, name) ) + io.log_info("[%d] : %s" % (idx, name)) - device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1, [ x[0] for x in idxs_names_list] ) + device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1, + [x[0] for x in idxs_names_list]) self.device_args = device_args self.device_config = nnlib.DeviceConfig(allow_growth=True, **self.device_args) - io.log_info ("Loading model...") + io.log_info("Loading model...") self.model_path = model_path - self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') ) + self.model_data_path = Path(self.get_strpath_storage_for_file('data.dat')) self.training_data_src_path = training_data_src_path self.training_data_dst_path = training_data_dst_path self.pretraining_data_path = pretraining_data_path - + self.src_images_paths = None self.dst_images_paths = None self.src_yaw_images_paths = None @@ -65,6 +67,8 @@ class ModelBase(object): self.debug = debug self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None) + self.paddle = 'pong' + self.iter = 0 self.options = {} self.loss_history = [] @@ -72,8 +76,8 @@ class ModelBase(object): model_data = {} if self.model_data_path.exists(): - model_data = pickle.loads ( self.model_data_path.read_bytes() ) - self.iter = max( model_data.get('iter',0), model_data.get('epoch',0) ) + model_data = pickle.loads(self.model_data_path.read_bytes()) + self.iter = max(model_data.get('iter', 0), model_data.get('epoch', 0)) if 'epoch' in self.options: self.options.pop('epoch') if self.iter != 0: @@ -81,65 +85,105 @@ class ModelBase(object): self.loss_history = model_data.get('loss_history', []) self.sample_for_preview = model_data.get('sample_for_preview', None) - ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time ("Press enter in 2 seconds to override model settings.", 5 if io.is_colab() else 2 ) + ask_override = self.is_training_mode and self.iter != 0 and io.input_in_time("Press enter in 2 seconds to" + " override model settings.", + 5 if io.is_colab() else 2) - yn_str = {True:'y',False:'n'} + yn_str = {True: 'y', False: 'n'} if self.iter == 0: - io.log_info ("\nModel first run. Enter model options as default for each run.") + io.log_info("\nModel first run. Enter model options as default for each run.") if ask_enable_autobackup and (self.iter == 0 or ask_override): - default_autobackup = False if self.iter == 0 else self.options.get('autobackup',False) - self.options['autobackup'] = io.input_bool("Enable autobackup? (y/n ?:help skip:%s) : " % (yn_str[default_autobackup]) , default_autobackup, help_message="Autobackup model files with preview every hour for last 15 hours. Latest backup located in model/<>_autobackups/01") + default_autobackup = False if self.iter == 0 else self.options.get('autobackup', False) + self.options['autobackup'] = io.input_bool("Enable autobackup? (y/n ?:help skip:%s) : " % + (yn_str[default_autobackup]), default_autobackup, + help_message="Autobackup model files with preview every hour for" + " last 15 hours. Latest backup located in model/<>" + "_autobackups/01") else: self.options['autobackup'] = self.options.get('autobackup', False) if ask_write_preview_history and (self.iter == 0 or ask_override): - default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',False) - self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to _history folder.") + default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history', + False) + self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % + (yn_str[default_write_preview_history]), + default_write_preview_history, + help_message="Preview history will be writed to" + " _history folder.") else: self.options['write_preview_history'] = self.options.get('write_preview_history', False) if (self.iter == 0 or ask_override) and self.options['write_preview_history'] and io.is_support_windows(): - choose_preview_history = io.input_bool("Choose image for the preview history? (y/n skip:%s) : " % (yn_str[False]) , False) + choose_preview_history = io.input_bool("Choose image for the preview history?" + " (y/n skip:%s) : " % (yn_str[False]), False) + elif (self.iter == 0 or ask_override) and self.options['write_preview_history'] and io.is_colab(): - choose_preview_history = io.input_bool("Randomly choose new image for preview history? (y/n ?:help skip:%s) : " % (yn_str[False]), False, help_message="Preview image history will stay stuck with old faces if you reuse the same model on different celebs. Choose no unless you are changing src/dst to a new person") + choose_preview_history = io.input_bool("Randomly choose new image for preview history? (y/n ?:help skip:%s)" + ": " % (yn_str[False]), False, + help_message="Preview image history will stay stuck with old faces" + " if you reuse the same model on different celebs." + " Choose no unless you are changing src/dst to a" + " new person") else: choose_preview_history = False - + if ask_target_iter: if (self.iter == 0 or ask_override): self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0)) else: - self.options['target_iter'] = max(model_data.get('target_iter',0), self.options.get('target_epoch',0)) + self.options['target_iter'] = max(model_data.get('target_iter', 0), self.options.get('target_epoch', 0)) if 'target_epoch' in self.options: self.options.pop('target_epoch') if ask_batch_size and (self.iter == 0 or ask_override): - default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0) - self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % (default_batch_size), default_batch_size, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually.")) + default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size', 0) + self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % default_batch_size, + default_batch_size, + help_message="Larger batch size is better for NN's" + " generalization, but it can cause Out of" + " Memory error. Tune this value for your" + " videocard manually.")) + self.options['ping_pong'] = io.input_bool( + "Enable ping-pong? (y/n ?:help skip:%s) : " % (yn_str[True]), + True, + help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence") + else: self.options['batch_size'] = self.options.get('batch_size', 0) + self.options['ping_pong'] = self.options.get('ping_pong', True) - if ask_sort_by_yaw: + if ask_sort_by_yaw: if (self.iter == 0 or ask_override): default_sort_by_yaw = self.options.get('sort_by_yaw', False) - self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s) : " % (yn_str[default_sort_by_yaw]), default_sort_by_yaw, help_message="NN will not learn src face directions that don't match dst face directions. Do not use if the dst face has hair that covers the jaw." ) + self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s):" + " " % (yn_str[default_sort_by_yaw]), default_sort_by_yaw, + help_message="NN will not learn src face directions that" + " don't match dst face directions. Do not use " + "if the dst face has hair that covers the jaw") else: self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False) if ask_random_flip: if (self.iter == 0): - self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.") + self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, + help_message="Predicted face will look more naturally" + " without this option, but src faceset should" + " cover all face directions as dst faceset.") else: self.options['random_flip'] = self.options.get('random_flip', True) if ask_src_scale_mod: - if (self.iter == 0): - self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30) + if self.iter == 0: + self.options['src_scale_mod'] = np.clip(io.input_int("Src face scale modifier %" + " ( -30...30, ?:help skip:0) : ", 0, + help_message="If src face shape is wider than" + " dst, try to decrease this value to" + " get a better result."), -30, 30) else: self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0) - + self.autobackup = self.options.get('autobackup', False) if not self.autobackup and 'autobackup' in self.options: self.options.pop('autobackup') @@ -148,15 +192,15 @@ class ModelBase(object): if not self.write_preview_history and 'write_preview_history' in self.options: self.options.pop('write_preview_history') - self.target_iter = self.options.get('target_iter',0) + self.target_iter = self.options.get('target_iter', 0) if self.target_iter == 0 and 'target_iter' in self.options: self.options.pop('target_iter') - self.batch_size = self.options.get('batch_size',0) - self.sort_by_yaw = self.options.get('sort_by_yaw',False) - self.random_flip = self.options.get('random_flip',True) + self.batch_size = self.options.get('batch_size', 0) + self.sort_by_yaw = self.options.get('sort_by_yaw', False) + self.random_flip = self.options.get('random_flip', True) - self.src_scale_mod = self.options.get('src_scale_mod',0) + self.src_scale_mod = self.options.get('src_scale_mod', 0) if self.src_scale_mod == 0 and 'src_scale_mod' in self.options: self.options.pop('src_scale_mod') @@ -169,21 +213,24 @@ class ModelBase(object): self.onInitialize() self.options['batch_size'] = self.batch_size + self.options['paddle'] = 'ping' if self.debug or self.batch_size == 0: self.batch_size = 1 if self.is_training_mode: if self.device_args['force_gpu_idx'] == -1: - self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) ) - self.autobackups_path = self.model_path / ( '%s_autobackups' % (self.get_model_name()) ) + self.preview_history_path = self.model_path / ('%s_history' % (self.get_model_name())) + self.autobackups_path = self.model_path / ('%s_autobackups' % (self.get_model_name())) else: - self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) ) - self.autobackups_path = self.model_path / ( '%d_%s_autobackups' % (self.device_args['force_gpu_idx'], self.get_model_name()) ) - + self.preview_history_path = self.model_path / ('%d_%s_history' % (self.device_args['force_gpu_idx'], + self.get_model_name())) + self.autobackups_path = self.model_path / ('%d_%s_autobackups' % (self.device_args['force_gpu_idx'], + self.get_model_name())) + if self.autobackup: self.autobackup_current_hour = time.localtime().tm_hour - + if not self.autobackups_path.exists(): self.autobackups_path.mkdir(exist_ok=True) @@ -196,13 +243,13 @@ class ModelBase(object): Path(filename).unlink() if self.generator_list is None: - raise ValueError( 'You didnt set_training_data_generators()') + raise ValueError('You didnt set_training_data_generators()') else: for i, generator in enumerate(self.generator_list): if not isinstance(generator, SampleGeneratorBase): raise ValueError('training data generator is not subclass of SampleGeneratorBase') - if self.sample_for_preview is None or choose_preview_history: + if self.sample_for_preview is None or choose_preview_history: if choose_preview_history and io.is_support_windows(): wnd_name = "[p] - next. [enter] - confirm." io.named_window(wnd_name) @@ -211,25 +258,26 @@ class ModelBase(object): while not choosed: self.sample_for_preview = self.generate_next_sample() preview = self.get_static_preview() - io.show_image( wnd_name, (preview*255).astype(np.uint8) ) + io.show_image(wnd_name, (preview * 255).astype(np.uint8)) while True: key_events = io.get_key_events(wnd_name) - key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) + key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len( + key_events) > 0 else (0, 0, False, False, False) if key == ord('\n') or key == ord('\r'): choosed = True break elif key == ord('p'): break - + try: io.process_messages(0.1) except KeyboardInterrupt: choosed = True - + io.destroy_window(wnd_name) - else: - self.sample_for_preview = self.generate_next_sample() + else: + self.sample_for_preview = self.generate_next_sample() self.last_sample = self.sample_for_preview model_summary_text = [] @@ -260,15 +308,15 @@ class ModelBase(object): model_summary_text += ["=="] model_summary_text += ["========================="] - model_summary_text = "\r\n".join (model_summary_text) + model_summary_text = "\r\n".join(model_summary_text) self.model_summary_text = model_summary_text io.log_info(model_summary_text) - #overridable + # overridable def onInitializeOptions(self, is_first_run, ask_override): pass - #overridable + # overridable def onInitialize(self): ''' initialize your keras models @@ -279,36 +327,36 @@ class ModelBase(object): ''' pass - #overridable + # overridable def onSave(self): - #save your keras models here + # save your keras models here pass - #overridable + # overridable def onTrainOneIter(self, sample, generator_list): - #train your keras models here + # train your keras models here - #return array of losses - return ( ('loss_src', 0), ('loss_dst', 0) ) + # return array of losses + return (('loss_src', 0), ('loss_dst', 0)) - #overridable + # overridable def onGetPreview(self, sample): - #you can return multiple previews - #return [ ('preview_name',preview_rgb), ... ] + # you can return multiple previews + # return [ ('preview_name',preview_rgb), ... ] return [] - #overridable if you want model name differs from folder name + # overridable if you want model name differs from folder name def get_model_name(self): return Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1] - #overridable , return [ [model, filename],... ] list + # overridable , return [ [model, filename],... ] list def get_model_filename_list(self): return [] - #overridable + # overridable def get_converter(self): raise NotImplementedError - #return existing or your own converter which derived from base + # return existing or your own converter which derived from base def get_target_iter(self): return self.target_iter @@ -316,8 +364,8 @@ class ModelBase(object): def is_reached_iter_goal(self): return self.target_iter != 0 and self.iter >= self.target_iter - #multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976 - #def to_multi_gpu_model_if_possible (self, models_list): + # multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976 + # def to_multi_gpu_model_if_possible (self, models_list): # if len(self.device_config.gpu_idxs) > 1: # #make batch_size to divide on GPU count without remainder # self.batch_size = int( self.batch_size / len(self.device_config.gpu_idxs) ) @@ -336,61 +384,63 @@ class ModelBase(object): # return models_list def get_previews(self): - return self.onGetPreview ( self.last_sample ) + return self.onGetPreview(self.last_sample) def get_static_preview(self): - return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr + return self.onGetPreview(self.sample_for_preview)[0][1] # first preview, and bgr def save(self): summary_path = self.get_strpath_storage_for_file('summary.txt') - Path( summary_path ).write_text(self.model_summary_text) + Path(summary_path).write_text(self.model_summary_text) self.onSave() model_data = { 'iter': self.iter, 'options': self.options, 'loss_history': self.loss_history, - 'sample_for_preview' : self.sample_for_preview + 'sample_for_preview': self.sample_for_preview } - self.model_data_path.write_bytes( pickle.dumps(model_data) ) + self.model_data_path.write_bytes(pickle.dumps(model_data)) + + bckp_filename_list = [self.get_strpath_storage_for_file(filename) for _, filename in + self.get_model_filename_list()] + bckp_filename_list += [str(summary_path), str(self.model_data_path)] - bckp_filename_list = [ self.get_strpath_storage_for_file(filename) for _, filename in self.get_model_filename_list() ] - bckp_filename_list += [ str(summary_path), str(self.model_data_path) ] - if self.autobackup: current_hour = time.localtime().tm_hour if self.autobackup_current_hour != current_hour: self.autobackup_current_hour = current_hour - for i in range(15,0,-1): + for i in range(15, 0, -1): idx_str = '%.2d' % i - next_idx_str = '%.2d' % (i+1) - + next_idx_str = '%.2d' % (i + 1) + idx_backup_path = self.autobackups_path / idx_str next_idx_packup_path = self.autobackups_path / next_idx_str - + if idx_backup_path.exists(): - if i == 15: + if i == 15: Path_utils.delete_all_files(idx_backup_path) else: next_idx_packup_path.mkdir(exist_ok=True) - Path_utils.move_all_files (idx_backup_path, next_idx_packup_path) - + Path_utils.move_all_files(idx_backup_path, next_idx_packup_path) + if i == 1: - idx_backup_path.mkdir(exist_ok=True) - for filename in bckp_filename_list: - shutil.copy ( str(filename), str(idx_backup_path / Path(filename).name) ) + idx_backup_path.mkdir(exist_ok=True) + for filename in bckp_filename_list: + shutil.copy(str(filename), str(idx_backup_path / Path(filename).name)) previews = self.get_previews() plist = [] for i in range(len(previews)): name, bgr = previews[i] - plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ] + plist += [(bgr, idx_backup_path / (('preview_%s.jpg') % (name)))] for preview, filepath in plist: - preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2]) - img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) - cv2_imwrite (filepath, img ) + preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter,self.batch_size, + preview.shape[1], preview.shape[2]) + img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8) + cv2_imwrite(filepath, img) def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]): for model, filename in model_filename_list: @@ -413,16 +463,15 @@ class ModelBase(object): opt.set_weights(weights) print("set ok") except Exception as e: - print ("Unable to load ", opt_filename) - + print("Unable to load ", opt_filename) def save_weights_safe(self, model_filename_list): for model, filename in model_filename_list: filename = self.get_strpath_storage_for_file(filename) - model.save_weights( filename + '.tmp' ) + model.save_weights(filename + '.tmp') rename_list = model_filename_list - + """ #unused , optimizer_filename_list=[] @@ -446,24 +495,24 @@ class ModelBase(object): except Exception as e: print ("Unable to save ", opt_filename) """ - + for _, filename in rename_list: filename = self.get_strpath_storage_for_file(filename) - source_filename = Path(filename+'.tmp') + source_filename = Path(filename + '.tmp') if source_filename.exists(): target_filename = Path(filename) if target_filename.exists(): target_filename.unlink() - source_filename.rename ( str(target_filename) ) - + source_filename.rename(str(target_filename)) + def debug_one_iter(self): images = [] for generator in self.generator_list: - for i,batch in enumerate(next(generator)): + for i, batch in enumerate(next(generator)): if len(batch.shape) == 4: - images.append( batch[0] ) + images.append(batch[0]) - return imagelib.equalize_and_stack_square (images) + return imagelib.equalize_and_stack_square(images) def generate_next_sample(self): return [next(generator) for generator in self.generator_list] @@ -475,7 +524,7 @@ class ModelBase(object): iter_time = time.time() - iter_time self.last_sample = sample - self.loss_history.append ( [float(loss[1]) for loss in losses] ) + self.loss_history.append([float(loss[1]) for loss in losses]) if self.iter % 10 == 0: plist = [] @@ -484,20 +533,24 @@ class ModelBase(object): previews = self.get_previews() for i in range(len(previews)): name, bgr = previews[i] - plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ] + plist += [(bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name)))] if self.write_preview_history: - plist += [ (self.get_static_preview(), str (self.preview_history_path / ('%.6d.jpg' % (self.iter))) ) ] + plist += [(self.get_static_preview(), str(self.preview_history_path / ('%.6d.jpg' % (self.iter))))] for preview, filepath in plist: - preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2]) - img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) - cv2_imwrite (filepath, img ) + preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter,self.batch_size, preview.shape[1], + preview.shape[2]) + img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8) + cv2_imwrite(filepath, img) + if self.iter % 50 == 0 and self.iter != 0: + + self.set_batch_size(self.batch_size + 1) self.iter += 1 - return self.iter, iter_time + return self.iter, iter_time, self.batch_size def pass_one_iter(self): self.last_sample = self.generate_next_sample() @@ -523,10 +576,10 @@ class ModelBase(object): def get_loss_history(self): return self.loss_history - def set_training_data_generators (self, generator_list): + def set_training_data_generators(self, generator_list): self.generator_list = generator_list - def get_training_data_generators (self): + def get_training_data_generators(self): return self.generator_list def get_model_root_path(self): @@ -534,12 +587,13 @@ class ModelBase(object): def get_strpath_storage_for_file(self, filename): if self.device_args['force_gpu_idx'] == -1: - return str( self.model_path / ( self.get_model_name() + '_' + filename) ) + return str(self.model_path / (self.get_model_name() + '_' + filename)) else: - return str( self.model_path / ( str(self.device_args['force_gpu_idx']) + '_' + self.get_model_name() + '_' + filename) ) + return str(self.model_path / ( + str(self.device_args['force_gpu_idx']) + '_' + self.get_model_name() + '_' + filename)) - def set_vram_batch_requirements (self, d): - #example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48} + def set_vram_batch_requirements(self, d): + # example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48} keys = [x for x in d.keys()] if self.device_config.cpu_only: @@ -553,65 +607,67 @@ class ModelBase(object): break if self.batch_size == 0: - self.batch_size = d[ keys[-1] ] + self.batch_size = d[keys[-1]] @staticmethod - def get_loss_history_preview(loss_history, iter, w, c): - loss_history = np.array (loss_history.copy()) + def get_loss_history_preview(loss_history, iter,batch_size, w, c): + loss_history = np.array(loss_history.copy()) lh_height = 100 - lh_img = np.ones ( (lh_height,w,c) ) * 0.1 - - if len(loss_history) != 0: + lh_img = np.ones((lh_height, w, c)) * 0.1 + + if len(loss_history) != 0: loss_count = len(loss_history[0]) lh_len = len(loss_history) l_per_col = lh_len / w - plist_max = [ [ max (0.0, loss_history[int(col*l_per_col)][p], - *[ loss_history[i_ab][p] - for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) ) - ] - ) - for p in range(loss_count) - ] - for col in range(w) - ] + plist_max = [[max(0.0, loss_history[int(col * l_per_col)][p], + *[loss_history[i_ab][p] + for i_ab in range(int(col * l_per_col), int((col + 1) * l_per_col)) + ] + ) + for p in range(loss_count) + ] + for col in range(w) + ] - plist_min = [ [ min (plist_max[col][p], loss_history[int(col*l_per_col)][p], - *[ loss_history[i_ab][p] - for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) ) - ] - ) - for p in range(loss_count) - ] - for col in range(w) - ] + plist_min = [[min(plist_max[col][p], loss_history[int(col * l_per_col)][p], + *[loss_history[i_ab][p] + for i_ab in range(int(col * l_per_col), int((col + 1) * l_per_col)) + ] + ) + for p in range(loss_count) + ] + for col in range(w) + ] - plist_abs_max = np.mean(loss_history[ len(loss_history) // 5 : ]) * 2 + plist_abs_max = np.mean(loss_history[len(loss_history) // 5:]) * 2 for col in range(0, w): - for p in range(0,loss_count): - point_color = [1.0]*c - point_color[0:3] = colorsys.hsv_to_rgb ( p * (1.0/loss_count), 1.0, 1.0 ) + for p in range(0, loss_count): + point_color = [1.0] * c + point_color[0:3] = colorsys.hsv_to_rgb(p * (1.0 / loss_count), 1.0, 1.0) - ph_max = int ( (plist_max[col][p] / plist_abs_max) * (lh_height-1) ) - ph_max = np.clip( ph_max, 0, lh_height-1 ) + ph_max = int((plist_max[col][p] / plist_abs_max) * (lh_height - 1)) + ph_max = np.clip(ph_max, 0, lh_height - 1) - ph_min = int ( (plist_min[col][p] / plist_abs_max) * (lh_height-1) ) - ph_min = np.clip( ph_min, 0, lh_height-1 ) + ph_min = int((plist_min[col][p] / plist_abs_max) * (lh_height - 1)) + ph_min = np.clip(ph_min, 0, lh_height - 1) - for ph in range(ph_min, ph_max+1): - lh_img[ (lh_height-ph-1), col ] = point_color + for ph in range(ph_min, ph_max + 1): + lh_img[(lh_height - ph - 1), col] = point_color lh_lines = 5 - lh_line_height = (lh_height-1)/lh_lines - for i in range(0,lh_lines+1): - lh_img[ int(i*lh_line_height), : ] = (0.8,)*c + lh_line_height = (lh_height - 1) / lh_lines + for i in range(0, lh_lines + 1): + lh_img[int(i * lh_line_height), :] = (0.8,) * c - last_line_t = int((lh_lines-1)*lh_line_height) - last_line_b = int(lh_lines*lh_line_height) + last_line_t = int((lh_lines - 1) * lh_line_height) + last_line_b = int(lh_lines * lh_line_height) - lh_text = 'Iter: %d' % (iter) if iter != 0 else '' + lh_text = 'Iter: %d' % iter if iter != 0 else '' + bs_text = 'BS: %d' % batch_size if batch_size is not None else '1' - lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c ) + lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image((last_line_b - last_line_t, w, c), bs_text, + color=[0.8] * c) return lh_img diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 3440660..b15079a 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -8,6 +8,8 @@ from samplelib import * from interact import interact as io + + # SAE - Styled AutoEncoder class SAEModel(ModelBase): encoderH5 = 'encoder.h5' @@ -152,7 +154,8 @@ class SAEModel(ModelBase): SAEModel.initialize_nn_functions() self.set_vram_batch_requirements({1.5: 4}) - resolution = self.options['resolution'] + global resolution + resolution= self.options['resolution'] ae_dims = self.options['ae_dims'] e_ch_dims = self.options['e_ch_dims'] d_ch_dims = self.options['d_ch_dims'] @@ -163,9 +166,10 @@ class SAEModel(ModelBase): d_residual_blocks = True bgr_shape = (resolution, resolution, 3) mask_shape = (resolution, resolution, 1) - + global ms_count self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1 + global apply_random_ct apply_random_ct = self.options.get('apply_random_ct', False) masked_training = True @@ -436,18 +440,24 @@ class SAEModel(ModelBase): self.src_sample_losses = [] self.dst_sample_losses = [] + global t t = SampleProcessor.Types + global face_type face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF + global t_mode_bgr t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE + global training_data_src_path training_data_src_path = self.training_data_src_path - training_data_dst_path = self.training_data_dst_path + global training_data_dst_path + training_data_dst_path= self.training_data_dst_path + global sort_by_yaw sort_by_yaw = self.sort_by_yaw if self.pretrain and self.pretraining_data_path is not None: training_data_src_path = self.pretraining_data_path - training_data_dst_path = self.pretraining_data_path + training_data_dst_path= self.pretraining_data_path sort_by_yaw = False self.set_training_data_generators([ @@ -458,8 +468,9 @@ class SAEModel(ModelBase): sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05]) + self.src_scale_mod / 100.0), - output_sample_types=[{'types': (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), - 'resolution': resolution, 'apply_ct': apply_random_ct}] + \ + output_sample_types=[{'types': ( + t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), + 'resolution': resolution, 'apply_ct': apply_random_ct}] + \ [{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2 ** i), 'apply_ct': apply_random_ct} for i in range(ms_count)] + \ @@ -470,8 +481,9 @@ class SAEModel(ModelBase): SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ), - output_sample_types=[{'types': (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), - 'resolution': resolution}] + \ + output_sample_types=[{'types': ( + t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), + 'resolution': resolution}] + \ [{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2 ** i)} for i in range(ms_count)] + \ @@ -514,6 +526,43 @@ class SAEModel(ModelBase): def onSave(self): self.save_weights_safe(self.get_model_filename_list()) + # override + def set_batch_size(self, batch_size): + self.batch_size = batch_size + self.set_training_data_generators([ + SampleGeneratorFace(training_data_src_path, + sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None, + random_ct_samples_path=training_data_dst_path if apply_random_ct else None, + debug=self.is_debug(), batch_size=self.batch_size, + sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, + scale_range=np.array([-0.05, + 0.05]) + self.src_scale_mod / 100.0), + output_sample_types=[{'types': ( + t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), + 'resolution': resolution, 'apply_ct': apply_random_ct}] + \ + [{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr), + 'resolution': resolution // (2 ** i), + 'apply_ct': apply_random_ct} for i in range(ms_count)] + \ + [{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M), + 'resolution': resolution // (2 ** i)} for i in + range(ms_count)] + ), + + SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, + sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ), + output_sample_types=[{'types': ( + t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), + 'resolution': resolution}] + \ + [{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr), + 'resolution': resolution // (2 ** i)} for i in + range(ms_count)] + \ + [{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M), + 'resolution': resolution // (2 ** i)} for i in + range(ms_count)]) + ]) + + + # override def onTrainOneIter(self, generators_samples, generators_list): src_samples = generators_samples[0]