diff --git a/README.md b/README.md index 84d4db3..4133ad8 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,6 @@ ## **DeepFaceLab** is a tool that utilizes machine learning to replace faces in videos. -If you like this software, please consider a donation. - GOAL: next DeepFacelab update. - ### [Gallery](doc/gallery/doc_gallery.md) @@ -16,8 +14,6 @@ GOAL: next DeepFacelab update. [English (google translated)](doc/manual_en_google_translated.pdf) -[На русском](doc/manual_ru.pdf) - - ### [Prebuilt windows app](doc/doc_prebuilt_windows_app.md) - ### Forks @@ -29,13 +25,3 @@ GOAL: next DeepFacelab update. - ### [Ready to work facesets](doc/doc_ready_to_work_facesets.md) - ### [Build and repository info](doc/doc_build_and_repository_info.md) - -- ### Communication groups: - -(Chinese) QQ group 951138799 for ML/AI experts - -[deepfakes (Chinese)](https://deepfakescn.com) - -[deepfakes (Chinese) (outdated) ](https://deepfakes.com.cn/) - -[reddit (English)](https://www.reddit.com/r/GifFakes/new/) diff --git a/interact/interact.py b/interact/interact.py index c9b22f1..bda74af 100644 --- a/interact/interact.py +++ b/interact/interact.py @@ -33,6 +33,7 @@ class InteractBase(object): self.key_events = {} self.pg_bar = None self.focus_wnd_name = None + self.error_log_line_prefix = '/!\\ ' def is_support_windows(self): return False @@ -65,10 +66,22 @@ class InteractBase(object): raise NotImplemented def log_info(self, msg, end='\n'): + if self.pg_bar is not None: + try: # Attempt print before the pb + tqdm.write(msg) + return + except: + pass #Fallback to normal print upon failure print (msg, end=end) def log_err(self, msg, end='\n'): - print (msg, end=end) + if self.pg_bar is not None: + try: # Attempt print before the pb + tqdm.write(f'{self.error_log_line_prefix}{msg}') + return + except: + pass #Fallback to normal print upon failure + print (f'{self.error_log_line_prefix}{msg}', end=end) def named_window(self, wnd_name): if wnd_name not in self.named_windows: @@ -150,9 +163,12 @@ class InteractBase(object): else: print("progress_bar not set.") def progress_bar_generator(self, data, desc, leave=True): - for x in tqdm( data, desc=desc, leave=leave, ascii=True ): + self.pg_bar = tqdm( data, desc=desc, leave=leave, ascii=True ) + for x in self.pg_bar: yield x - + self.pg_bar.close() + self.pg_bar = None + def process_messages(self, sleep_time=0): self.on_process_messages(sleep_time) diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index 624d44f..fa152ba 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -12,7 +12,7 @@ import cv2 import models from interact import interact as io -def trainerThread (s2c, c2s, args, device_args): +def trainerThread (s2c, c2s, e, args, device_args): while True: try: start_time = time.time() @@ -66,6 +66,7 @@ def trainerThread (s2c, c2s, args, device_args): else: previews = [( 'debug, press update for new', model.debug_one_iter())] c2s.put ( {'op':'show', 'previews': previews} ) + e.set() #Set the GUI Thread as Ready if model.is_first_run(): @@ -190,9 +191,12 @@ def main(args, device_args): s2c = queue.Queue() c2s = queue.Queue() - thread = threading.Thread(target=trainerThread, args=(s2c, c2s, args, device_args) ) + e = threading.Event() + thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, args, device_args) ) thread.start() + e.wait() #Wait for inital load to occur. + if no_preview: while True: if not c2s.empty(): @@ -294,7 +298,7 @@ def main(args, device_args): key_events = io.get_key_events(wnd_name) key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) - + if key == ord('\n') or key == ord('\r'): s2c.put ( {'op': 'close'} ) elif key == ord('s'): @@ -324,4 +328,4 @@ def main(args, device_args): except KeyboardInterrupt: s2c.put ( {'op': 'close'} ) - io.destroy_all_windows() + io.destroy_all_windows() \ No newline at end of file diff --git a/models/ModelBase.py b/models/ModelBase.py index 27f51d5..a14acf4 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -26,22 +26,22 @@ class ModelBase(object): def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None, ask_enable_autobackup=True, - ask_write_preview_history=True, - ask_target_iter=True, - ask_batch_size=True, + ask_write_preview_history=True, + ask_target_iter=True, + ask_batch_size=True, ask_sort_by_yaw=True, - ask_random_flip=True, + ask_random_flip=True, ask_src_scale_mod=True): - device_args['force_gpu_idx'] = device_args.get('force_gpu_idx', -1) - device_args['cpu_only'] = device_args.get('cpu_only', False) + device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1) + device_args['cpu_only'] = device_args.get('cpu_only',False) if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']: idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList() if len(idxs_names_list) > 1: - io.log_info("You have multi GPUs in a system: ") + io.log_info ("You have multi GPUs in a system: ") for idx, name in idxs_names_list: - io.log_info("[%d] : %s" % (idx, name)) + io.log_info ("[%d] : %s" % (idx, name) ) device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1, [x[0] for x in idxs_names_list]) @@ -49,15 +49,15 @@ class ModelBase(object): self.device_config = nnlib.DeviceConfig(allow_growth=True, **self.device_args) - io.log_info("Loading model...") + io.log_info ("Loading model...") self.model_path = model_path - self.model_data_path = Path(self.get_strpath_storage_for_file('data.dat')) + self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') ) self.training_data_src_path = training_data_src_path self.training_data_dst_path = training_data_dst_path self.pretraining_data_path = pretraining_data_path - + self.src_images_paths = None self.dst_images_paths = None self.src_yaw_images_paths = None @@ -76,8 +76,8 @@ class ModelBase(object): model_data = {} if self.model_data_path.exists(): - model_data = pickle.loads(self.model_data_path.read_bytes()) - self.iter = max(model_data.get('iter', 0), model_data.get('epoch', 0)) + model_data = pickle.loads ( self.model_data_path.read_bytes() ) + self.iter = max( model_data.get('iter',0), model_data.get('epoch',0) ) if 'epoch' in self.options: self.options.pop('epoch') if self.iter != 0: @@ -89,13 +89,13 @@ class ModelBase(object): " override model settings.", 5 if io.is_colab() else 2) - yn_str = {True: 'y', False: 'n'} + yn_str = {True:'y',False:'n'} if self.iter == 0: - io.log_info("\nModel first run. Enter model options as default for each run.") + io.log_info ("\nModel first run. Enter model options as default for each run.") if ask_enable_autobackup and (self.iter == 0 or ask_override): - default_autobackup = False if self.iter == 0 else self.options.get('autobackup', False) + default_autobackup = False if self.iter == 0 else self.options.get('autobackup',False) self.options['autobackup'] = io.input_bool("Enable autobackup? (y/n ?:help skip:%s) : " % (yn_str[default_autobackup]), default_autobackup, help_message="Autobackup model files with preview every hour for" @@ -128,17 +128,17 @@ class ModelBase(object): " new person") else: choose_preview_history = False - + if ask_target_iter: if (self.iter == 0 or ask_override): self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0)) else: - self.options['target_iter'] = max(model_data.get('target_iter', 0), self.options.get('target_epoch', 0)) + self.options['target_iter'] = max(model_data.get('target_iter',0), self.options.get('target_epoch',0)) if 'target_epoch' in self.options: self.options.pop('target_epoch') if ask_batch_size and (self.iter == 0 or ask_override): - default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size', 0) + default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0) self.options['batch_cap'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % self.options.get('batch_cap', 16),self.options.get('batch_cap', 16), help_message="Larger batch size is better for NN's" " generalization, but it can cause Out of" @@ -157,7 +157,7 @@ class ModelBase(object): self.options['ping_pong'] = self.options.get('ping_pong', False) self.options['ping_pong_iter'] = self.options.get('ping_pong_iter',1000) - if ask_sort_by_yaw: + if ask_sort_by_yaw: if (self.iter == 0 or ask_override): default_sort_by_yaw = self.options.get('sort_by_yaw', False) self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s):" @@ -182,7 +182,7 @@ class ModelBase(object): " get a better result."), -30, 30) else: self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0) - + self.autobackup = self.options.get('autobackup', False) if not self.autobackup and 'autobackup' in self.options: self.options.pop('autobackup') @@ -191,20 +191,20 @@ class ModelBase(object): if not self.write_preview_history and 'write_preview_history' in self.options: self.options.pop('write_preview_history') - self.target_iter = self.options.get('target_iter', 0) + self.target_iter = self.options.get('target_iter',0) if self.target_iter == 0 and 'target_iter' in self.options: self.options.pop('target_iter') self.batch_size = self.options.get('batch_size', 8) self.batch_cap = self.options.get('batch_cap',16) self.ping_pong_iter = self.options.get('ping_pong_iter',1000) - self.sort_by_yaw = self.options.get('sort_by_yaw', False) - self.random_flip = self.options.get('random_flip', True) + self.sort_by_yaw = self.options.get('sort_by_yaw',False) + self.random_flip = self.options.get('random_flip',True) if self.batch_cap == 0: self.options['batch_cap'] = self.batch_size self.batch_cap = self.options.get('batch_cap',16) - self.src_scale_mod = self.options.get('src_scale_mod', 0) + self.src_scale_mod = self.options.get('src_scale_mod',0) if self.src_scale_mod == 0 and 'src_scale_mod' in self.options: self.options.pop('src_scale_mod') @@ -224,17 +224,17 @@ class ModelBase(object): if self.is_training_mode: if self.device_args['force_gpu_idx'] == -1: - self.preview_history_path = self.model_path / ('%s_history' % (self.get_model_name())) - self.autobackups_path = self.model_path / ('%s_autobackups' % (self.get_model_name())) + self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) ) + self.autobackups_path = self.model_path / ( '%s_autobackups' % (self.get_model_name()) ) else: self.preview_history_path = self.model_path / ('%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name())) self.autobackups_path = self.model_path / ('%d_%s_autobackups' % (self.device_args['force_gpu_idx'], self.get_model_name())) - + if self.autobackup: self.autobackup_current_hour = time.localtime().tm_hour - + if not self.autobackups_path.exists(): self.autobackups_path.mkdir(exist_ok=True) @@ -247,13 +247,13 @@ class ModelBase(object): Path(filename).unlink() if self.generator_list is None: - raise ValueError('You didnt set_training_data_generators()') + raise ValueError( 'You didnt set_training_data_generators()') else: for i, generator in enumerate(self.generator_list): if not isinstance(generator, SampleGeneratorBase): raise ValueError('training data generator is not subclass of SampleGeneratorBase') - if self.sample_for_preview is None or choose_preview_history: + if self.sample_for_preview is None or choose_preview_history: if choose_preview_history and io.is_support_windows(): wnd_name = "[p] - next. [enter] - confirm." io.named_window(wnd_name) @@ -262,7 +262,7 @@ class ModelBase(object): while not choosed: self.sample_for_preview = self.generate_next_sample() preview = self.get_static_preview() - io.show_image(wnd_name, (preview * 255).astype(np.uint8)) + io.show_image( wnd_name, (preview*255).astype(np.uint8) ) while True: key_events = io.get_key_events(wnd_name) @@ -273,54 +273,72 @@ class ModelBase(object): break elif key == ord('p'): break - + try: io.process_messages(0.1) except KeyboardInterrupt: choosed = True - + io.destroy_window(wnd_name) - else: - self.sample_for_preview = self.generate_next_sample() + else: + self.sample_for_preview = self.generate_next_sample() self.last_sample = self.sample_for_preview + + ###Generate text summary of model hyperparameters + #Find the longest key name and value string. Used as column widths. + width_name = max([len(k) for k in self.options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration" + width_value = max([len(str(x)) for x in self.options.values()] + [len(str(self.iter)), len(self.get_model_name())]) + 1 # Single space buffer to right edge + if not self.device_config.cpu_only: #Check length of GPU names + width_value = max([len(nnlib.device.getDeviceName(idx))+1 for idx in self.device_config.gpu_idxs] + [width_value]) + width_total = width_name + width_value + 2 #Plus 2 for ": " + model_summary_text = [] - - model_summary_text += ["===== Model summary ====="] - model_summary_text += ["== Model name: " + self.get_model_name()] - model_summary_text += ["=="] - model_summary_text += ["== Current iteration: " + str(self.iter)] - model_summary_text += ["=="] - model_summary_text += ["== Model options:"] + model_summary_text += [f'=={" Model Summary ":=^{width_total}}=='] # Model/status summary + model_summary_text += [f'=={" "*width_total}=='] + model_summary_text += [f'=={"Model name": >{width_name}}: {self.get_model_name(): <{width_value}}=='] # Name + model_summary_text += [f'=={" "*width_total}=='] + model_summary_text += [f'=={"Current iteration": >{width_name}}: {str(self.iter): <{width_value}}=='] # Iter + model_summary_text += [f'=={" "*width_total}=='] + + model_summary_text += [f'=={" Model Options ":-^{width_total}}=='] # Model options + model_summary_text += [f'=={" "*width_total}=='] for key in self.options.keys(): - model_summary_text += ["== |== %s : %s" % (key, self.options[key])] - + model_summary_text += [f'=={key: >{width_name}}: {str(self.options[key]): <{width_value}}=='] # self.options key/value pairs + model_summary_text += [f'=={" "*width_total}=='] + + model_summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info + model_summary_text += [f'=={" "*width_total}=='] if self.device_config.multi_gpu: - model_summary_text += ["== |== multi_gpu : True "] - - model_summary_text += ["== Running on:"] + model_summary_text += [f'=={"Using multi_gpu": >{width_name}}: {"True": <{width_value}}=='] # multi_gpu + model_summary_text += [f'=={" "*width_total}=='] if self.device_config.cpu_only: - model_summary_text += ["== |== [CPU]"] + model_summary_text += [f'=={"Using device": >{width_name}}: {"CPU": <{width_value}}=='] # cpu_only else: for idx in self.device_config.gpu_idxs: - model_summary_text += ["== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx))] - - if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[0] == 2: - model_summary_text += ["=="] - model_summary_text += ["== WARNING: You are using 2GB GPU. Result quality may be significantly decreased."] - model_summary_text += ["== If training does not start, close all programs and try again."] - model_summary_text += ["== Also you can disable Windows Aero Desktop to get extra free VRAM."] - model_summary_text += ["=="] - - model_summary_text += ["========================="] - model_summary_text = "\r\n".join(model_summary_text) + model_summary_text += [f'=={"Device index": >{width_name}}: {idx: <{width_value}}=='] # GPU hardware device index + model_summary_text += [f'=={"Name": >{width_name}}: {nnlib.device.getDeviceName(idx): <{width_value}}=='] # GPU name + vram_str = f'{nnlib.device.getDeviceVRAMTotalGb(idx):.2f}GB' # GPU VRAM - Formated as #.## (or ##.##) + model_summary_text += [f'=={"VRAM": >{width_name}}: {vram_str: <{width_value}}=='] + model_summary_text += [f'=={" "*width_total}=='] + model_summary_text += [f'=={"="*width_total}=='] + + if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[0] <= 2: # Low VRAM warning + model_summary_text += ["/!\\"] + model_summary_text += ["/!\\ WARNING:"] + model_summary_text += ["/!\\ You are using a GPU with 2GB or less VRAM. This may significantly reduce the quality of your result!"] + model_summary_text += ["/!\\ If training does not start, close all programs and try again."] + model_summary_text += ["/!\\ Also you can disable Windows Aero Desktop to increase available VRAM."] + model_summary_text += ["/!\\"] + + model_summary_text = "\n".join (model_summary_text) self.model_summary_text = model_summary_text io.log_info(model_summary_text) - # overridable + #overridable def onInitializeOptions(self, is_first_run, ask_override): pass - # overridable + #overridable def onInitialize(self): ''' initialize your keras models @@ -331,36 +349,36 @@ class ModelBase(object): ''' pass - # overridable + #overridable def onSave(self): - # save your keras models here + #save your keras models here pass - # overridable + #overridable def onTrainOneIter(self, sample, generator_list): - # train your keras models here + #train your keras models here - # return array of losses - return (('loss_src', 0), ('loss_dst', 0)) + #return array of losses + return ( ('loss_src', 0), ('loss_dst', 0) ) - # overridable + #overridable def onGetPreview(self, sample): - # you can return multiple previews - # return [ ('preview_name',preview_rgb), ... ] + #you can return multiple previews + #return [ ('preview_name',preview_rgb), ... ] return [] - # overridable if you want model name differs from folder name + #overridable if you want model name differs from folder name def get_model_name(self): return Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1] - # overridable , return [ [model, filename],... ] list + #overridable , return [ [model, filename],... ] list def get_model_filename_list(self): return [] - # overridable + #overridable def get_converter(self): raise NotImplementedError - # return existing or your own converter which derived from base + #return existing or your own converter which derived from base def get_target_iter(self): return self.target_iter @@ -368,8 +386,8 @@ class ModelBase(object): def is_reached_iter_goal(self): return self.target_iter != 0 and self.iter >= self.target_iter - # multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976 - # def to_multi_gpu_model_if_possible (self, models_list): + #multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976 + #def to_multi_gpu_model_if_possible (self, models_list): # if len(self.device_config.gpu_idxs) > 1: # #make batch_size to divide on GPU count without remainder # self.batch_size = int( self.batch_size / len(self.device_config.gpu_idxs) ) @@ -388,65 +406,65 @@ class ModelBase(object): # return models_list def get_previews(self): - return self.onGetPreview(self.last_sample) + return self.onGetPreview ( self.last_sample ) def get_static_preview(self): - return self.onGetPreview(self.sample_for_preview)[0][1] # first preview, and bgr + return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr def save(self): self.options['batch_size'] = self.batch_size self.options['paddle'] = self.paddle summary_path = self.get_strpath_storage_for_file('summary.txt') - Path(summary_path).write_text(self.model_summary_text) + Path( summary_path ).write_text(self.model_summary_text) self.onSave() model_data = { 'iter': self.iter, 'options': self.options, 'loss_history': self.loss_history, - 'sample_for_preview': self.sample_for_preview + 'sample_for_preview' : self.sample_for_preview } - self.model_data_path.write_bytes(pickle.dumps(model_data)) + self.model_data_path.write_bytes( pickle.dumps(model_data) ) bckp_filename_list = [self.get_strpath_storage_for_file(filename) for _, filename in self.get_model_filename_list()] - bckp_filename_list += [str(summary_path), str(self.model_data_path)] - + bckp_filename_list += [ str(summary_path), str(self.model_data_path) ] + if self.autobackup: current_hour = time.localtime().tm_hour if self.autobackup_current_hour != current_hour: self.autobackup_current_hour = current_hour - for i in range(15, 0, -1): + for i in range(15,0,-1): idx_str = '%.2d' % i - next_idx_str = '%.2d' % (i + 1) - + next_idx_str = '%.2d' % (i+1) + idx_backup_path = self.autobackups_path / idx_str next_idx_packup_path = self.autobackups_path / next_idx_str - + if idx_backup_path.exists(): - if i == 15: + if i == 15: Path_utils.delete_all_files(idx_backup_path) else: next_idx_packup_path.mkdir(exist_ok=True) - Path_utils.move_all_files(idx_backup_path, next_idx_packup_path) - + Path_utils.move_all_files (idx_backup_path, next_idx_packup_path) + if i == 1: - idx_backup_path.mkdir(exist_ok=True) - for filename in bckp_filename_list: - shutil.copy(str(filename), str(idx_backup_path / Path(filename).name)) + idx_backup_path.mkdir(exist_ok=True) + for filename in bckp_filename_list: + shutil.copy ( str(filename), str(idx_backup_path / Path(filename).name) ) previews = self.get_previews() plist = [] for i in range(len(previews)): name, bgr = previews[i] - plist += [(bgr, idx_backup_path / (('preview_%s.jpg') % (name)))] + plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ] for preview, filepath in plist: preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter,self.batch_size, preview.shape[1], preview.shape[2]) - img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8) - cv2_imwrite(filepath, img) + img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) + cv2_imwrite (filepath, img ) def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]): for model, filename in model_filename_list: @@ -469,15 +487,15 @@ class ModelBase(object): opt.set_weights(weights) print("set ok") except Exception as e: - print("Unable to load ", opt_filename) + print ("Unable to load ", opt_filename) def save_weights_safe(self, model_filename_list): for model, filename in model_filename_list: filename = self.get_strpath_storage_for_file(filename) - model.save_weights(filename + '.tmp') + model.save_weights( filename + '.tmp' ) rename_list = model_filename_list - + """ #unused , optimizer_filename_list=[] @@ -501,24 +519,24 @@ class ModelBase(object): except Exception as e: print ("Unable to save ", opt_filename) """ - + for _, filename in rename_list: filename = self.get_strpath_storage_for_file(filename) - source_filename = Path(filename + '.tmp') + source_filename = Path(filename+'.tmp') if source_filename.exists(): target_filename = Path(filename) if target_filename.exists(): target_filename.unlink() - source_filename.rename(str(target_filename)) - + source_filename.rename ( str(target_filename) ) + def debug_one_iter(self): images = [] for generator in self.generator_list: - for i, batch in enumerate(next(generator)): + for i,batch in enumerate(next(generator)): if len(batch.shape) == 4: - images.append(batch[0]) + images.append( batch[0] ) - return imagelib.equalize_and_stack_square(images) + return imagelib.equalize_and_stack_square (images) def generate_next_sample(self): return [ generator.generate_next() for generator in self.generator_list] @@ -536,7 +554,7 @@ class ModelBase(object): iter_time = time.time() - iter_time self.last_sample = sample - self.loss_history.append([float(loss[1]) for loss in losses]) + self.loss_history.append ( [float(loss[1]) for loss in losses] ) if self.iter % 10 == 0: plist = [] @@ -545,16 +563,16 @@ class ModelBase(object): previews = self.get_previews() for i in range(len(previews)): name, bgr = previews[i] - plist += [(bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name)))] + plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ] if self.write_preview_history: - plist += [(self.get_static_preview(), str(self.preview_history_path / ('%.6d.jpg' % (self.iter))))] + plist += [ (self.get_static_preview(), str (self.preview_history_path / ('%.6d.jpg' % (self.iter))) ) ] for preview, filepath in plist: preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter,self.batch_size, preview.shape[1], preview.shape[2]) - img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8) - cv2_imwrite(filepath, img) + img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) + cv2_imwrite (filepath, img ) if self.iter % self.ping_pong_iter == 0 and self.iter != 0 and self.options.get('ping_pong', False): if self.batch_size == self.batch_cap: @@ -599,10 +617,10 @@ class ModelBase(object): def get_loss_history(self): return self.loss_history - def set_training_data_generators(self, generator_list): + def set_training_data_generators (self, generator_list): self.generator_list = generator_list - def get_training_data_generators(self): + def get_training_data_generators (self): return self.generator_list def get_model_root_path(self): @@ -610,13 +628,13 @@ class ModelBase(object): def get_strpath_storage_for_file(self, filename): if self.device_args['force_gpu_idx'] == -1: - return str(self.model_path / (self.get_model_name() + '_' + filename)) + return str( self.model_path / ( self.get_model_name() + '_' + filename) ) else: return str(self.model_path / ( str(self.device_args['force_gpu_idx']) + '_' + self.get_model_name() + '_' + filename)) - def set_vram_batch_requirements(self, d): - # example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48} + def set_vram_batch_requirements (self, d): + #example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48} keys = [x for x in d.keys()] if self.device_config.cpu_only: @@ -630,63 +648,63 @@ class ModelBase(object): break if self.batch_size == 0: - self.batch_size = d[keys[-1]] + self.batch_size = d[ keys[-1] ] @staticmethod def get_loss_history_preview(loss_history, iter,batch_size, w, c): - loss_history = np.array(loss_history.copy()) + loss_history = np.array (loss_history.copy()) lh_height = 100 - lh_img = np.ones((lh_height, w, c)) * 0.1 - - if len(loss_history) != 0: + lh_img = np.ones ( (lh_height,w,c) ) * 0.1 + + if len(loss_history) != 0: loss_count = len(loss_history[0]) lh_len = len(loss_history) l_per_col = lh_len / w - plist_max = [[max(0.0, loss_history[int(col * l_per_col)][p], - *[loss_history[i_ab][p] - for i_ab in range(int(col * l_per_col), int((col + 1) * l_per_col)) - ] - ) - for p in range(loss_count) - ] - for col in range(w) - ] + plist_max = [ [ max (0.0, loss_history[int(col*l_per_col)][p], + *[ loss_history[i_ab][p] + for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) ) + ] + ) + for p in range(loss_count) + ] + for col in range(w) + ] - plist_min = [[min(plist_max[col][p], loss_history[int(col * l_per_col)][p], - *[loss_history[i_ab][p] - for i_ab in range(int(col * l_per_col), int((col + 1) * l_per_col)) - ] - ) - for p in range(loss_count) - ] - for col in range(w) - ] + plist_min = [ [ min (plist_max[col][p], loss_history[int(col*l_per_col)][p], + *[ loss_history[i_ab][p] + for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) ) + ] + ) + for p in range(loss_count) + ] + for col in range(w) + ] - plist_abs_max = np.mean(loss_history[len(loss_history) // 5:]) * 2 + plist_abs_max = np.mean(loss_history[ len(loss_history) // 5 : ]) * 2 for col in range(0, w): - for p in range(0, loss_count): - point_color = [1.0] * c - point_color[0:3] = colorsys.hsv_to_rgb(p * (1.0 / loss_count), 1.0, 1.0) + for p in range(0,loss_count): + point_color = [1.0]*c + point_color[0:3] = colorsys.hsv_to_rgb ( p * (1.0/loss_count), 1.0, 1.0 ) - ph_max = int((plist_max[col][p] / plist_abs_max) * (lh_height - 1)) - ph_max = np.clip(ph_max, 0, lh_height - 1) + ph_max = int ( (plist_max[col][p] / plist_abs_max) * (lh_height-1) ) + ph_max = np.clip( ph_max, 0, lh_height-1 ) - ph_min = int((plist_min[col][p] / plist_abs_max) * (lh_height - 1)) - ph_min = np.clip(ph_min, 0, lh_height - 1) + ph_min = int ( (plist_min[col][p] / plist_abs_max) * (lh_height-1) ) + ph_min = np.clip( ph_min, 0, lh_height-1 ) - for ph in range(ph_min, ph_max + 1): - lh_img[(lh_height - ph - 1), col] = point_color + for ph in range(ph_min, ph_max+1): + lh_img[ (lh_height-ph-1), col ] = point_color lh_lines = 5 - lh_line_height = (lh_height - 1) / lh_lines - for i in range(0, lh_lines + 1): - lh_img[int(i * lh_line_height), :] = (0.8,) * c + lh_line_height = (lh_height-1)/lh_lines + for i in range(0,lh_lines+1): + lh_img[ int(i*lh_line_height), : ] = (0.8,)*c - last_line_t = int((lh_lines - 1) * lh_line_height) - last_line_b = int(lh_lines * lh_line_height) + last_line_t = int((lh_lines-1)*lh_line_height) + last_line_b = int(lh_lines*lh_line_height) lh_text = 'Iter: %d' % iter if iter != 0 else '' bs_text = 'BS: %d' % batch_size if batch_size is not None else '1' diff --git a/utils/Path_utils.py b/utils/Path_utils.py index 0f03aeb..5c2b43f 100644 --- a/utils/Path_utils.py +++ b/utils/Path_utils.py @@ -7,7 +7,7 @@ IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff") def get_image_paths(dir_path: str, image_extensions: List[str] = IMAGE_EXTENSIONS) -> List[str]: - dir_path = Path(dir_path) + dir_path = Path (dir_path) result = [] if dir_path.exists(): @@ -26,41 +26,41 @@ def get_image_unique_filestem_paths(dir_path: str, verbose_print_func: Optional[ if f_stem in result_dup: result.remove(f) if verbose_print_func is not None: - verbose_print_func("Duplicate filenames are not allowed, skipping: %s" % Path(f).name) + verbose_print_func ("Duplicate filenames are not allowed, skipping: %s" % Path(f).name ) continue result_dup.add(f_stem) return sorted(result) - + def get_file_paths(dir_path: str) -> List[str]: dir_path = Path(dir_path) if dir_path.exists(): - return [x.path for x in scandir(str(dir_path)) if x.is_file()] + return sorted([x.path for x in scandir(str(dir_path)) if x.is_file()]) return [] def get_all_dir_names(dir_path: str) -> List[str]: dir_path = Path(dir_path) if dir_path.exists(): - return [x.name for x in scandir(str(dir_path)) if x.is_dir()] + return sorted([x.name for x in scandir(str(dir_path)) if x.is_dir()]) return [] def get_all_dir_names_startswith(dir_path: str, startswith: str) -> List[str]: - dir_path = Path(dir_path) + dir_path = Path (dir_path) startswith = startswith.lower() result = [] if dir_path.exists(): for x in scandir(str(dir_path)): if x.name.lower().startswith(startswith): - result.append(x.name[len(startswith):]) + result.append ( x.name[len(startswith):] ) return sorted(result) def get_first_file_by_stem(dir_path: str, stem: str, exts: List[str] = None) -> Optional[Path]: - dir_path = Path(dir_path) + dir_path = Path (dir_path) stem = stem.lower() if dir_path.exists(): @@ -78,8 +78,8 @@ def move_all_files(src_dir_path: str, dst_dir_path: str) -> None: paths = get_file_paths(src_dir_path) for p in paths: p = Path(p) - p.rename(Path(dst_dir_path) / p.name) - + p.rename ( Path(dst_dir_path) / p.name ) + def delete_all_files(dir_path: str) -> None: paths = get_file_paths(dir_path)