Merge pull request #42 from faceshiftlabs/feat/ping-pong-2

Feat/ping pong 2
This commit is contained in:
Jeremy Hummel 2019-08-23 19:10:31 -07:00 committed by GitHub
commit 1dff2283d1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 121 additions and 108 deletions

View file

@ -14,6 +14,7 @@ import imagelib
from interact import interact as io from interact import interact as io
from nnlib import nnlib from nnlib import nnlib
from samplelib import SampleGeneratorBase from samplelib import SampleGeneratorBase
from samplelib.SampleGeneratorPingPong import PingPongOptions, Paddle
from utils import Path_utils, std_utils from utils import Path_utils, std_utils
from utils.cv2_utils import * from utils.cv2_utils import *
@ -26,11 +27,11 @@ class ModelBase(object):
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None, def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None,
ask_enable_autobackup=True, ask_enable_autobackup=True,
ask_write_preview_history=True, ask_write_preview_history=True,
ask_target_iter=True, ask_target_iter=True,
ask_batch_size=True, ask_batch_size=True,
ask_sort_by_yaw=True, ask_sort_by_yaw=True,
ask_random_flip=True, ask_random_flip=True,
ask_src_scale_mod=True): ask_src_scale_mod=True):
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1) device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
@ -57,7 +58,7 @@ class ModelBase(object):
self.training_data_src_path = training_data_src_path self.training_data_src_path = training_data_src_path
self.training_data_dst_path = training_data_dst_path self.training_data_dst_path = training_data_dst_path
self.pretraining_data_path = pretraining_data_path self.pretraining_data_path = pretraining_data_path
self.src_images_paths = None self.src_images_paths = None
self.dst_images_paths = None self.dst_images_paths = None
self.src_yaw_images_paths = None self.src_yaw_images_paths = None
@ -67,7 +68,7 @@ class ModelBase(object):
self.debug = debug self.debug = debug
self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None) self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None)
self.paddle = 'pong' self.paddle = Paddle.PONG
self.iter = 0 self.iter = 0
self.options = {} self.options = {}
@ -128,7 +129,7 @@ class ModelBase(object):
" new person") " new person")
else: else:
choose_preview_history = False choose_preview_history = False
if ask_target_iter: if ask_target_iter:
if (self.iter == 0 or ask_override): if (self.iter == 0 or ask_override):
self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0)) self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0))
@ -157,7 +158,7 @@ class ModelBase(object):
self.options['ping_pong'] = self.options.get('ping_pong', False) self.options['ping_pong'] = self.options.get('ping_pong', False)
self.options['ping_pong_iter'] = self.options.get('ping_pong_iter',1000) self.options['ping_pong_iter'] = self.options.get('ping_pong_iter',1000)
if ask_sort_by_yaw: if ask_sort_by_yaw:
if (self.iter == 0 or ask_override): if (self.iter == 0 or ask_override):
default_sort_by_yaw = self.options.get('sort_by_yaw', False) default_sort_by_yaw = self.options.get('sort_by_yaw', False)
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s):" self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:%s):"
@ -182,7 +183,7 @@ class ModelBase(object):
" get a better result."), -30, 30) " get a better result."), -30, 30)
else: else:
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0) self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
self.autobackup = self.options.get('autobackup', False) self.autobackup = self.options.get('autobackup', False)
if not self.autobackup and 'autobackup' in self.options: if not self.autobackup and 'autobackup' in self.options:
self.options.pop('autobackup') self.options.pop('autobackup')
@ -210,15 +211,18 @@ class ModelBase(object):
self.onInitializeOptions(self.iter == 0, ask_override) self.onInitializeOptions(self.iter == 0, ask_override)
self.ping_pong_options = PingPongOptions(enabled=self.options['ping_pong'],
iterations=self.ping_pong_iter,
model_iter=self.iter,
paddle=self.paddle,
batch_cap=self.batch_size)
nnlib.import_all(self.device_config) nnlib.import_all(self.device_config)
self.keras = nnlib.keras self.keras = nnlib.keras
self.K = nnlib.keras.backend self.K = nnlib.keras.backend
self.onInitialize() self.onInitialize()
self.options['batch_size'] = self.batch_size
self.paddle = self.options.get('paddle', 'ping')
if self.debug or self.batch_size == 0: if self.debug or self.batch_size == 0:
self.batch_size = 1 self.batch_size = 1
@ -231,10 +235,10 @@ class ModelBase(object):
self.get_model_name())) self.get_model_name()))
self.autobackups_path = self.model_path / ('%d_%s_autobackups' % (self.device_args['force_gpu_idx'], self.autobackups_path = self.model_path / ('%d_%s_autobackups' % (self.device_args['force_gpu_idx'],
self.get_model_name())) self.get_model_name()))
if self.autobackup: if self.autobackup:
self.autobackup_current_hour = time.localtime().tm_hour self.autobackup_current_hour = time.localtime().tm_hour
if not self.autobackups_path.exists(): if not self.autobackups_path.exists():
self.autobackups_path.mkdir(exist_ok=True) self.autobackups_path.mkdir(exist_ok=True)
@ -253,7 +257,7 @@ class ModelBase(object):
if not isinstance(generator, SampleGeneratorBase): if not isinstance(generator, SampleGeneratorBase):
raise ValueError('training data generator is not subclass of SampleGeneratorBase') raise ValueError('training data generator is not subclass of SampleGeneratorBase')
if self.sample_for_preview is None or choose_preview_history: if self.sample_for_preview is None or choose_preview_history:
if choose_preview_history and io.is_support_windows(): if choose_preview_history and io.is_support_windows():
wnd_name = "[p] - next. [enter] - confirm." wnd_name = "[p] - next. [enter] - confirm."
io.named_window(wnd_name) io.named_window(wnd_name)
@ -273,25 +277,25 @@ class ModelBase(object):
break break
elif key == ord('p'): elif key == ord('p'):
break break
try: try:
io.process_messages(0.1) io.process_messages(0.1)
except KeyboardInterrupt: except KeyboardInterrupt:
choosed = True choosed = True
io.destroy_window(wnd_name) io.destroy_window(wnd_name)
else: else:
self.sample_for_preview = self.generate_next_sample() self.sample_for_preview = self.generate_next_sample()
self.last_sample = self.sample_for_preview self.last_sample = self.sample_for_preview
###Generate text summary of model hyperparameters ###Generate text summary of model hyperparameters
#Find the longest key name and value string. Used as column widths. #Find the longest key name and value string. Used as column widths.
width_name = max([len(k) for k in self.options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration" width_name = max([len(k) for k in self.options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration"
width_value = max([len(str(x)) for x in self.options.values()] + [len(str(self.iter)), len(self.get_model_name())]) + 1 # Single space buffer to right edge width_value = max([len(str(x)) for x in self.options.values()] + [len(str(self.iter)), len(self.get_model_name())]) + 1 # Single space buffer to right edge
if not self.device_config.cpu_only: #Check length of GPU names if not self.device_config.cpu_only: #Check length of GPU names
width_value = max([len(nnlib.device.getDeviceName(idx))+1 for idx in self.device_config.gpu_idxs] + [width_value]) width_value = max([len(nnlib.device.getDeviceName(idx))+1 for idx in self.device_config.gpu_idxs] + [width_value])
width_total = width_name + width_value + 2 #Plus 2 for ": " width_total = width_name + width_value + 2 #Plus 2 for ": "
model_summary_text = [] model_summary_text = []
model_summary_text += [f'=={" Model Summary ":=^{width_total}}=='] # Model/status summary model_summary_text += [f'=={" Model Summary ":=^{width_total}}=='] # Model/status summary
model_summary_text += [f'=={" "*width_total}=='] model_summary_text += [f'=={" "*width_total}==']
@ -299,13 +303,13 @@ class ModelBase(object):
model_summary_text += [f'=={" "*width_total}=='] model_summary_text += [f'=={" "*width_total}==']
model_summary_text += [f'=={"Current iteration": >{width_name}}: {str(self.iter): <{width_value}}=='] # Iter model_summary_text += [f'=={"Current iteration": >{width_name}}: {str(self.iter): <{width_value}}=='] # Iter
model_summary_text += [f'=={" "*width_total}=='] model_summary_text += [f'=={" "*width_total}==']
model_summary_text += [f'=={" Model Options ":-^{width_total}}=='] # Model options model_summary_text += [f'=={" Model Options ":-^{width_total}}=='] # Model options
model_summary_text += [f'=={" "*width_total}=='] model_summary_text += [f'=={" "*width_total}==']
for key in self.options.keys(): for key in self.options.keys():
model_summary_text += [f'=={key: >{width_name}}: {str(self.options[key]): <{width_value}}=='] # self.options key/value pairs model_summary_text += [f'=={key: >{width_name}}: {str(self.options[key]): <{width_value}}=='] # self.options key/value pairs
model_summary_text += [f'=={" "*width_total}=='] model_summary_text += [f'=={" "*width_total}==']
model_summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info model_summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info
model_summary_text += [f'=={" "*width_total}=='] model_summary_text += [f'=={" "*width_total}==']
if self.device_config.multi_gpu: if self.device_config.multi_gpu:
@ -318,10 +322,10 @@ class ModelBase(object):
model_summary_text += [f'=={"Device index": >{width_name}}: {idx: <{width_value}}=='] # GPU hardware device index model_summary_text += [f'=={"Device index": >{width_name}}: {idx: <{width_value}}=='] # GPU hardware device index
model_summary_text += [f'=={"Name": >{width_name}}: {nnlib.device.getDeviceName(idx): <{width_value}}=='] # GPU name model_summary_text += [f'=={"Name": >{width_name}}: {nnlib.device.getDeviceName(idx): <{width_value}}=='] # GPU name
vram_str = f'{nnlib.device.getDeviceVRAMTotalGb(idx):.2f}GB' # GPU VRAM - Formated as #.## (or ##.##) vram_str = f'{nnlib.device.getDeviceVRAMTotalGb(idx):.2f}GB' # GPU VRAM - Formated as #.## (or ##.##)
model_summary_text += [f'=={"VRAM": >{width_name}}: {vram_str: <{width_value}}=='] model_summary_text += [f'=={"VRAM": >{width_name}}: {vram_str: <{width_value}}==']
model_summary_text += [f'=={" "*width_total}=='] model_summary_text += [f'=={" "*width_total}==']
model_summary_text += [f'=={"="*width_total}=='] model_summary_text += [f'=={"="*width_total}==']
if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[0] <= 2: # Low VRAM warning if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[0] <= 2: # Low VRAM warning
model_summary_text += ["/!\\"] model_summary_text += ["/!\\"]
model_summary_text += ["/!\\ WARNING:"] model_summary_text += ["/!\\ WARNING:"]
@ -329,7 +333,7 @@ class ModelBase(object):
model_summary_text += ["/!\\ If training does not start, close all programs and try again."] model_summary_text += ["/!\\ If training does not start, close all programs and try again."]
model_summary_text += ["/!\\ Also you can disable Windows Aero Desktop to increase available VRAM."] model_summary_text += ["/!\\ Also you can disable Windows Aero Desktop to increase available VRAM."]
model_summary_text += ["/!\\"] model_summary_text += ["/!\\"]
model_summary_text = "\n".join (model_summary_text) model_summary_text = "\n".join (model_summary_text)
self.model_summary_text = model_summary_text self.model_summary_text = model_summary_text
io.log_info(model_summary_text) io.log_info(model_summary_text)
@ -413,7 +417,7 @@ class ModelBase(object):
def save(self): def save(self):
self.options['batch_size'] = self.batch_size self.options['batch_size'] = self.batch_size
self.options['paddle'] = self.paddle self.options['paddle'] = self.ping_pong_options.paddle
summary_path = self.get_strpath_storage_for_file('summary.txt') summary_path = self.get_strpath_storage_for_file('summary.txt')
Path( summary_path ).write_text(self.model_summary_text) Path( summary_path ).write_text(self.model_summary_text)
self.onSave() self.onSave()
@ -428,8 +432,8 @@ class ModelBase(object):
bckp_filename_list = [self.get_strpath_storage_for_file(filename) for _, filename in bckp_filename_list = [self.get_strpath_storage_for_file(filename) for _, filename in
self.get_model_filename_list()] self.get_model_filename_list()]
bckp_filename_list += [ str(summary_path), str(self.model_data_path) ] bckp_filename_list += [ str(summary_path), str(self.model_data_path) ]
if self.autobackup: if self.autobackup:
current_hour = time.localtime().tm_hour current_hour = time.localtime().tm_hour
if self.autobackup_current_hour != current_hour: if self.autobackup_current_hour != current_hour:
@ -438,20 +442,20 @@ class ModelBase(object):
for i in range(15,0,-1): for i in range(15,0,-1):
idx_str = '%.2d' % i idx_str = '%.2d' % i
next_idx_str = '%.2d' % (i+1) next_idx_str = '%.2d' % (i+1)
idx_backup_path = self.autobackups_path / idx_str idx_backup_path = self.autobackups_path / idx_str
next_idx_packup_path = self.autobackups_path / next_idx_str next_idx_packup_path = self.autobackups_path / next_idx_str
if idx_backup_path.exists(): if idx_backup_path.exists():
if i == 15: if i == 15:
Path_utils.delete_all_files(idx_backup_path) Path_utils.delete_all_files(idx_backup_path)
else: else:
next_idx_packup_path.mkdir(exist_ok=True) next_idx_packup_path.mkdir(exist_ok=True)
Path_utils.move_all_files (idx_backup_path, next_idx_packup_path) Path_utils.move_all_files (idx_backup_path, next_idx_packup_path)
if i == 1: if i == 1:
idx_backup_path.mkdir(exist_ok=True) idx_backup_path.mkdir(exist_ok=True)
for filename in bckp_filename_list: for filename in bckp_filename_list:
shutil.copy ( str(filename), str(idx_backup_path / Path(filename).name) ) shutil.copy ( str(filename), str(idx_backup_path / Path(filename).name) )
previews = self.get_previews() previews = self.get_previews()
@ -495,7 +499,7 @@ class ModelBase(object):
model.save_weights( filename + '.tmp' ) model.save_weights( filename + '.tmp' )
rename_list = model_filename_list rename_list = model_filename_list
""" """
#unused #unused
, optimizer_filename_list=[] , optimizer_filename_list=[]
@ -519,7 +523,7 @@ class ModelBase(object):
except Exception as e: except Exception as e:
print ("Unable to save ", opt_filename) print ("Unable to save ", opt_filename)
""" """
for _, filename in rename_list: for _, filename in rename_list:
filename = self.get_strpath_storage_for_file(filename) filename = self.get_strpath_storage_for_file(filename)
source_filename = Path(filename+'.tmp') source_filename = Path(filename+'.tmp')
@ -528,7 +532,7 @@ class ModelBase(object):
if target_filename.exists(): if target_filename.exists():
target_filename.unlink() target_filename.unlink()
source_filename.rename ( str(target_filename) ) source_filename.rename ( str(target_filename) )
def debug_one_iter(self): def debug_one_iter(self):
images = [] images = []
for generator in self.generator_list: for generator in self.generator_list:
@ -542,12 +546,11 @@ class ModelBase(object):
return [ generator.generate_next() for generator in self.generator_list] return [ generator.generate_next() for generator in self.generator_list]
def train_one_iter(self): def train_one_iter(self):
# if self.iter == 1 and self.options.get('ping_pong', False):
if self.iter == 1 and self.options.get('ping_pong', False): # self.set_batch_size(1)
self.set_batch_size(1) # self.paddle = 'ping'
self.paddle = 'ping' # elif not self.options.get('ping_pong', False) and self.batch_cap != self.batch_size:
elif not self.options.get('ping_pong', False) and self.batch_cap != self.batch_size: # self.set_batch_size(self.batch_cap)
self.set_batch_size(self.batch_cap)
sample = self.generate_next_sample() sample = self.generate_next_sample()
iter_time = time.time() iter_time = time.time()
losses = self.onTrainOneIter(sample, self.generator_list) losses = self.onTrainOneIter(sample, self.generator_list)
@ -574,21 +577,6 @@ class ModelBase(object):
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite (filepath, img ) cv2_imwrite (filepath, img )
if self.iter % self.ping_pong_iter == 0 and self.iter != 0 and self.options.get('ping_pong', False):
if self.batch_size == self.batch_cap:
self.paddle = 'pong'
if self.batch_size > self.batch_cap:
self.set_batch_size(self.batch_cap)
self.paddle = 'pong'
if self.batch_size == 1:
self.paddle = 'ping'
if self.paddle == 'ping':
self.save()
self.set_batch_size(self.batch_size + 1)
else:
self.save()
self.set_batch_size(self.batch_size - 1)
self.iter += 1 self.iter += 1
return self.iter, iter_time, self.batch_size return self.iter, iter_time, self.batch_size
@ -656,8 +644,8 @@ class ModelBase(object):
lh_height = 100 lh_height = 100
lh_img = np.ones ( (lh_height,w,c) ) * 0.1 lh_img = np.ones ( (lh_height,w,c) ) * 0.1
if len(loss_history) != 0: if len(loss_history) != 0:
loss_count = len(loss_history[0]) loss_count = len(loss_history[0])
lh_len = len(loss_history) lh_len = len(loss_history)

View file

@ -487,7 +487,8 @@ class SAEModel(ModelBase):
'apply_ct': apply_random_ct} for i in range(ms_count)] + \ 'apply_ct': apply_random_ct} for i in range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M), [{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in 'resolution': resolution // (2 ** i)} for i in
range(ms_count)] range(ms_count)],
ping_pong=self.ping_pong_options,
), ),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
@ -500,7 +501,8 @@ class SAEModel(ModelBase):
range(ms_count)] + \ range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M), [{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in 'resolution': resolution // (2 ** i)} for i in
range(ms_count)]) range(ms_count)],
ping_pong=self.ping_pong_options,)
]) ])
# override # override
@ -540,38 +542,8 @@ class SAEModel(ModelBase):
# override # override
def set_batch_size(self, batch_size): def set_batch_size(self, batch_size):
self.batch_size = batch_size self.batch_size = batch_size
self.set_training_data_generators(None) for generators in self.get_training_data_generators():
self.set_training_data_generators([ generators.update_batch(batch_size)
SampleGeneratorFace(training_data_src_path,
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
scale_range=np.array([-0.05,
0.05]) + self.src_scale_mod / 100.0),
output_sample_types=[{'types': (
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution, 'apply_ct': apply_random_ct}] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution // (2 ** i),
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)]
),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
output_sample_types=[{'types': (
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution}] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)])
])
# override # override
def onTrainOneIter(self, generators_samples, generators_list): def onTrainOneIter(self, generators_samples, generators_list):
@ -705,7 +677,7 @@ class SAEModel(ModelBase):
return func return func
SAEModel.downscale = downscale SAEModel.downscale = downscale
#def downscale (dim, padding='zero', norm='', act='', **kwargs): #def downscale (dim, padding='zero', norm='', act='', **kwargs):
# def func(x): # def func(x):
# return BlurPool()( Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=1, padding=padding)(x)) ) ) # return BlurPool()( Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=1, padding=padding)(x)) ) )

View file

@ -6,7 +6,7 @@ You can implement your own SampleGenerator
class SampleGeneratorBase(object): class SampleGeneratorBase(object):
def __init__ (self, samples_path, debug, batch_size): def __init__(self, samples_path, debug, batch_size):
if samples_path is None: if samples_path is None:
raise Exception('samples_path is None') raise Exception('samples_path is None')
@ -15,21 +15,21 @@ class SampleGeneratorBase(object):
self.batch_size = 1 if self.debug else batch_size self.batch_size = 1 if self.debug else batch_size
self.last_generation = None self.last_generation = None
self.active = True self.active = True
def set_active(self, is_active): def set_active(self, is_active):
self.active = is_active self.active = is_active
def generate_next(self): def generate_next(self):
if not self.active and self.last_generation is not None: if not self.active and self.last_generation is not None:
return self.last_generation return self.last_generation
self.last_generation = next(self) self.last_generation = next(self)
return self.last_generation return self.last_generation
#overridable # overridable
def __iter__(self): def __iter__(self):
#implement your own iterator # implement your own iterator
return self return self
def __next__(self): def __next__(self):
#implement your own iterator # implement your own iterator
return None return None

View file

@ -5,8 +5,9 @@ import cv2
import numpy as np import numpy as np
from facelib import LandmarksProcessor from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleGeneratorPingPong,
SampleType) SampleType)
from samplelib.SampleGeneratorPingPong import PingPongOptions, SampleGeneratorPingPong
from utils import iter_utils from utils import iter_utils
@ -19,12 +20,12 @@ output_sample_types = [
''' '''
class SampleGeneratorFace(SampleGeneratorBase): class SampleGeneratorFace(SampleGeneratorPingPong):
def __init__(self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, def __init__(self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None,
random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(),
output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None,
**kwargs): ping_pong=PingPongOptions(), **kwargs):
super().__init__(samples_path, debug, batch_size) super().__init__(samples_path, debug, batch_size=batch_size, ping_pong=ping_pong)
self.sample_process_options = sample_process_options self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx self.add_sample_idx = add_sample_idx
@ -64,6 +65,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
def __next__(self): def __next__(self):
self.generator_counter += 1 self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators)] generator = self.generators[self.generator_counter % len(self.generators)]
super().__next__()
return next(generator) return next(generator)
def batch_func(self, param): def batch_func(self, param):

View file

@ -0,0 +1,50 @@
from enum import Enum
from samplelib import SampleGeneratorBase
class Paddle(Enum):
PING = 'ping' # Ascending
PONG = 'pong' # Descending
class PingPongOptions:
def __init__(self, enabled=False, iterations=1000, model_iter=1, paddle=Paddle.PING, batch_cap=1):
self.enabled = enabled
self.iterations = iterations
self.model_iter = model_iter
self.paddle = paddle
self.batch_cap = batch_cap
class SampleGeneratorPingPong(SampleGeneratorBase):
def __init__(self, *args, batch_size, ping_pong=PingPongOptions()):
self.ping_pong = ping_pong
super().__init__(*args, batch_size)
def __next__(self):
if self.ping_pong.enabled and self.ping_pong.model_iter % self.ping_pong.iterations == 0 \
and self.ping_pong.model_iter != 0:
# If batch size is greater then batch cap, set it to batch cap
if self.batch_size > self.ping_pong.batch_cap:
self.batch_size = self.ping_pong.batch_cap
# If we are at the batch cap, switch to PONG (descend)
if self.batch_size == self.ping_pong.batch_cap:
self.paddle = Paddle.PONG
# Else if we are at 1, switch to PING (ascend)
elif self.batch_size == 1:
self.paddle = Paddle.PING
# If PING (ascending) increase the batch size
if self.paddle is Paddle.PING:
self.batch_size += 1
# Else decrease the batch size
else:
self.batch_size -= 1
self.ping_pong.model_iter += 1
super().__next__()

View file

@ -4,5 +4,6 @@ from .SampleLoader import SampleLoader
from .SampleProcessor import SampleProcessor from .SampleProcessor import SampleProcessor
from .SampleGeneratorBase import SampleGeneratorBase from .SampleGeneratorBase import SampleGeneratorBase
from .SampleGeneratorFace import SampleGeneratorFace from .SampleGeneratorFace import SampleGeneratorFace
from .SampleGeneratorPingPong import SampleGeneratorPingPong
from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal