mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
Merge pull request #30 from faceshiftlabs/build/commits-from-upstream
Build/commits from upstream
This commit is contained in:
commit
35a17511c3
5 changed files with 210 additions and 186 deletions
14
README.md
14
README.md
|
@ -6,8 +6,6 @@
|
|||
|
||||
## **DeepFaceLab** is a tool that utilizes machine learning to replace faces in videos.
|
||||
|
||||
If you like this software, please consider a donation.
|
||||
|
||||
GOAL: next DeepFacelab update.
|
||||
|
||||
- ### [Gallery](doc/gallery/doc_gallery.md)
|
||||
|
@ -16,8 +14,6 @@ GOAL: next DeepFacelab update.
|
|||
|
||||
[English (google translated)](doc/manual_en_google_translated.pdf)
|
||||
|
||||
[На русском](doc/manual_ru.pdf)
|
||||
|
||||
- ### [Prebuilt windows app](doc/doc_prebuilt_windows_app.md)
|
||||
|
||||
- ### Forks
|
||||
|
@ -29,13 +25,3 @@ GOAL: next DeepFacelab update.
|
|||
- ### [Ready to work facesets](doc/doc_ready_to_work_facesets.md)
|
||||
|
||||
- ### [Build and repository info](doc/doc_build_and_repository_info.md)
|
||||
|
||||
- ### Communication groups:
|
||||
|
||||
(Chinese) QQ group 951138799 for ML/AI experts
|
||||
|
||||
[deepfakes (Chinese)](https://deepfakescn.com)
|
||||
|
||||
[deepfakes (Chinese) (outdated) ](https://deepfakes.com.cn/)
|
||||
|
||||
[reddit (English)](https://www.reddit.com/r/GifFakes/new/)
|
||||
|
|
|
@ -33,6 +33,7 @@ class InteractBase(object):
|
|||
self.key_events = {}
|
||||
self.pg_bar = None
|
||||
self.focus_wnd_name = None
|
||||
self.error_log_line_prefix = '/!\\ '
|
||||
|
||||
def is_support_windows(self):
|
||||
return False
|
||||
|
@ -65,10 +66,22 @@ class InteractBase(object):
|
|||
raise NotImplemented
|
||||
|
||||
def log_info(self, msg, end='\n'):
|
||||
if self.pg_bar is not None:
|
||||
try: # Attempt print before the pb
|
||||
tqdm.write(msg)
|
||||
return
|
||||
except:
|
||||
pass #Fallback to normal print upon failure
|
||||
print (msg, end=end)
|
||||
|
||||
def log_err(self, msg, end='\n'):
|
||||
print (msg, end=end)
|
||||
if self.pg_bar is not None:
|
||||
try: # Attempt print before the pb
|
||||
tqdm.write(f'{self.error_log_line_prefix}{msg}')
|
||||
return
|
||||
except:
|
||||
pass #Fallback to normal print upon failure
|
||||
print (f'{self.error_log_line_prefix}{msg}', end=end)
|
||||
|
||||
def named_window(self, wnd_name):
|
||||
if wnd_name not in self.named_windows:
|
||||
|
@ -150,9 +163,12 @@ class InteractBase(object):
|
|||
else: print("progress_bar not set.")
|
||||
|
||||
def progress_bar_generator(self, data, desc, leave=True):
|
||||
for x in tqdm( data, desc=desc, leave=leave, ascii=True ):
|
||||
self.pg_bar = tqdm( data, desc=desc, leave=leave, ascii=True )
|
||||
for x in self.pg_bar:
|
||||
yield x
|
||||
|
||||
self.pg_bar.close()
|
||||
self.pg_bar = None
|
||||
|
||||
def process_messages(self, sleep_time=0):
|
||||
self.on_process_messages(sleep_time)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ import cv2
|
|||
import models
|
||||
from interact import interact as io
|
||||
|
||||
def trainerThread (s2c, c2s, args, device_args):
|
||||
def trainerThread (s2c, c2s, e, args, device_args):
|
||||
while True:
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
@ -66,6 +66,7 @@ def trainerThread (s2c, c2s, args, device_args):
|
|||
else:
|
||||
previews = [( 'debug, press update for new', model.debug_one_iter())]
|
||||
c2s.put ( {'op':'show', 'previews': previews} )
|
||||
e.set() #Set the GUI Thread as Ready
|
||||
|
||||
|
||||
if model.is_first_run():
|
||||
|
@ -190,9 +191,12 @@ def main(args, device_args):
|
|||
s2c = queue.Queue()
|
||||
c2s = queue.Queue()
|
||||
|
||||
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, args, device_args) )
|
||||
e = threading.Event()
|
||||
thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e, args, device_args) )
|
||||
thread.start()
|
||||
|
||||
e.wait() #Wait for inital load to occur.
|
||||
|
||||
if no_preview:
|
||||
while True:
|
||||
if not c2s.empty():
|
||||
|
@ -294,7 +298,7 @@ def main(args, device_args):
|
|||
|
||||
key_events = io.get_key_events(wnd_name)
|
||||
key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False)
|
||||
|
||||
|
||||
if key == ord('\n') or key == ord('\r'):
|
||||
s2c.put ( {'op': 'close'} )
|
||||
elif key == ord('s'):
|
||||
|
@ -324,4 +328,4 @@ def main(args, device_args):
|
|||
except KeyboardInterrupt:
|
||||
s2c.put ( {'op': 'close'} )
|
||||
|
||||
io.destroy_all_windows()
|
||||
io.destroy_all_windows()
|
|
@ -26,22 +26,22 @@ class ModelBase(object):
|
|||
|
||||
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None,
|
||||
ask_enable_autobackup=True,
|
||||
ask_write_preview_history=True,
|
||||
ask_target_iter=True,
|
||||
ask_batch_size=True,
|
||||
ask_write_preview_history=True,
|
||||
ask_target_iter=True,
|
||||
ask_batch_size=True,
|
||||
ask_sort_by_yaw=True,
|
||||
ask_random_flip=True,
|
||||
ask_random_flip=True,
|
||||
ask_src_scale_mod=True):
|
||||
|
||||
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx', -1)
|
||||
device_args['cpu_only'] = device_args.get('cpu_only', False)
|
||||
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
|
||||
device_args['cpu_only'] = device_args.get('cpu_only',False)
|
||||
|
||||
if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
|
||||
idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
|
||||
if len(idxs_names_list) > 1:
|
||||
io.log_info("You have multi GPUs in a system: ")
|
||||
io.log_info ("You have multi GPUs in a system: ")
|
||||
for idx, name in idxs_names_list:
|
||||
io.log_info("[%d] : %s" % (idx, name))
|
||||
io.log_info ("[%d] : %s" % (idx, name) )
|
||||
|
||||
device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1,
|
||||
[x[0] for x in idxs_names_list])
|
||||
|
@ -49,15 +49,15 @@ class ModelBase(object):
|
|||
|
||||
self.device_config = nnlib.DeviceConfig(allow_growth=True, **self.device_args)
|
||||
|
||||
io.log_info("Loading model...")
|
||||
io.log_info ("Loading model...")
|
||||
|
||||
self.model_path = model_path
|
||||
self.model_data_path = Path(self.get_strpath_storage_for_file('data.dat'))
|
||||
self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') )
|
||||
|
||||
self.training_data_src_path = training_data_src_path
|
||||
self.training_data_dst_path = training_data_dst_path
|
||||
self.pretraining_data_path = pretraining_data_path
|
||||
|
||||
|
||||
self.src_images_paths = None
|
||||
self.dst_images_paths = None
|
||||
self.src_yaw_images_paths = None
|
||||
|
@ -76,8 +76,8 @@ class ModelBase(object):
|
|||
|
||||
model_data = {}
|
||||
if self.model_data_path.exists():
|
||||
model_data = pickle.loads(self.model_data_path.read_bytes())
|
||||
self.iter = max(model_data.get('iter', 0), model_data.get('epoch', 0))
|
||||
model_data = pickle.loads ( self.model_data_path.read_bytes() )
|
||||
self.iter = max( model_data.get('iter',0), model_data.get('epoch',0) )
|
||||
if 'epoch' in self.options:
|
||||
self.options.pop('epoch')
|
||||
if self.iter != 0:
|
||||
|
@ -89,13 +89,13 @@ class ModelBase(object):
|
|||
" override model settings.",
|
||||
5 if io.is_colab() else 2)
|
||||
|
||||
yn_str = {True: 'y', False: 'n'}
|
||||
yn_str = {True:'y',False:'n'}
|
||||
|
||||
if self.iter == 0:
|
||||
io.log_info("\nModel first run. Enter model options as default for each run.")
|
||||
io.log_info ("\nModel first run. Enter model options as default for each run.")
|
||||
|
||||
if ask_enable_autobackup and (self.iter == 0 or ask_override):
|
||||
default_autobackup = False if self.iter == 0 else self.options.get('autobackup', False)
|
||||
default_autobackup = False if self.iter == 0 else self.options.get('autobackup',False)
|
||||
self.options['autobackup'] = io.input_bool("Enable autobackup? (y/n ?:help skip:%s) : " %
|
||||
(yn_str[default_autobackup]), default_autobackup,
|
||||
help_message="Autobackup model files with preview every hour for"
|
||||
|
@ -128,17 +128,17 @@ class ModelBase(object):
|
|||
" new person")
|
||||
else:
|
||||
choose_preview_history = False
|
||||
|
||||
|
||||
if ask_target_iter:
|
||||
if (self.iter == 0 or ask_override):
|
||||
self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0))
|
||||
else:
|
||||
self.options['target_iter'] = max(model_data.get('target_iter', 0), self.options.get('target_epoch', 0))
|
||||
self.options['target_iter'] = max(model_data.get('target_iter',0), self.options.get('target_epoch',0))
|
||||
if 'target_epoch' in self.options:
|
||||
self.options.pop('target_epoch')
|
||||
|
||||
if ask_batch_size and (self.iter == 0 or ask_override):
|
||||
default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size', 0)
|
||||
default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0)
|
||||
self.options['batch_cap'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % self.options.get('batch_cap', 16),self.options.get('batch_cap', 16),
|
||||
help_message="Larger batch size is better for NN's"
|
||||
" generalization, but it can cause Out of"
|
||||
|
@ -157,7 +157,7 @@ class ModelBase(object):
|
|||
self.options['ping_pong'] = self.options.get('ping_pong', False)
|
||||
self.options['ping_pong_iter'] = self.options.get('ping_pong_iter',1000)
|
||||
|
||||
if ask_sort_by_yaw:
|
||||
if ask_sort_by_yaw:
|
||||
if (self.iter == 0 or ask_override):
|
||||
default_sort_by_yaw = self.options.get('sort_by_yaw', False)
|
||||
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s):"
|
||||
|
@ -182,7 +182,7 @@ class ModelBase(object):
|
|||
" get a better result."), -30, 30)
|
||||
else:
|
||||
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
||||
|
||||
|
||||
self.autobackup = self.options.get('autobackup', False)
|
||||
if not self.autobackup and 'autobackup' in self.options:
|
||||
self.options.pop('autobackup')
|
||||
|
@ -191,20 +191,20 @@ class ModelBase(object):
|
|||
if not self.write_preview_history and 'write_preview_history' in self.options:
|
||||
self.options.pop('write_preview_history')
|
||||
|
||||
self.target_iter = self.options.get('target_iter', 0)
|
||||
self.target_iter = self.options.get('target_iter',0)
|
||||
if self.target_iter == 0 and 'target_iter' in self.options:
|
||||
self.options.pop('target_iter')
|
||||
|
||||
self.batch_size = self.options.get('batch_size', 8)
|
||||
self.batch_cap = self.options.get('batch_cap',16)
|
||||
self.ping_pong_iter = self.options.get('ping_pong_iter',1000)
|
||||
self.sort_by_yaw = self.options.get('sort_by_yaw', False)
|
||||
self.random_flip = self.options.get('random_flip', True)
|
||||
self.sort_by_yaw = self.options.get('sort_by_yaw',False)
|
||||
self.random_flip = self.options.get('random_flip',True)
|
||||
if self.batch_cap == 0:
|
||||
self.options['batch_cap'] = self.batch_size
|
||||
self.batch_cap = self.options.get('batch_cap',16)
|
||||
|
||||
self.src_scale_mod = self.options.get('src_scale_mod', 0)
|
||||
self.src_scale_mod = self.options.get('src_scale_mod',0)
|
||||
if self.src_scale_mod == 0 and 'src_scale_mod' in self.options:
|
||||
self.options.pop('src_scale_mod')
|
||||
|
||||
|
@ -224,17 +224,17 @@ class ModelBase(object):
|
|||
|
||||
if self.is_training_mode:
|
||||
if self.device_args['force_gpu_idx'] == -1:
|
||||
self.preview_history_path = self.model_path / ('%s_history' % (self.get_model_name()))
|
||||
self.autobackups_path = self.model_path / ('%s_autobackups' % (self.get_model_name()))
|
||||
self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) )
|
||||
self.autobackups_path = self.model_path / ( '%s_autobackups' % (self.get_model_name()) )
|
||||
else:
|
||||
self.preview_history_path = self.model_path / ('%d_%s_history' % (self.device_args['force_gpu_idx'],
|
||||
self.get_model_name()))
|
||||
self.autobackups_path = self.model_path / ('%d_%s_autobackups' % (self.device_args['force_gpu_idx'],
|
||||
self.get_model_name()))
|
||||
|
||||
|
||||
if self.autobackup:
|
||||
self.autobackup_current_hour = time.localtime().tm_hour
|
||||
|
||||
|
||||
if not self.autobackups_path.exists():
|
||||
self.autobackups_path.mkdir(exist_ok=True)
|
||||
|
||||
|
@ -247,13 +247,13 @@ class ModelBase(object):
|
|||
Path(filename).unlink()
|
||||
|
||||
if self.generator_list is None:
|
||||
raise ValueError('You didnt set_training_data_generators()')
|
||||
raise ValueError( 'You didnt set_training_data_generators()')
|
||||
else:
|
||||
for i, generator in enumerate(self.generator_list):
|
||||
if not isinstance(generator, SampleGeneratorBase):
|
||||
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
|
||||
|
||||
if self.sample_for_preview is None or choose_preview_history:
|
||||
if self.sample_for_preview is None or choose_preview_history:
|
||||
if choose_preview_history and io.is_support_windows():
|
||||
wnd_name = "[p] - next. [enter] - confirm."
|
||||
io.named_window(wnd_name)
|
||||
|
@ -262,7 +262,7 @@ class ModelBase(object):
|
|||
while not choosed:
|
||||
self.sample_for_preview = self.generate_next_sample()
|
||||
preview = self.get_static_preview()
|
||||
io.show_image(wnd_name, (preview * 255).astype(np.uint8))
|
||||
io.show_image( wnd_name, (preview*255).astype(np.uint8) )
|
||||
|
||||
while True:
|
||||
key_events = io.get_key_events(wnd_name)
|
||||
|
@ -273,54 +273,72 @@ class ModelBase(object):
|
|||
break
|
||||
elif key == ord('p'):
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
io.process_messages(0.1)
|
||||
except KeyboardInterrupt:
|
||||
choosed = True
|
||||
|
||||
|
||||
io.destroy_window(wnd_name)
|
||||
else:
|
||||
self.sample_for_preview = self.generate_next_sample()
|
||||
else:
|
||||
self.sample_for_preview = self.generate_next_sample()
|
||||
self.last_sample = self.sample_for_preview
|
||||
|
||||
###Generate text summary of model hyperparameters
|
||||
#Find the longest key name and value string. Used as column widths.
|
||||
width_name = max([len(k) for k in self.options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration"
|
||||
width_value = max([len(str(x)) for x in self.options.values()] + [len(str(self.iter)), len(self.get_model_name())]) + 1 # Single space buffer to right edge
|
||||
if not self.device_config.cpu_only: #Check length of GPU names
|
||||
width_value = max([len(nnlib.device.getDeviceName(idx))+1 for idx in self.device_config.gpu_idxs] + [width_value])
|
||||
width_total = width_name + width_value + 2 #Plus 2 for ": "
|
||||
|
||||
model_summary_text = []
|
||||
|
||||
model_summary_text += ["===== Model summary ====="]
|
||||
model_summary_text += ["== Model name: " + self.get_model_name()]
|
||||
model_summary_text += ["=="]
|
||||
model_summary_text += ["== Current iteration: " + str(self.iter)]
|
||||
model_summary_text += ["=="]
|
||||
model_summary_text += ["== Model options:"]
|
||||
model_summary_text += [f'=={" Model Summary ":=^{width_total}}=='] # Model/status summary
|
||||
model_summary_text += [f'=={" "*width_total}==']
|
||||
model_summary_text += [f'=={"Model name": >{width_name}}: {self.get_model_name(): <{width_value}}=='] # Name
|
||||
model_summary_text += [f'=={" "*width_total}==']
|
||||
model_summary_text += [f'=={"Current iteration": >{width_name}}: {str(self.iter): <{width_value}}=='] # Iter
|
||||
model_summary_text += [f'=={" "*width_total}==']
|
||||
|
||||
model_summary_text += [f'=={" Model Options ":-^{width_total}}=='] # Model options
|
||||
model_summary_text += [f'=={" "*width_total}==']
|
||||
for key in self.options.keys():
|
||||
model_summary_text += ["== |== %s : %s" % (key, self.options[key])]
|
||||
|
||||
model_summary_text += [f'=={key: >{width_name}}: {str(self.options[key]): <{width_value}}=='] # self.options key/value pairs
|
||||
model_summary_text += [f'=={" "*width_total}==']
|
||||
|
||||
model_summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info
|
||||
model_summary_text += [f'=={" "*width_total}==']
|
||||
if self.device_config.multi_gpu:
|
||||
model_summary_text += ["== |== multi_gpu : True "]
|
||||
|
||||
model_summary_text += ["== Running on:"]
|
||||
model_summary_text += [f'=={"Using multi_gpu": >{width_name}}: {"True": <{width_value}}=='] # multi_gpu
|
||||
model_summary_text += [f'=={" "*width_total}==']
|
||||
if self.device_config.cpu_only:
|
||||
model_summary_text += ["== |== [CPU]"]
|
||||
model_summary_text += [f'=={"Using device": >{width_name}}: {"CPU": <{width_value}}=='] # cpu_only
|
||||
else:
|
||||
for idx in self.device_config.gpu_idxs:
|
||||
model_summary_text += ["== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx))]
|
||||
|
||||
if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[0] == 2:
|
||||
model_summary_text += ["=="]
|
||||
model_summary_text += ["== WARNING: You are using 2GB GPU. Result quality may be significantly decreased."]
|
||||
model_summary_text += ["== If training does not start, close all programs and try again."]
|
||||
model_summary_text += ["== Also you can disable Windows Aero Desktop to get extra free VRAM."]
|
||||
model_summary_text += ["=="]
|
||||
|
||||
model_summary_text += ["========================="]
|
||||
model_summary_text = "\r\n".join(model_summary_text)
|
||||
model_summary_text += [f'=={"Device index": >{width_name}}: {idx: <{width_value}}=='] # GPU hardware device index
|
||||
model_summary_text += [f'=={"Name": >{width_name}}: {nnlib.device.getDeviceName(idx): <{width_value}}=='] # GPU name
|
||||
vram_str = f'{nnlib.device.getDeviceVRAMTotalGb(idx):.2f}GB' # GPU VRAM - Formated as #.## (or ##.##)
|
||||
model_summary_text += [f'=={"VRAM": >{width_name}}: {vram_str: <{width_value}}==']
|
||||
model_summary_text += [f'=={" "*width_total}==']
|
||||
model_summary_text += [f'=={"="*width_total}==']
|
||||
|
||||
if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[0] <= 2: # Low VRAM warning
|
||||
model_summary_text += ["/!\\"]
|
||||
model_summary_text += ["/!\\ WARNING:"]
|
||||
model_summary_text += ["/!\\ You are using a GPU with 2GB or less VRAM. This may significantly reduce the quality of your result!"]
|
||||
model_summary_text += ["/!\\ If training does not start, close all programs and try again."]
|
||||
model_summary_text += ["/!\\ Also you can disable Windows Aero Desktop to increase available VRAM."]
|
||||
model_summary_text += ["/!\\"]
|
||||
|
||||
model_summary_text = "\n".join (model_summary_text)
|
||||
self.model_summary_text = model_summary_text
|
||||
io.log_info(model_summary_text)
|
||||
|
||||
# overridable
|
||||
#overridable
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
pass
|
||||
|
||||
# overridable
|
||||
#overridable
|
||||
def onInitialize(self):
|
||||
'''
|
||||
initialize your keras models
|
||||
|
@ -331,36 +349,36 @@ class ModelBase(object):
|
|||
'''
|
||||
pass
|
||||
|
||||
# overridable
|
||||
#overridable
|
||||
def onSave(self):
|
||||
# save your keras models here
|
||||
#save your keras models here
|
||||
pass
|
||||
|
||||
# overridable
|
||||
#overridable
|
||||
def onTrainOneIter(self, sample, generator_list):
|
||||
# train your keras models here
|
||||
#train your keras models here
|
||||
|
||||
# return array of losses
|
||||
return (('loss_src', 0), ('loss_dst', 0))
|
||||
#return array of losses
|
||||
return ( ('loss_src', 0), ('loss_dst', 0) )
|
||||
|
||||
# overridable
|
||||
#overridable
|
||||
def onGetPreview(self, sample):
|
||||
# you can return multiple previews
|
||||
# return [ ('preview_name',preview_rgb), ... ]
|
||||
#you can return multiple previews
|
||||
#return [ ('preview_name',preview_rgb), ... ]
|
||||
return []
|
||||
|
||||
# overridable if you want model name differs from folder name
|
||||
#overridable if you want model name differs from folder name
|
||||
def get_model_name(self):
|
||||
return Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1]
|
||||
|
||||
# overridable , return [ [model, filename],... ] list
|
||||
#overridable , return [ [model, filename],... ] list
|
||||
def get_model_filename_list(self):
|
||||
return []
|
||||
|
||||
# overridable
|
||||
#overridable
|
||||
def get_converter(self):
|
||||
raise NotImplementedError
|
||||
# return existing or your own converter which derived from base
|
||||
#return existing or your own converter which derived from base
|
||||
|
||||
def get_target_iter(self):
|
||||
return self.target_iter
|
||||
|
@ -368,8 +386,8 @@ class ModelBase(object):
|
|||
def is_reached_iter_goal(self):
|
||||
return self.target_iter != 0 and self.iter >= self.target_iter
|
||||
|
||||
# multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976
|
||||
# def to_multi_gpu_model_if_possible (self, models_list):
|
||||
#multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976
|
||||
#def to_multi_gpu_model_if_possible (self, models_list):
|
||||
# if len(self.device_config.gpu_idxs) > 1:
|
||||
# #make batch_size to divide on GPU count without remainder
|
||||
# self.batch_size = int( self.batch_size / len(self.device_config.gpu_idxs) )
|
||||
|
@ -388,65 +406,65 @@ class ModelBase(object):
|
|||
# return models_list
|
||||
|
||||
def get_previews(self):
|
||||
return self.onGetPreview(self.last_sample)
|
||||
return self.onGetPreview ( self.last_sample )
|
||||
|
||||
def get_static_preview(self):
|
||||
return self.onGetPreview(self.sample_for_preview)[0][1] # first preview, and bgr
|
||||
return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr
|
||||
|
||||
def save(self):
|
||||
self.options['batch_size'] = self.batch_size
|
||||
self.options['paddle'] = self.paddle
|
||||
summary_path = self.get_strpath_storage_for_file('summary.txt')
|
||||
Path(summary_path).write_text(self.model_summary_text)
|
||||
Path( summary_path ).write_text(self.model_summary_text)
|
||||
self.onSave()
|
||||
|
||||
model_data = {
|
||||
'iter': self.iter,
|
||||
'options': self.options,
|
||||
'loss_history': self.loss_history,
|
||||
'sample_for_preview': self.sample_for_preview
|
||||
'sample_for_preview' : self.sample_for_preview
|
||||
}
|
||||
self.model_data_path.write_bytes(pickle.dumps(model_data))
|
||||
self.model_data_path.write_bytes( pickle.dumps(model_data) )
|
||||
|
||||
bckp_filename_list = [self.get_strpath_storage_for_file(filename) for _, filename in
|
||||
self.get_model_filename_list()]
|
||||
bckp_filename_list += [str(summary_path), str(self.model_data_path)]
|
||||
|
||||
bckp_filename_list += [ str(summary_path), str(self.model_data_path) ]
|
||||
|
||||
if self.autobackup:
|
||||
current_hour = time.localtime().tm_hour
|
||||
if self.autobackup_current_hour != current_hour:
|
||||
self.autobackup_current_hour = current_hour
|
||||
|
||||
for i in range(15, 0, -1):
|
||||
for i in range(15,0,-1):
|
||||
idx_str = '%.2d' % i
|
||||
next_idx_str = '%.2d' % (i + 1)
|
||||
|
||||
next_idx_str = '%.2d' % (i+1)
|
||||
|
||||
idx_backup_path = self.autobackups_path / idx_str
|
||||
next_idx_packup_path = self.autobackups_path / next_idx_str
|
||||
|
||||
|
||||
if idx_backup_path.exists():
|
||||
if i == 15:
|
||||
if i == 15:
|
||||
Path_utils.delete_all_files(idx_backup_path)
|
||||
else:
|
||||
next_idx_packup_path.mkdir(exist_ok=True)
|
||||
Path_utils.move_all_files(idx_backup_path, next_idx_packup_path)
|
||||
|
||||
Path_utils.move_all_files (idx_backup_path, next_idx_packup_path)
|
||||
|
||||
if i == 1:
|
||||
idx_backup_path.mkdir(exist_ok=True)
|
||||
for filename in bckp_filename_list:
|
||||
shutil.copy(str(filename), str(idx_backup_path / Path(filename).name))
|
||||
idx_backup_path.mkdir(exist_ok=True)
|
||||
for filename in bckp_filename_list:
|
||||
shutil.copy ( str(filename), str(idx_backup_path / Path(filename).name) )
|
||||
|
||||
previews = self.get_previews()
|
||||
plist = []
|
||||
for i in range(len(previews)):
|
||||
name, bgr = previews[i]
|
||||
plist += [(bgr, idx_backup_path / (('preview_%s.jpg') % (name)))]
|
||||
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]
|
||||
|
||||
for preview, filepath in plist:
|
||||
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter,self.batch_size,
|
||||
preview.shape[1], preview.shape[2])
|
||||
img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8)
|
||||
cv2_imwrite(filepath, img)
|
||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||
cv2_imwrite (filepath, img )
|
||||
|
||||
def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
|
||||
for model, filename in model_filename_list:
|
||||
|
@ -469,15 +487,15 @@ class ModelBase(object):
|
|||
opt.set_weights(weights)
|
||||
print("set ok")
|
||||
except Exception as e:
|
||||
print("Unable to load ", opt_filename)
|
||||
print ("Unable to load ", opt_filename)
|
||||
|
||||
def save_weights_safe(self, model_filename_list):
|
||||
for model, filename in model_filename_list:
|
||||
filename = self.get_strpath_storage_for_file(filename)
|
||||
model.save_weights(filename + '.tmp')
|
||||
model.save_weights( filename + '.tmp' )
|
||||
|
||||
rename_list = model_filename_list
|
||||
|
||||
|
||||
"""
|
||||
#unused
|
||||
, optimizer_filename_list=[]
|
||||
|
@ -501,24 +519,24 @@ class ModelBase(object):
|
|||
except Exception as e:
|
||||
print ("Unable to save ", opt_filename)
|
||||
"""
|
||||
|
||||
|
||||
for _, filename in rename_list:
|
||||
filename = self.get_strpath_storage_for_file(filename)
|
||||
source_filename = Path(filename + '.tmp')
|
||||
source_filename = Path(filename+'.tmp')
|
||||
if source_filename.exists():
|
||||
target_filename = Path(filename)
|
||||
if target_filename.exists():
|
||||
target_filename.unlink()
|
||||
source_filename.rename(str(target_filename))
|
||||
|
||||
source_filename.rename ( str(target_filename) )
|
||||
|
||||
def debug_one_iter(self):
|
||||
images = []
|
||||
for generator in self.generator_list:
|
||||
for i, batch in enumerate(next(generator)):
|
||||
for i,batch in enumerate(next(generator)):
|
||||
if len(batch.shape) == 4:
|
||||
images.append(batch[0])
|
||||
images.append( batch[0] )
|
||||
|
||||
return imagelib.equalize_and_stack_square(images)
|
||||
return imagelib.equalize_and_stack_square (images)
|
||||
|
||||
def generate_next_sample(self):
|
||||
return [ generator.generate_next() for generator in self.generator_list]
|
||||
|
@ -536,7 +554,7 @@ class ModelBase(object):
|
|||
iter_time = time.time() - iter_time
|
||||
self.last_sample = sample
|
||||
|
||||
self.loss_history.append([float(loss[1]) for loss in losses])
|
||||
self.loss_history.append ( [float(loss[1]) for loss in losses] )
|
||||
|
||||
if self.iter % 10 == 0:
|
||||
plist = []
|
||||
|
@ -545,16 +563,16 @@ class ModelBase(object):
|
|||
previews = self.get_previews()
|
||||
for i in range(len(previews)):
|
||||
name, bgr = previews[i]
|
||||
plist += [(bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name)))]
|
||||
plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ]
|
||||
|
||||
if self.write_preview_history:
|
||||
plist += [(self.get_static_preview(), str(self.preview_history_path / ('%.6d.jpg' % (self.iter))))]
|
||||
plist += [ (self.get_static_preview(), str (self.preview_history_path / ('%.6d.jpg' % (self.iter))) ) ]
|
||||
|
||||
for preview, filepath in plist:
|
||||
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter,self.batch_size, preview.shape[1],
|
||||
preview.shape[2])
|
||||
img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8)
|
||||
cv2_imwrite(filepath, img)
|
||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||
cv2_imwrite (filepath, img )
|
||||
|
||||
if self.iter % self.ping_pong_iter == 0 and self.iter != 0 and self.options.get('ping_pong', False):
|
||||
if self.batch_size == self.batch_cap:
|
||||
|
@ -599,10 +617,10 @@ class ModelBase(object):
|
|||
def get_loss_history(self):
|
||||
return self.loss_history
|
||||
|
||||
def set_training_data_generators(self, generator_list):
|
||||
def set_training_data_generators (self, generator_list):
|
||||
self.generator_list = generator_list
|
||||
|
||||
def get_training_data_generators(self):
|
||||
def get_training_data_generators (self):
|
||||
return self.generator_list
|
||||
|
||||
def get_model_root_path(self):
|
||||
|
@ -610,13 +628,13 @@ class ModelBase(object):
|
|||
|
||||
def get_strpath_storage_for_file(self, filename):
|
||||
if self.device_args['force_gpu_idx'] == -1:
|
||||
return str(self.model_path / (self.get_model_name() + '_' + filename))
|
||||
return str( self.model_path / ( self.get_model_name() + '_' + filename) )
|
||||
else:
|
||||
return str(self.model_path / (
|
||||
str(self.device_args['force_gpu_idx']) + '_' + self.get_model_name() + '_' + filename))
|
||||
|
||||
def set_vram_batch_requirements(self, d):
|
||||
# example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48}
|
||||
def set_vram_batch_requirements (self, d):
|
||||
#example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48}
|
||||
keys = [x for x in d.keys()]
|
||||
|
||||
if self.device_config.cpu_only:
|
||||
|
@ -630,63 +648,63 @@ class ModelBase(object):
|
|||
break
|
||||
|
||||
if self.batch_size == 0:
|
||||
self.batch_size = d[keys[-1]]
|
||||
self.batch_size = d[ keys[-1] ]
|
||||
|
||||
@staticmethod
|
||||
def get_loss_history_preview(loss_history, iter,batch_size, w, c):
|
||||
loss_history = np.array(loss_history.copy())
|
||||
loss_history = np.array (loss_history.copy())
|
||||
|
||||
lh_height = 100
|
||||
lh_img = np.ones((lh_height, w, c)) * 0.1
|
||||
|
||||
if len(loss_history) != 0:
|
||||
lh_img = np.ones ( (lh_height,w,c) ) * 0.1
|
||||
|
||||
if len(loss_history) != 0:
|
||||
loss_count = len(loss_history[0])
|
||||
lh_len = len(loss_history)
|
||||
|
||||
l_per_col = lh_len / w
|
||||
plist_max = [[max(0.0, loss_history[int(col * l_per_col)][p],
|
||||
*[loss_history[i_ab][p]
|
||||
for i_ab in range(int(col * l_per_col), int((col + 1) * l_per_col))
|
||||
]
|
||||
)
|
||||
for p in range(loss_count)
|
||||
]
|
||||
for col in range(w)
|
||||
]
|
||||
plist_max = [ [ max (0.0, loss_history[int(col*l_per_col)][p],
|
||||
*[ loss_history[i_ab][p]
|
||||
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
|
||||
]
|
||||
)
|
||||
for p in range(loss_count)
|
||||
]
|
||||
for col in range(w)
|
||||
]
|
||||
|
||||
plist_min = [[min(plist_max[col][p], loss_history[int(col * l_per_col)][p],
|
||||
*[loss_history[i_ab][p]
|
||||
for i_ab in range(int(col * l_per_col), int((col + 1) * l_per_col))
|
||||
]
|
||||
)
|
||||
for p in range(loss_count)
|
||||
]
|
||||
for col in range(w)
|
||||
]
|
||||
plist_min = [ [ min (plist_max[col][p], loss_history[int(col*l_per_col)][p],
|
||||
*[ loss_history[i_ab][p]
|
||||
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
|
||||
]
|
||||
)
|
||||
for p in range(loss_count)
|
||||
]
|
||||
for col in range(w)
|
||||
]
|
||||
|
||||
plist_abs_max = np.mean(loss_history[len(loss_history) // 5:]) * 2
|
||||
plist_abs_max = np.mean(loss_history[ len(loss_history) // 5 : ]) * 2
|
||||
|
||||
for col in range(0, w):
|
||||
for p in range(0, loss_count):
|
||||
point_color = [1.0] * c
|
||||
point_color[0:3] = colorsys.hsv_to_rgb(p * (1.0 / loss_count), 1.0, 1.0)
|
||||
for p in range(0,loss_count):
|
||||
point_color = [1.0]*c
|
||||
point_color[0:3] = colorsys.hsv_to_rgb ( p * (1.0/loss_count), 1.0, 1.0 )
|
||||
|
||||
ph_max = int((plist_max[col][p] / plist_abs_max) * (lh_height - 1))
|
||||
ph_max = np.clip(ph_max, 0, lh_height - 1)
|
||||
ph_max = int ( (plist_max[col][p] / plist_abs_max) * (lh_height-1) )
|
||||
ph_max = np.clip( ph_max, 0, lh_height-1 )
|
||||
|
||||
ph_min = int((plist_min[col][p] / plist_abs_max) * (lh_height - 1))
|
||||
ph_min = np.clip(ph_min, 0, lh_height - 1)
|
||||
ph_min = int ( (plist_min[col][p] / plist_abs_max) * (lh_height-1) )
|
||||
ph_min = np.clip( ph_min, 0, lh_height-1 )
|
||||
|
||||
for ph in range(ph_min, ph_max + 1):
|
||||
lh_img[(lh_height - ph - 1), col] = point_color
|
||||
for ph in range(ph_min, ph_max+1):
|
||||
lh_img[ (lh_height-ph-1), col ] = point_color
|
||||
|
||||
lh_lines = 5
|
||||
lh_line_height = (lh_height - 1) / lh_lines
|
||||
for i in range(0, lh_lines + 1):
|
||||
lh_img[int(i * lh_line_height), :] = (0.8,) * c
|
||||
lh_line_height = (lh_height-1)/lh_lines
|
||||
for i in range(0,lh_lines+1):
|
||||
lh_img[ int(i*lh_line_height), : ] = (0.8,)*c
|
||||
|
||||
last_line_t = int((lh_lines - 1) * lh_line_height)
|
||||
last_line_b = int(lh_lines * lh_line_height)
|
||||
last_line_t = int((lh_lines-1)*lh_line_height)
|
||||
last_line_b = int(lh_lines*lh_line_height)
|
||||
|
||||
lh_text = 'Iter: %d' % iter if iter != 0 else ''
|
||||
bs_text = 'BS: %d' % batch_size if batch_size is not None else '1'
|
||||
|
|
|
@ -7,7 +7,7 @@ IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff")
|
|||
|
||||
|
||||
def get_image_paths(dir_path: str, image_extensions: List[str] = IMAGE_EXTENSIONS) -> List[str]:
|
||||
dir_path = Path(dir_path)
|
||||
dir_path = Path (dir_path)
|
||||
|
||||
result = []
|
||||
if dir_path.exists():
|
||||
|
@ -26,41 +26,41 @@ def get_image_unique_filestem_paths(dir_path: str, verbose_print_func: Optional[
|
|||
if f_stem in result_dup:
|
||||
result.remove(f)
|
||||
if verbose_print_func is not None:
|
||||
verbose_print_func("Duplicate filenames are not allowed, skipping: %s" % Path(f).name)
|
||||
verbose_print_func ("Duplicate filenames are not allowed, skipping: %s" % Path(f).name )
|
||||
continue
|
||||
result_dup.add(f_stem)
|
||||
|
||||
return sorted(result)
|
||||
|
||||
|
||||
|
||||
def get_file_paths(dir_path: str) -> List[str]:
|
||||
dir_path = Path(dir_path)
|
||||
if dir_path.exists():
|
||||
return [x.path for x in scandir(str(dir_path)) if x.is_file()]
|
||||
return sorted([x.path for x in scandir(str(dir_path)) if x.is_file()])
|
||||
return []
|
||||
|
||||
|
||||
def get_all_dir_names(dir_path: str) -> List[str]:
|
||||
dir_path = Path(dir_path)
|
||||
if dir_path.exists():
|
||||
return [x.name for x in scandir(str(dir_path)) if x.is_dir()]
|
||||
return sorted([x.name for x in scandir(str(dir_path)) if x.is_dir()])
|
||||
return []
|
||||
|
||||
|
||||
def get_all_dir_names_startswith(dir_path: str, startswith: str) -> List[str]:
|
||||
dir_path = Path(dir_path)
|
||||
dir_path = Path (dir_path)
|
||||
startswith = startswith.lower()
|
||||
|
||||
result = []
|
||||
if dir_path.exists():
|
||||
for x in scandir(str(dir_path)):
|
||||
if x.name.lower().startswith(startswith):
|
||||
result.append(x.name[len(startswith):])
|
||||
result.append ( x.name[len(startswith):] )
|
||||
return sorted(result)
|
||||
|
||||
|
||||
def get_first_file_by_stem(dir_path: str, stem: str, exts: List[str] = None) -> Optional[Path]:
|
||||
dir_path = Path(dir_path)
|
||||
dir_path = Path (dir_path)
|
||||
stem = stem.lower()
|
||||
|
||||
if dir_path.exists():
|
||||
|
@ -78,8 +78,8 @@ def move_all_files(src_dir_path: str, dst_dir_path: str) -> None:
|
|||
paths = get_file_paths(src_dir_path)
|
||||
for p in paths:
|
||||
p = Path(p)
|
||||
p.rename(Path(dst_dir_path) / p.name)
|
||||
|
||||
p.rename ( Path(dst_dir_path) / p.name )
|
||||
|
||||
|
||||
def delete_all_files(dir_path: str) -> None:
|
||||
paths = get_file_paths(dir_path)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue