mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
Move ping pong logic into generator class
This commit is contained in:
parent
f2ed44f3c4
commit
1ecc6f1b62
5 changed files with 122 additions and 78 deletions
|
@ -14,6 +14,7 @@ import imagelib
|
||||||
from interact import interact as io
|
from interact import interact as io
|
||||||
from nnlib import nnlib
|
from nnlib import nnlib
|
||||||
from samplelib import SampleGeneratorBase
|
from samplelib import SampleGeneratorBase
|
||||||
|
from samplelib.SampleGeneratorPingPong import PingPongOptions, Paddle
|
||||||
from utils import Path_utils, std_utils
|
from utils import Path_utils, std_utils
|
||||||
from utils.cv2_utils import *
|
from utils.cv2_utils import *
|
||||||
|
|
||||||
|
@ -26,11 +27,11 @@ class ModelBase(object):
|
||||||
|
|
||||||
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None,
|
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None,
|
||||||
ask_enable_autobackup=True,
|
ask_enable_autobackup=True,
|
||||||
ask_write_preview_history=True,
|
ask_write_preview_history=True,
|
||||||
ask_target_iter=True,
|
ask_target_iter=True,
|
||||||
ask_batch_size=True,
|
ask_batch_size=True,
|
||||||
ask_sort_by_yaw=True,
|
ask_sort_by_yaw=True,
|
||||||
ask_random_flip=True,
|
ask_random_flip=True,
|
||||||
ask_src_scale_mod=True):
|
ask_src_scale_mod=True):
|
||||||
|
|
||||||
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
|
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
|
||||||
|
@ -57,7 +58,7 @@ class ModelBase(object):
|
||||||
self.training_data_src_path = training_data_src_path
|
self.training_data_src_path = training_data_src_path
|
||||||
self.training_data_dst_path = training_data_dst_path
|
self.training_data_dst_path = training_data_dst_path
|
||||||
self.pretraining_data_path = pretraining_data_path
|
self.pretraining_data_path = pretraining_data_path
|
||||||
|
|
||||||
self.src_images_paths = None
|
self.src_images_paths = None
|
||||||
self.dst_images_paths = None
|
self.dst_images_paths = None
|
||||||
self.src_yaw_images_paths = None
|
self.src_yaw_images_paths = None
|
||||||
|
@ -67,7 +68,7 @@ class ModelBase(object):
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None)
|
self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None)
|
||||||
|
|
||||||
self.paddle = 'pong'
|
self.paddle = Paddle.PONG
|
||||||
|
|
||||||
self.iter = 0
|
self.iter = 0
|
||||||
self.options = {}
|
self.options = {}
|
||||||
|
@ -128,7 +129,7 @@ class ModelBase(object):
|
||||||
" new person")
|
" new person")
|
||||||
else:
|
else:
|
||||||
choose_preview_history = False
|
choose_preview_history = False
|
||||||
|
|
||||||
if ask_target_iter:
|
if ask_target_iter:
|
||||||
if (self.iter == 0 or ask_override):
|
if (self.iter == 0 or ask_override):
|
||||||
self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0))
|
self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0))
|
||||||
|
@ -157,7 +158,7 @@ class ModelBase(object):
|
||||||
self.options['ping_pong'] = self.options.get('ping_pong', False)
|
self.options['ping_pong'] = self.options.get('ping_pong', False)
|
||||||
self.options['ping_pong_iter'] = self.options.get('ping_pong_iter',1000)
|
self.options['ping_pong_iter'] = self.options.get('ping_pong_iter',1000)
|
||||||
|
|
||||||
if ask_sort_by_yaw:
|
if ask_sort_by_yaw:
|
||||||
if (self.iter == 0 or ask_override):
|
if (self.iter == 0 or ask_override):
|
||||||
default_sort_by_yaw = self.options.get('sort_by_yaw', False)
|
default_sort_by_yaw = self.options.get('sort_by_yaw', False)
|
||||||
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s):"
|
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s):"
|
||||||
|
@ -182,7 +183,7 @@ class ModelBase(object):
|
||||||
" get a better result."), -30, 30)
|
" get a better result."), -30, 30)
|
||||||
else:
|
else:
|
||||||
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
||||||
|
|
||||||
self.autobackup = self.options.get('autobackup', False)
|
self.autobackup = self.options.get('autobackup', False)
|
||||||
if not self.autobackup and 'autobackup' in self.options:
|
if not self.autobackup and 'autobackup' in self.options:
|
||||||
self.options.pop('autobackup')
|
self.options.pop('autobackup')
|
||||||
|
@ -210,15 +211,18 @@ class ModelBase(object):
|
||||||
|
|
||||||
self.onInitializeOptions(self.iter == 0, ask_override)
|
self.onInitializeOptions(self.iter == 0, ask_override)
|
||||||
|
|
||||||
|
self.ping_pong_options = PingPongOptions(enabled=self.options['ping_pong'],
|
||||||
|
iterations=self.ping_pong_iter,
|
||||||
|
model_iter=self.iter,
|
||||||
|
paddle=self.paddle,
|
||||||
|
batch_cap=self.batch_size)
|
||||||
|
|
||||||
nnlib.import_all(self.device_config)
|
nnlib.import_all(self.device_config)
|
||||||
self.keras = nnlib.keras
|
self.keras = nnlib.keras
|
||||||
self.K = nnlib.keras.backend
|
self.K = nnlib.keras.backend
|
||||||
|
|
||||||
self.onInitialize()
|
self.onInitialize()
|
||||||
|
|
||||||
self.options['batch_size'] = self.batch_size
|
|
||||||
self.paddle = self.options.get('paddle', 'ping')
|
|
||||||
|
|
||||||
if self.debug or self.batch_size == 0:
|
if self.debug or self.batch_size == 0:
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
|
|
||||||
|
@ -231,10 +235,10 @@ class ModelBase(object):
|
||||||
self.get_model_name()))
|
self.get_model_name()))
|
||||||
self.autobackups_path = self.model_path / ('%d_%s_autobackups' % (self.device_args['force_gpu_idx'],
|
self.autobackups_path = self.model_path / ('%d_%s_autobackups' % (self.device_args['force_gpu_idx'],
|
||||||
self.get_model_name()))
|
self.get_model_name()))
|
||||||
|
|
||||||
if self.autobackup:
|
if self.autobackup:
|
||||||
self.autobackup_current_hour = time.localtime().tm_hour
|
self.autobackup_current_hour = time.localtime().tm_hour
|
||||||
|
|
||||||
if not self.autobackups_path.exists():
|
if not self.autobackups_path.exists():
|
||||||
self.autobackups_path.mkdir(exist_ok=True)
|
self.autobackups_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
@ -253,7 +257,7 @@ class ModelBase(object):
|
||||||
if not isinstance(generator, SampleGeneratorBase):
|
if not isinstance(generator, SampleGeneratorBase):
|
||||||
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
|
raise ValueError('training data generator is not subclass of SampleGeneratorBase')
|
||||||
|
|
||||||
if self.sample_for_preview is None or choose_preview_history:
|
if self.sample_for_preview is None or choose_preview_history:
|
||||||
if choose_preview_history and io.is_support_windows():
|
if choose_preview_history and io.is_support_windows():
|
||||||
wnd_name = "[p] - next. [enter] - confirm."
|
wnd_name = "[p] - next. [enter] - confirm."
|
||||||
io.named_window(wnd_name)
|
io.named_window(wnd_name)
|
||||||
|
@ -273,25 +277,25 @@ class ModelBase(object):
|
||||||
break
|
break
|
||||||
elif key == ord('p'):
|
elif key == ord('p'):
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
io.process_messages(0.1)
|
io.process_messages(0.1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
choosed = True
|
choosed = True
|
||||||
|
|
||||||
io.destroy_window(wnd_name)
|
io.destroy_window(wnd_name)
|
||||||
else:
|
else:
|
||||||
self.sample_for_preview = self.generate_next_sample()
|
self.sample_for_preview = self.generate_next_sample()
|
||||||
self.last_sample = self.sample_for_preview
|
self.last_sample = self.sample_for_preview
|
||||||
|
|
||||||
###Generate text summary of model hyperparameters
|
###Generate text summary of model hyperparameters
|
||||||
#Find the longest key name and value string. Used as column widths.
|
#Find the longest key name and value string. Used as column widths.
|
||||||
width_name = max([len(k) for k in self.options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration"
|
width_name = max([len(k) for k in self.options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration"
|
||||||
width_value = max([len(str(x)) for x in self.options.values()] + [len(str(self.iter)), len(self.get_model_name())]) + 1 # Single space buffer to right edge
|
width_value = max([len(str(x)) for x in self.options.values()] + [len(str(self.iter)), len(self.get_model_name())]) + 1 # Single space buffer to right edge
|
||||||
if not self.device_config.cpu_only: #Check length of GPU names
|
if not self.device_config.cpu_only: #Check length of GPU names
|
||||||
width_value = max([len(nnlib.device.getDeviceName(idx))+1 for idx in self.device_config.gpu_idxs] + [width_value])
|
width_value = max([len(nnlib.device.getDeviceName(idx))+1 for idx in self.device_config.gpu_idxs] + [width_value])
|
||||||
width_total = width_name + width_value + 2 #Plus 2 for ": "
|
width_total = width_name + width_value + 2 #Plus 2 for ": "
|
||||||
|
|
||||||
model_summary_text = []
|
model_summary_text = []
|
||||||
model_summary_text += [f'=={" Model Summary ":=^{width_total}}=='] # Model/status summary
|
model_summary_text += [f'=={" Model Summary ":=^{width_total}}=='] # Model/status summary
|
||||||
model_summary_text += [f'=={" "*width_total}==']
|
model_summary_text += [f'=={" "*width_total}==']
|
||||||
|
@ -299,13 +303,13 @@ class ModelBase(object):
|
||||||
model_summary_text += [f'=={" "*width_total}==']
|
model_summary_text += [f'=={" "*width_total}==']
|
||||||
model_summary_text += [f'=={"Current iteration": >{width_name}}: {str(self.iter): <{width_value}}=='] # Iter
|
model_summary_text += [f'=={"Current iteration": >{width_name}}: {str(self.iter): <{width_value}}=='] # Iter
|
||||||
model_summary_text += [f'=={" "*width_total}==']
|
model_summary_text += [f'=={" "*width_total}==']
|
||||||
|
|
||||||
model_summary_text += [f'=={" Model Options ":-^{width_total}}=='] # Model options
|
model_summary_text += [f'=={" Model Options ":-^{width_total}}=='] # Model options
|
||||||
model_summary_text += [f'=={" "*width_total}==']
|
model_summary_text += [f'=={" "*width_total}==']
|
||||||
for key in self.options.keys():
|
for key in self.options.keys():
|
||||||
model_summary_text += [f'=={key: >{width_name}}: {str(self.options[key]): <{width_value}}=='] # self.options key/value pairs
|
model_summary_text += [f'=={key: >{width_name}}: {str(self.options[key]): <{width_value}}=='] # self.options key/value pairs
|
||||||
model_summary_text += [f'=={" "*width_total}==']
|
model_summary_text += [f'=={" "*width_total}==']
|
||||||
|
|
||||||
model_summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info
|
model_summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info
|
||||||
model_summary_text += [f'=={" "*width_total}==']
|
model_summary_text += [f'=={" "*width_total}==']
|
||||||
if self.device_config.multi_gpu:
|
if self.device_config.multi_gpu:
|
||||||
|
@ -318,10 +322,10 @@ class ModelBase(object):
|
||||||
model_summary_text += [f'=={"Device index": >{width_name}}: {idx: <{width_value}}=='] # GPU hardware device index
|
model_summary_text += [f'=={"Device index": >{width_name}}: {idx: <{width_value}}=='] # GPU hardware device index
|
||||||
model_summary_text += [f'=={"Name": >{width_name}}: {nnlib.device.getDeviceName(idx): <{width_value}}=='] # GPU name
|
model_summary_text += [f'=={"Name": >{width_name}}: {nnlib.device.getDeviceName(idx): <{width_value}}=='] # GPU name
|
||||||
vram_str = f'{nnlib.device.getDeviceVRAMTotalGb(idx):.2f}GB' # GPU VRAM - Formated as #.## (or ##.##)
|
vram_str = f'{nnlib.device.getDeviceVRAMTotalGb(idx):.2f}GB' # GPU VRAM - Formated as #.## (or ##.##)
|
||||||
model_summary_text += [f'=={"VRAM": >{width_name}}: {vram_str: <{width_value}}==']
|
model_summary_text += [f'=={"VRAM": >{width_name}}: {vram_str: <{width_value}}==']
|
||||||
model_summary_text += [f'=={" "*width_total}==']
|
model_summary_text += [f'=={" "*width_total}==']
|
||||||
model_summary_text += [f'=={"="*width_total}==']
|
model_summary_text += [f'=={"="*width_total}==']
|
||||||
|
|
||||||
if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[0] <= 2: # Low VRAM warning
|
if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[0] <= 2: # Low VRAM warning
|
||||||
model_summary_text += ["/!\\"]
|
model_summary_text += ["/!\\"]
|
||||||
model_summary_text += ["/!\\ WARNING:"]
|
model_summary_text += ["/!\\ WARNING:"]
|
||||||
|
@ -329,7 +333,7 @@ class ModelBase(object):
|
||||||
model_summary_text += ["/!\\ If training does not start, close all programs and try again."]
|
model_summary_text += ["/!\\ If training does not start, close all programs and try again."]
|
||||||
model_summary_text += ["/!\\ Also you can disable Windows Aero Desktop to increase available VRAM."]
|
model_summary_text += ["/!\\ Also you can disable Windows Aero Desktop to increase available VRAM."]
|
||||||
model_summary_text += ["/!\\"]
|
model_summary_text += ["/!\\"]
|
||||||
|
|
||||||
model_summary_text = "\n".join (model_summary_text)
|
model_summary_text = "\n".join (model_summary_text)
|
||||||
self.model_summary_text = model_summary_text
|
self.model_summary_text = model_summary_text
|
||||||
io.log_info(model_summary_text)
|
io.log_info(model_summary_text)
|
||||||
|
@ -413,7 +417,7 @@ class ModelBase(object):
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
self.options['batch_size'] = self.batch_size
|
self.options['batch_size'] = self.batch_size
|
||||||
self.options['paddle'] = self.paddle
|
self.options['paddle'] = self.ping_pong_options.paddle
|
||||||
summary_path = self.get_strpath_storage_for_file('summary.txt')
|
summary_path = self.get_strpath_storage_for_file('summary.txt')
|
||||||
Path( summary_path ).write_text(self.model_summary_text)
|
Path( summary_path ).write_text(self.model_summary_text)
|
||||||
self.onSave()
|
self.onSave()
|
||||||
|
@ -428,8 +432,8 @@ class ModelBase(object):
|
||||||
|
|
||||||
bckp_filename_list = [self.get_strpath_storage_for_file(filename) for _, filename in
|
bckp_filename_list = [self.get_strpath_storage_for_file(filename) for _, filename in
|
||||||
self.get_model_filename_list()]
|
self.get_model_filename_list()]
|
||||||
bckp_filename_list += [ str(summary_path), str(self.model_data_path) ]
|
bckp_filename_list += [ str(summary_path), str(self.model_data_path) ]
|
||||||
|
|
||||||
if self.autobackup:
|
if self.autobackup:
|
||||||
current_hour = time.localtime().tm_hour
|
current_hour = time.localtime().tm_hour
|
||||||
if self.autobackup_current_hour != current_hour:
|
if self.autobackup_current_hour != current_hour:
|
||||||
|
@ -438,20 +442,20 @@ class ModelBase(object):
|
||||||
for i in range(15,0,-1):
|
for i in range(15,0,-1):
|
||||||
idx_str = '%.2d' % i
|
idx_str = '%.2d' % i
|
||||||
next_idx_str = '%.2d' % (i+1)
|
next_idx_str = '%.2d' % (i+1)
|
||||||
|
|
||||||
idx_backup_path = self.autobackups_path / idx_str
|
idx_backup_path = self.autobackups_path / idx_str
|
||||||
next_idx_packup_path = self.autobackups_path / next_idx_str
|
next_idx_packup_path = self.autobackups_path / next_idx_str
|
||||||
|
|
||||||
if idx_backup_path.exists():
|
if idx_backup_path.exists():
|
||||||
if i == 15:
|
if i == 15:
|
||||||
Path_utils.delete_all_files(idx_backup_path)
|
Path_utils.delete_all_files(idx_backup_path)
|
||||||
else:
|
else:
|
||||||
next_idx_packup_path.mkdir(exist_ok=True)
|
next_idx_packup_path.mkdir(exist_ok=True)
|
||||||
Path_utils.move_all_files (idx_backup_path, next_idx_packup_path)
|
Path_utils.move_all_files (idx_backup_path, next_idx_packup_path)
|
||||||
|
|
||||||
if i == 1:
|
if i == 1:
|
||||||
idx_backup_path.mkdir(exist_ok=True)
|
idx_backup_path.mkdir(exist_ok=True)
|
||||||
for filename in bckp_filename_list:
|
for filename in bckp_filename_list:
|
||||||
shutil.copy ( str(filename), str(idx_backup_path / Path(filename).name) )
|
shutil.copy ( str(filename), str(idx_backup_path / Path(filename).name) )
|
||||||
|
|
||||||
previews = self.get_previews()
|
previews = self.get_previews()
|
||||||
|
@ -495,7 +499,7 @@ class ModelBase(object):
|
||||||
model.save_weights( filename + '.tmp' )
|
model.save_weights( filename + '.tmp' )
|
||||||
|
|
||||||
rename_list = model_filename_list
|
rename_list = model_filename_list
|
||||||
|
|
||||||
"""
|
"""
|
||||||
#unused
|
#unused
|
||||||
, optimizer_filename_list=[]
|
, optimizer_filename_list=[]
|
||||||
|
@ -519,7 +523,7 @@ class ModelBase(object):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print ("Unable to save ", opt_filename)
|
print ("Unable to save ", opt_filename)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for _, filename in rename_list:
|
for _, filename in rename_list:
|
||||||
filename = self.get_strpath_storage_for_file(filename)
|
filename = self.get_strpath_storage_for_file(filename)
|
||||||
source_filename = Path(filename+'.tmp')
|
source_filename = Path(filename+'.tmp')
|
||||||
|
@ -528,7 +532,7 @@ class ModelBase(object):
|
||||||
if target_filename.exists():
|
if target_filename.exists():
|
||||||
target_filename.unlink()
|
target_filename.unlink()
|
||||||
source_filename.rename ( str(target_filename) )
|
source_filename.rename ( str(target_filename) )
|
||||||
|
|
||||||
def debug_one_iter(self):
|
def debug_one_iter(self):
|
||||||
images = []
|
images = []
|
||||||
for generator in self.generator_list:
|
for generator in self.generator_list:
|
||||||
|
@ -542,12 +546,11 @@ class ModelBase(object):
|
||||||
return [ generator.generate_next() for generator in self.generator_list]
|
return [ generator.generate_next() for generator in self.generator_list]
|
||||||
|
|
||||||
def train_one_iter(self):
|
def train_one_iter(self):
|
||||||
|
# if self.iter == 1 and self.options.get('ping_pong', False):
|
||||||
if self.iter == 1 and self.options.get('ping_pong', False):
|
# self.set_batch_size(1)
|
||||||
self.set_batch_size(1)
|
# self.paddle = 'ping'
|
||||||
self.paddle = 'ping'
|
# elif not self.options.get('ping_pong', False) and self.batch_cap != self.batch_size:
|
||||||
elif not self.options.get('ping_pong', False) and self.batch_cap != self.batch_size:
|
# self.set_batch_size(self.batch_cap)
|
||||||
self.set_batch_size(self.batch_cap)
|
|
||||||
sample = self.generate_next_sample()
|
sample = self.generate_next_sample()
|
||||||
iter_time = time.time()
|
iter_time = time.time()
|
||||||
losses = self.onTrainOneIter(sample, self.generator_list)
|
losses = self.onTrainOneIter(sample, self.generator_list)
|
||||||
|
@ -574,21 +577,6 @@ class ModelBase(object):
|
||||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||||
cv2_imwrite (filepath, img )
|
cv2_imwrite (filepath, img )
|
||||||
|
|
||||||
if self.iter % self.ping_pong_iter == 0 and self.iter != 0 and self.options.get('ping_pong', False):
|
|
||||||
if self.batch_size == self.batch_cap:
|
|
||||||
self.paddle = 'pong'
|
|
||||||
if self.batch_size > self.batch_cap:
|
|
||||||
self.set_batch_size(self.batch_cap)
|
|
||||||
self.paddle = 'pong'
|
|
||||||
if self.batch_size == 1:
|
|
||||||
self.paddle = 'ping'
|
|
||||||
if self.paddle == 'ping':
|
|
||||||
self.save()
|
|
||||||
self.set_batch_size(self.batch_size + 1)
|
|
||||||
else:
|
|
||||||
self.save()
|
|
||||||
self.set_batch_size(self.batch_size - 1)
|
|
||||||
|
|
||||||
self.iter += 1
|
self.iter += 1
|
||||||
|
|
||||||
return self.iter, iter_time, self.batch_size
|
return self.iter, iter_time, self.batch_size
|
||||||
|
@ -656,8 +644,8 @@ class ModelBase(object):
|
||||||
|
|
||||||
lh_height = 100
|
lh_height = 100
|
||||||
lh_img = np.ones ( (lh_height,w,c) ) * 0.1
|
lh_img = np.ones ( (lh_height,w,c) ) * 0.1
|
||||||
|
|
||||||
if len(loss_history) != 0:
|
if len(loss_history) != 0:
|
||||||
loss_count = len(loss_history[0])
|
loss_count = len(loss_history[0])
|
||||||
lh_len = len(loss_history)
|
lh_len = len(loss_history)
|
||||||
|
|
||||||
|
|
|
@ -487,7 +487,8 @@ class SAEModel(ModelBase):
|
||||||
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
|
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
|
||||||
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||||
'resolution': resolution // (2 ** i)} for i in
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
range(ms_count)]
|
range(ms_count)],
|
||||||
|
ping_pong=self.ping_pong_options,
|
||||||
),
|
),
|
||||||
|
|
||||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
|
@ -500,7 +501,8 @@ class SAEModel(ModelBase):
|
||||||
range(ms_count)] + \
|
range(ms_count)] + \
|
||||||
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||||
'resolution': resolution // (2 ** i)} for i in
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
range(ms_count)])
|
range(ms_count)],
|
||||||
|
ping_pong=self.ping_pong_options,)
|
||||||
])
|
])
|
||||||
|
|
||||||
# override
|
# override
|
||||||
|
@ -557,7 +559,8 @@ class SAEModel(ModelBase):
|
||||||
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
|
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
|
||||||
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||||
'resolution': resolution // (2 ** i)} for i in
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
range(ms_count)]
|
range(ms_count)],
|
||||||
|
ping_pong=self.ping_pong_options,
|
||||||
),
|
),
|
||||||
|
|
||||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
|
@ -570,7 +573,8 @@ class SAEModel(ModelBase):
|
||||||
range(ms_count)] + \
|
range(ms_count)] + \
|
||||||
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||||
'resolution': resolution // (2 ** i)} for i in
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
range(ms_count)])
|
range(ms_count)],
|
||||||
|
ping_pong=self.ping_pong_options,)
|
||||||
])
|
])
|
||||||
|
|
||||||
# override
|
# override
|
||||||
|
@ -705,7 +709,7 @@ class SAEModel(ModelBase):
|
||||||
return func
|
return func
|
||||||
|
|
||||||
SAEModel.downscale = downscale
|
SAEModel.downscale = downscale
|
||||||
|
|
||||||
#def downscale (dim, padding='zero', norm='', act='', **kwargs):
|
#def downscale (dim, padding='zero', norm='', act='', **kwargs):
|
||||||
# def func(x):
|
# def func(x):
|
||||||
# return BlurPool()( Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=1, padding=padding)(x)) ) )
|
# return BlurPool()( Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=1, padding=padding)(x)) ) )
|
||||||
|
|
|
@ -6,7 +6,7 @@ You can implement your own SampleGenerator
|
||||||
class SampleGeneratorBase(object):
|
class SampleGeneratorBase(object):
|
||||||
|
|
||||||
|
|
||||||
def __init__ (self, samples_path, debug, batch_size):
|
def __init__(self, samples_path, debug, batch_size):
|
||||||
if samples_path is None:
|
if samples_path is None:
|
||||||
raise Exception('samples_path is None')
|
raise Exception('samples_path is None')
|
||||||
|
|
||||||
|
@ -15,21 +15,21 @@ class SampleGeneratorBase(object):
|
||||||
self.batch_size = 1 if self.debug else batch_size
|
self.batch_size = 1 if self.debug else batch_size
|
||||||
self.last_generation = None
|
self.last_generation = None
|
||||||
self.active = True
|
self.active = True
|
||||||
|
|
||||||
def set_active(self, is_active):
|
def set_active(self, is_active):
|
||||||
self.active = is_active
|
self.active = is_active
|
||||||
|
|
||||||
def generate_next(self):
|
def generate_next(self):
|
||||||
if not self.active and self.last_generation is not None:
|
if not self.active and self.last_generation is not None:
|
||||||
return self.last_generation
|
return self.last_generation
|
||||||
self.last_generation = next(self)
|
self.last_generation = next(self)
|
||||||
return self.last_generation
|
return self.last_generation
|
||||||
|
|
||||||
#overridable
|
# overridable
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
#implement your own iterator
|
# implement your own iterator
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
#implement your own iterator
|
# implement your own iterator
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -5,8 +5,9 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from facelib import LandmarksProcessor
|
from facelib import LandmarksProcessor
|
||||||
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleGeneratorPingPong,
|
||||||
SampleType)
|
SampleType)
|
||||||
|
from samplelib.SampleGeneratorPingPong import PingPongOptions
|
||||||
from utils import iter_utils
|
from utils import iter_utils
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,12 +20,12 @@ output_sample_types = [
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
class SampleGeneratorFace(SampleGeneratorBase):
|
class SampleGeneratorFace(SampleGeneratorPingPong):
|
||||||
def __init__(self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None,
|
def __init__(self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None,
|
||||||
random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(),
|
random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(),
|
||||||
output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None,
|
output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None,
|
||||||
**kwargs):
|
ping_pong=PingPongOptions(), **kwargs):
|
||||||
super().__init__(samples_path, debug, batch_size)
|
super().__init__(samples_path, debug, batch_size, ping_pong)
|
||||||
self.sample_process_options = sample_process_options
|
self.sample_process_options = sample_process_options
|
||||||
self.output_sample_types = output_sample_types
|
self.output_sample_types = output_sample_types
|
||||||
self.add_sample_idx = add_sample_idx
|
self.add_sample_idx = add_sample_idx
|
||||||
|
@ -64,6 +65,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
self.generator_counter += 1
|
self.generator_counter += 1
|
||||||
generator = self.generators[self.generator_counter % len(self.generators)]
|
generator = self.generators[self.generator_counter % len(self.generators)]
|
||||||
|
super().__next__()
|
||||||
return next(generator)
|
return next(generator)
|
||||||
|
|
||||||
def batch_func(self, param):
|
def batch_func(self, param):
|
||||||
|
|
50
samplelib/SampleGeneratorPingPong.py
Normal file
50
samplelib/SampleGeneratorPingPong.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
from enum import Enum
|
||||||
|
from samplelib import SampleGeneratorBase
|
||||||
|
|
||||||
|
|
||||||
|
class Paddle(Enum):
|
||||||
|
PING = 'ping' # Ascending
|
||||||
|
PONG = 'pong' # Descending
|
||||||
|
|
||||||
|
|
||||||
|
class PingPongOptions:
|
||||||
|
def __init__(self, enabled=False, iterations=1000, model_iter=1, paddle=Paddle.PING, batch_cap=1):
|
||||||
|
self.enabled = enabled
|
||||||
|
self.iterations = iterations
|
||||||
|
self.model_iter = model_iter
|
||||||
|
self.paddle = paddle
|
||||||
|
self.batch_cap = batch_cap
|
||||||
|
|
||||||
|
|
||||||
|
class SampleGeneratorPingPong(SampleGeneratorBase):
|
||||||
|
def __init__(self, *args, batch_size, ping_pong=PingPongOptions()):
|
||||||
|
self.ping_pong = ping_pong
|
||||||
|
super().__init__(*args, batch_size)
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.ping_pong.enabled and self.ping_pong.model_iter % self.ping_pong.iterations == 0 \
|
||||||
|
and self.ping_pong.model_iter != 0:
|
||||||
|
|
||||||
|
# If batch size is greater then batch cap, set it to batch cap
|
||||||
|
if self.batch_size > self.ping_pong.batch_cap:
|
||||||
|
self.batch_size = self.ping_pong.batch_cap
|
||||||
|
|
||||||
|
# If we are at the batch cap, switch to PONG (descend)
|
||||||
|
if self.batch_size == self.ping_pong.batch_cap:
|
||||||
|
self.paddle = Paddle.PONG
|
||||||
|
# Else if we are at 1, switch to PING (ascend)
|
||||||
|
elif self.batch_size == 1:
|
||||||
|
self.paddle = Paddle.PING
|
||||||
|
|
||||||
|
# If PING (ascending) increase the batch size
|
||||||
|
if self.paddle is Paddle.PING:
|
||||||
|
self.batch_size += 1
|
||||||
|
# Else decrease the batch size
|
||||||
|
else:
|
||||||
|
self.batch_size -= 1
|
||||||
|
|
||||||
|
self.ping_pong.model_iter += 1
|
||||||
|
super().__next__()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue