mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
Trainer: fixed "Choose image for the preview history". Now you can switch between subpreviews using 'space' key.
Fixed "Write preview history". Now it writes all subpreviews in separated folders https://i.imgur.com/IszifCJ.jpg also the last preview saved as _last.jpg before the first file https://i.imgur.com/Ls1AOK4.jpg thus you can easily check the changes with the first file in photo viewer
This commit is contained in:
parent
5c315cab68
commit
82f405ed49
1 changed files with 35 additions and 23 deletions
|
@ -61,7 +61,7 @@ class ModelBase(object):
|
||||||
# sort by modified datetime
|
# sort by modified datetime
|
||||||
saved_models_names = sorted(saved_models_names, key=operator.itemgetter(1), reverse=True )
|
saved_models_names = sorted(saved_models_names, key=operator.itemgetter(1), reverse=True )
|
||||||
saved_models_names = [ x[0] for x in saved_models_names ]
|
saved_models_names = [ x[0] for x in saved_models_names ]
|
||||||
|
|
||||||
|
|
||||||
if len(saved_models_names) != 0:
|
if len(saved_models_names) != 0:
|
||||||
if silent_start:
|
if silent_start:
|
||||||
|
@ -125,7 +125,7 @@ class ModelBase(object):
|
||||||
self.model_name = self.model_name.replace('_', ' ')
|
self.model_name = self.model_name.replace('_', ' ')
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
self.model_name = self.model_name + '_' + self.model_class_name
|
self.model_name = self.model_name + '_' + self.model_class_name
|
||||||
else:
|
else:
|
||||||
self.model_name = force_model_class_name
|
self.model_name = force_model_class_name
|
||||||
|
@ -150,10 +150,10 @@ class ModelBase(object):
|
||||||
|
|
||||||
if self.is_first_run():
|
if self.is_first_run():
|
||||||
io.log_info ("\nModel first run.")
|
io.log_info ("\nModel first run.")
|
||||||
|
|
||||||
if silent_start:
|
if silent_start:
|
||||||
self.device_config = nn.DeviceConfig.BestGPU()
|
self.device_config = nn.DeviceConfig.BestGPU()
|
||||||
io.log_info (f"Silent start: choosed device {'CPU' if self.device_config.cpu_only else self.device_config.devices[0].name}")
|
io.log_info (f"Silent start: choosed device {'CPU' if self.device_config.cpu_only else self.device_config.devices[0].name}")
|
||||||
else:
|
else:
|
||||||
self.device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_best_multi_gpu=True)) \
|
self.device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_best_multi_gpu=True)) \
|
||||||
if not cpu_only else nn.DeviceConfig.CPU()
|
if not cpu_only else nn.DeviceConfig.CPU()
|
||||||
|
@ -208,7 +208,7 @@ class ModelBase(object):
|
||||||
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
|
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
|
||||||
|
|
||||||
self.update_sample_for_preview(choose_preview_history=self.choose_preview_history)
|
self.update_sample_for_preview(choose_preview_history=self.choose_preview_history)
|
||||||
|
|
||||||
if self.autobackup_hour != 0:
|
if self.autobackup_hour != 0:
|
||||||
self.autobackup_start_time = time.time()
|
self.autobackup_start_time = time.time()
|
||||||
|
|
||||||
|
@ -216,19 +216,21 @@ class ModelBase(object):
|
||||||
self.autobackups_path.mkdir(exist_ok=True)
|
self.autobackups_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
io.log_info( self.get_summary_text() )
|
io.log_info( self.get_summary_text() )
|
||||||
|
|
||||||
def update_sample_for_preview(self, choose_preview_history=False, force_new=False):
|
def update_sample_for_preview(self, choose_preview_history=False, force_new=False):
|
||||||
if self.sample_for_preview is None or choose_preview_history or force_new:
|
if self.sample_for_preview is None or choose_preview_history or force_new:
|
||||||
if choose_preview_history and io.is_support_windows():
|
if choose_preview_history and io.is_support_windows():
|
||||||
io.log_info ("Choose image for the preview history. [p] - next. [enter] - confirm.")
|
wnd_name = "[p] - next. [space] - switch preview type. [enter] - confirm."
|
||||||
wnd_name = "[p] - next. [enter] - confirm."
|
io.log_info (f"Choose image for the preview history. {wnd_name}")
|
||||||
io.named_window(wnd_name)
|
io.named_window(wnd_name)
|
||||||
io.capture_keys(wnd_name)
|
io.capture_keys(wnd_name)
|
||||||
choosed = False
|
choosed = False
|
||||||
|
preview_id_counter = 0
|
||||||
while not choosed:
|
while not choosed:
|
||||||
self.sample_for_preview = self.generate_next_samples()
|
self.sample_for_preview = self.generate_next_samples()
|
||||||
preview = self.get_static_preview()
|
previews = self.get_static_previews()
|
||||||
io.show_image( wnd_name, (preview*255).astype(np.uint8) )
|
|
||||||
|
io.show_image( wnd_name, ( previews[preview_id_counter % len(previews) ][1] *255).astype(np.uint8) )
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
key_events = io.get_key_events(wnd_name)
|
key_events = io.get_key_events(wnd_name)
|
||||||
|
@ -236,6 +238,9 @@ class ModelBase(object):
|
||||||
if key == ord('\n') or key == ord('\r'):
|
if key == ord('\n') or key == ord('\r'):
|
||||||
choosed = True
|
choosed = True
|
||||||
break
|
break
|
||||||
|
elif key == ord(' '):
|
||||||
|
preview_id_counter += 1
|
||||||
|
break
|
||||||
elif key == ord('p'):
|
elif key == ord('p'):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -249,12 +254,12 @@ class ModelBase(object):
|
||||||
self.sample_for_preview = self.generate_next_samples()
|
self.sample_for_preview = self.generate_next_samples()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.get_static_preview()
|
self.get_static_previews()
|
||||||
except:
|
except:
|
||||||
self.sample_for_preview = self.generate_next_samples()
|
self.sample_for_preview = self.generate_next_samples()
|
||||||
|
|
||||||
self.last_sample = self.sample_for_preview
|
self.last_sample = self.sample_for_preview
|
||||||
|
|
||||||
def load_or_def_option(self, name, def_value):
|
def load_or_def_option(self, name, def_value):
|
||||||
options_val = self.options.get(name, None)
|
options_val = self.options.get(name, None)
|
||||||
if options_val is not None:
|
if options_val is not None:
|
||||||
|
@ -354,8 +359,8 @@ class ModelBase(object):
|
||||||
def get_previews(self):
|
def get_previews(self):
|
||||||
return self.onGetPreview ( self.last_sample )
|
return self.onGetPreview ( self.last_sample )
|
||||||
|
|
||||||
def get_static_preview(self):
|
def get_static_previews(self):
|
||||||
return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr
|
return self.onGetPreview (self.sample_for_preview)
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
Path( self.get_summary_path() ).write_text( self.get_summary_text() )
|
Path( self.get_summary_path() ).write_text( self.get_summary_text() )
|
||||||
|
@ -377,16 +382,16 @@ class ModelBase(object):
|
||||||
if diff_hour > 0 and diff_hour % self.autobackup_hour == 0:
|
if diff_hour > 0 and diff_hour % self.autobackup_hour == 0:
|
||||||
self.autobackup_start_time += self.autobackup_hour*3600
|
self.autobackup_start_time += self.autobackup_hour*3600
|
||||||
self.create_backup()
|
self.create_backup()
|
||||||
|
|
||||||
def create_backup(self):
|
def create_backup(self):
|
||||||
io.log_info ("Creating backup...", end='\r')
|
io.log_info ("Creating backup...", end='\r')
|
||||||
|
|
||||||
if not self.autobackups_path.exists():
|
if not self.autobackups_path.exists():
|
||||||
self.autobackups_path.mkdir(exist_ok=True)
|
self.autobackups_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
bckp_filename_list = [ self.get_strpath_storage_for_file(filename) for _, filename in self.get_model_filename_list() ]
|
bckp_filename_list = [ self.get_strpath_storage_for_file(filename) for _, filename in self.get_model_filename_list() ]
|
||||||
bckp_filename_list += [ str(self.get_summary_path()), str(self.model_data_path) ]
|
bckp_filename_list += [ str(self.get_summary_path()), str(self.model_data_path) ]
|
||||||
|
|
||||||
for i in range(24,0,-1):
|
for i in range(24,0,-1):
|
||||||
idx_str = '%.2d' % i
|
idx_str = '%.2d' % i
|
||||||
next_idx_str = '%.2d' % (i+1)
|
next_idx_str = '%.2d' % (i+1)
|
||||||
|
@ -427,7 +432,7 @@ class ModelBase(object):
|
||||||
return imagelib.equalize_and_stack_square (images)
|
return imagelib.equalize_and_stack_square (images)
|
||||||
|
|
||||||
def generate_next_samples(self):
|
def generate_next_samples(self):
|
||||||
sample = []
|
sample = []
|
||||||
for generator in self.generator_list:
|
for generator in self.generator_list:
|
||||||
if generator.is_initialized():
|
if generator.is_initialized():
|
||||||
sample.append ( generator.generate_next() )
|
sample.append ( generator.generate_next() )
|
||||||
|
@ -445,7 +450,7 @@ class ModelBase(object):
|
||||||
self.loss_history.append ( [float(loss[1]) for loss in losses] )
|
self.loss_history.append ( [float(loss[1]) for loss in losses] )
|
||||||
|
|
||||||
if (not io.is_colab() and self.iter % 10 == 0) or \
|
if (not io.is_colab() and self.iter % 10 == 0) or \
|
||||||
(io.is_colab() and self.iter % 100 == 0):
|
(io.is_colab() and self.iter % 100 == 0):
|
||||||
plist = []
|
plist = []
|
||||||
|
|
||||||
if io.is_colab():
|
if io.is_colab():
|
||||||
|
@ -455,8 +460,15 @@ class ModelBase(object):
|
||||||
plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ]
|
plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ]
|
||||||
|
|
||||||
if self.write_preview_history:
|
if self.write_preview_history:
|
||||||
plist += [ (self.get_static_preview(), str (self.preview_history_path / ('%.6d.jpg' % (self.iter))) ) ]
|
previews = self.get_static_previews()
|
||||||
|
for i in range(len(previews)):
|
||||||
|
name, bgr = previews[i]
|
||||||
|
path = self.preview_history_path / name
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ]
|
||||||
|
if not io.is_colab():
|
||||||
|
plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ]
|
||||||
|
|
||||||
for preview, filepath in plist:
|
for preview, filepath in plist:
|
||||||
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
|
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
|
||||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||||
|
@ -508,7 +520,7 @@ class ModelBase(object):
|
||||||
|
|
||||||
def get_summary_path(self):
|
def get_summary_path(self):
|
||||||
return self.get_strpath_storage_for_file('summary.txt')
|
return self.get_strpath_storage_for_file('summary.txt')
|
||||||
|
|
||||||
def get_summary_text(self):
|
def get_summary_text(self):
|
||||||
###Generate text summary of model hyperparameters
|
###Generate text summary of model hyperparameters
|
||||||
#Find the longest key name and value string. Used as column widths.
|
#Find the longest key name and value string. Used as column widths.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue