From 64785513586a8fadc48c7f3c061e804349c5efef Mon Sep 17 00:00:00 2001 From: iperov Date: Mon, 20 Dec 2021 17:52:52 +0400 Subject: [PATCH] _dev parts upd --- .../FaceAligner/FaceAlignerTrainerApp.py | 44 ++++++++++++++----- xlib/torch/optim/AdaBelief.py | 9 ++-- 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/apps/trainers/FaceAligner/FaceAlignerTrainerApp.py b/apps/trainers/FaceAligner/FaceAlignerTrainerApp.py index 43a11d7..c53a75c 100644 --- a/apps/trainers/FaceAligner/FaceAlignerTrainerApp.py +++ b/apps/trainers/FaceAligner/FaceAlignerTrainerApp.py @@ -53,12 +53,16 @@ class FaceAlignerTrainerApp: self._device_info = None self._batch_size = None self._resolution = None + self._learning_rate = None self._random_warp = None self._iteration = None self._autosave_period = None self._is_training = False self._loss_history = {} + self._model = None + self._model_optimizer = None + # Generators self._training_generator = TrainingDataGenerator(faceset_path) @@ -90,6 +94,12 @@ class FaceAlignerTrainerApp: self._resolution = resolution self._training_generator.set_resolution(resolution) + def get_learning_rate(self) -> float: return self._learning_rate + def set_learning_rate(self, learning_rate : float): + self._learning_rate = learning_rate + if self._model_optimizer is not None: + self._model_optimizer.set_lr(learning_rate) + def get_random_warp(self) -> bool: return self._random_warp def set_random_warp(self, random_warp : bool): self._random_warp = random_warp @@ -118,6 +128,7 @@ class FaceAlignerTrainerApp: self.set_device_info( lib_torch.get_device_info_by_index(model_data.get('device_index', -1)) ) self.set_batch_size( model_data.get('batch_size', 64) ) self.set_resolution( model_data.get('resolution', 224) ) + self.set_learning_rate( model_data.get('learning_rate', 5e-5) ) self.set_random_warp( model_data.get('random_warp', True) ) self.set_iteration( model_data.get('iteration', 0) ) self.set_autosave_period( model_data.get('autosave_period', 25) ) @@ -234,7 +245,7 @@ class FaceAlignerTrainerApp: def main_loop(self): while not self._is_quit: - + if self._ev_request_reset_model.is_set(): self._ev_request_reset_model.clear() self.reset_model(load=False) @@ -267,20 +278,20 @@ class FaceAlignerTrainerApp: self.export() print('Exporting done.') dc.Diacon.get_current_dlg().recreate().set_current() - + if self._ev_request_save.is_set(): self._ev_request_save.clear() print('Saving...') self.save() print('Saving done.') dc.Diacon.get_current_dlg().recreate().set_current() - + if self._ev_request_quit.is_set(): self._ev_request_quit.clear() self._is_quit = True - + time.sleep(0.005) - + def main_loop_training(self): """ separated function, because torch tensors refences must be freed from python locals @@ -288,21 +299,21 @@ class FaceAlignerTrainerApp: if self._is_training or \ self._is_previewing_samples or \ self._ev_request_preview.is_set(): - + training_data = self._training_generator.get_next_data(wait=False) if training_data is not None and \ training_data.resolution == self.get_resolution(): # Skip if resolution is different, due to delay if self._is_training: self._model_optimizer.zero_grad() - + if self._ev_request_preview.is_set() or \ self._is_training: # Inference for both preview and training img_aligned_shifted_t = torch.tensor(training_data.img_aligned_shifted).to(self._device) shift_uni_mats_pred_t = self._model(img_aligned_shifted_t)#.view( (-1,2,3) ) - + if self._is_training: - # Training optimization step + # Training optimization step shift_uni_mats_t = torch.tensor(training_data.shift_uni_mats).to(self._device) loss_t = (shift_uni_mats_pred_t-shift_uni_mats_t).square().mean()*10.0 loss_t.backward() @@ -313,7 +324,7 @@ class FaceAlignerTrainerApp: rec_loss_history = self._loss_history['reconstruct'] = [] rec_loss_history.append(float(loss)) self.set_iteration( self.get_iteration() + 1 ) - + if self._ev_request_preview.is_set(): self._ev_request_preview.clear() # Preview request @@ -321,7 +332,7 @@ class FaceAlignerTrainerApp: pd.training_data = training_data pd.shift_uni_mats_pred = shift_uni_mats_pred_t.detach().cpu().numpy() self._new_preview_data = pd - + if self._is_previewing_samples: self._new_viewing_data = training_data @@ -339,6 +350,9 @@ class FaceAlignerTrainerApp: dc.DlgChoice(short_name='d', row_def=f'| Device | {self._device_info}', on_choose=lambda dlg: self.get_training_device_dlg(dlg).set_current()), + dc.DlgChoice(short_name='lr', row_def=f'| Learning rate | {self.get_learning_rate()}', + on_choose=lambda dlg: self.get_learning_rate_dlg(parent_dlg=dlg).set_current() ), + dc.DlgChoice(short_name='i', row_def=f'| Iteration | {self.get_iteration()}', on_choose=lambda dlg: self.get_iteration_dlg(parent_dlg=dlg).set_current() ), @@ -435,6 +449,13 @@ class FaceAlignerTrainerApp: on_back = lambda dlg: parent_dlg.recreate().set_current(), top_rows_def='|c9 Set iteration', ) + def get_learning_rate_dlg(self, parent_dlg): + return dc.DlgNumber(is_float=True, min_value=0, max_value=0.1, + on_value = lambda dlg, value: (self.set_learning_rate(value), parent_dlg.recreate().set_current()), + on_recreate = lambda dlg: self.get_learning_rate_dlg(parent_dlg), + on_back = lambda dlg: parent_dlg.recreate().set_current(), + top_rows_def='|c9 Set learning rate', ) + def get_training_device_dlg(self, parent_dlg): return DlgTorchDevicesInfo(on_device_choice = lambda dlg, device_info: (self.set_device_info(device_info), parent_dlg.recreate().set_current()), on_recreate = lambda dlg: self.get_training_device_dlg(parent_dlg), @@ -442,7 +463,6 @@ class FaceAlignerTrainerApp: top_rows_def='|c9 Choose device' ) - class DlgTorchDevicesInfo(dc.DlgChoices): def __init__(self, on_device_choice : Callable = None, on_device_multi_choice : Callable = None, diff --git a/xlib/torch/optim/AdaBelief.py b/xlib/torch/optim/AdaBelief.py index ee06716..ec095b0 100644 --- a/xlib/torch/optim/AdaBelief.py +++ b/xlib/torch/optim/AdaBelief.py @@ -6,10 +6,7 @@ class AdaBelief(torch.optim.Optimizer): defaults = dict(lr=lr, lr_dropout=lr_dropout, betas=betas, eps=eps, weight_decay=weight_decay) super(AdaBelief, self).__init__(params, defaults) - - def __setstate__(self, state): - super(AdaBelief, self).__setstate__(state) - + def reset(self): for group in self.param_groups: for p in group['params']: @@ -17,6 +14,10 @@ class AdaBelief(torch.optim.Optimizer): state['step'] = 0 state['m_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) state['v_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) + + def set_lr(self, lr): + for group in self.param_groups: + group['lr'] = lr def step(self): for group in self.param_groups: