_dev parts upd

This commit is contained in:
iperov 2021-12-20 17:52:52 +04:00
commit 6478551358
2 changed files with 37 additions and 16 deletions

View file

@ -53,12 +53,16 @@ class FaceAlignerTrainerApp:
self._device_info = None self._device_info = None
self._batch_size = None self._batch_size = None
self._resolution = None self._resolution = None
self._learning_rate = None
self._random_warp = None self._random_warp = None
self._iteration = None self._iteration = None
self._autosave_period = None self._autosave_period = None
self._is_training = False self._is_training = False
self._loss_history = {} self._loss_history = {}
self._model = None
self._model_optimizer = None
# Generators # Generators
self._training_generator = TrainingDataGenerator(faceset_path) self._training_generator = TrainingDataGenerator(faceset_path)
@ -90,6 +94,12 @@ class FaceAlignerTrainerApp:
self._resolution = resolution self._resolution = resolution
self._training_generator.set_resolution(resolution) self._training_generator.set_resolution(resolution)
def get_learning_rate(self) -> float: return self._learning_rate
def set_learning_rate(self, learning_rate : float):
self._learning_rate = learning_rate
if self._model_optimizer is not None:
self._model_optimizer.set_lr(learning_rate)
def get_random_warp(self) -> bool: return self._random_warp def get_random_warp(self) -> bool: return self._random_warp
def set_random_warp(self, random_warp : bool): def set_random_warp(self, random_warp : bool):
self._random_warp = random_warp self._random_warp = random_warp
@ -118,6 +128,7 @@ class FaceAlignerTrainerApp:
self.set_device_info( lib_torch.get_device_info_by_index(model_data.get('device_index', -1)) ) self.set_device_info( lib_torch.get_device_info_by_index(model_data.get('device_index', -1)) )
self.set_batch_size( model_data.get('batch_size', 64) ) self.set_batch_size( model_data.get('batch_size', 64) )
self.set_resolution( model_data.get('resolution', 224) ) self.set_resolution( model_data.get('resolution', 224) )
self.set_learning_rate( model_data.get('learning_rate', 5e-5) )
self.set_random_warp( model_data.get('random_warp', True) ) self.set_random_warp( model_data.get('random_warp', True) )
self.set_iteration( model_data.get('iteration', 0) ) self.set_iteration( model_data.get('iteration', 0) )
self.set_autosave_period( model_data.get('autosave_period', 25) ) self.set_autosave_period( model_data.get('autosave_period', 25) )
@ -339,6 +350,9 @@ class FaceAlignerTrainerApp:
dc.DlgChoice(short_name='d', row_def=f'| Device | {self._device_info}', dc.DlgChoice(short_name='d', row_def=f'| Device | {self._device_info}',
on_choose=lambda dlg: self.get_training_device_dlg(dlg).set_current()), on_choose=lambda dlg: self.get_training_device_dlg(dlg).set_current()),
dc.DlgChoice(short_name='lr', row_def=f'| Learning rate | {self.get_learning_rate()}',
on_choose=lambda dlg: self.get_learning_rate_dlg(parent_dlg=dlg).set_current() ),
dc.DlgChoice(short_name='i', row_def=f'| Iteration | {self.get_iteration()}', dc.DlgChoice(short_name='i', row_def=f'| Iteration | {self.get_iteration()}',
on_choose=lambda dlg: self.get_iteration_dlg(parent_dlg=dlg).set_current() ), on_choose=lambda dlg: self.get_iteration_dlg(parent_dlg=dlg).set_current() ),
@ -435,6 +449,13 @@ class FaceAlignerTrainerApp:
on_back = lambda dlg: parent_dlg.recreate().set_current(), on_back = lambda dlg: parent_dlg.recreate().set_current(),
top_rows_def='|c9 Set iteration', ) top_rows_def='|c9 Set iteration', )
def get_learning_rate_dlg(self, parent_dlg):
return dc.DlgNumber(is_float=True, min_value=0, max_value=0.1,
on_value = lambda dlg, value: (self.set_learning_rate(value), parent_dlg.recreate().set_current()),
on_recreate = lambda dlg: self.get_learning_rate_dlg(parent_dlg),
on_back = lambda dlg: parent_dlg.recreate().set_current(),
top_rows_def='|c9 Set learning rate', )
def get_training_device_dlg(self, parent_dlg): def get_training_device_dlg(self, parent_dlg):
return DlgTorchDevicesInfo(on_device_choice = lambda dlg, device_info: (self.set_device_info(device_info), parent_dlg.recreate().set_current()), return DlgTorchDevicesInfo(on_device_choice = lambda dlg, device_info: (self.set_device_info(device_info), parent_dlg.recreate().set_current()),
on_recreate = lambda dlg: self.get_training_device_dlg(parent_dlg), on_recreate = lambda dlg: self.get_training_device_dlg(parent_dlg),
@ -442,7 +463,6 @@ class FaceAlignerTrainerApp:
top_rows_def='|c9 Choose device' top_rows_def='|c9 Choose device'
) )
class DlgTorchDevicesInfo(dc.DlgChoices): class DlgTorchDevicesInfo(dc.DlgChoices):
def __init__(self, on_device_choice : Callable = None, def __init__(self, on_device_choice : Callable = None,
on_device_multi_choice : Callable = None, on_device_multi_choice : Callable = None,

View file

@ -7,9 +7,6 @@ class AdaBelief(torch.optim.Optimizer):
defaults = dict(lr=lr, lr_dropout=lr_dropout, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, lr_dropout=lr_dropout, betas=betas, eps=eps, weight_decay=weight_decay)
super(AdaBelief, self).__init__(params, defaults) super(AdaBelief, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdaBelief, self).__setstate__(state)
def reset(self): def reset(self):
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']:
@ -18,6 +15,10 @@ class AdaBelief(torch.optim.Optimizer):
state['m_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) state['m_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
state['v_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) state['v_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
def set_lr(self, lr):
for group in self.param_groups:
group['lr'] = lr
def step(self): def step(self):
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group['params']: