diff --git a/apps/trainers/FaceAligner/FaceAlignerTrainerApp.py b/apps/trainers/FaceAligner/FaceAlignerTrainerApp.py index e045021..ccef4af 100644 --- a/apps/trainers/FaceAligner/FaceAlignerTrainerApp.py +++ b/apps/trainers/FaceAligner/FaceAlignerTrainerApp.py @@ -43,6 +43,7 @@ class FaceAlignerTrainerApp: self._is_previewing_samples = False self._new_preview_data : 'PreviewData' = None self._last_save_time = None + self._req_is_training = None # settings / params self._model_data = None @@ -52,7 +53,7 @@ class FaceAlignerTrainerApp: self._resolution = None self._iteration = None self._autosave_period = None - self._is_training = None + self._is_training = False self._loss_history = {} # Generators @@ -96,14 +97,9 @@ class FaceAlignerTrainerApp: def set_autosave_period(self, mins : int): self._autosave_period = mins - def get_is_training(self) -> bool: return self._is_training + def get_is_training(self) -> bool: return self._req_is_training if self._req_is_training is not None else self._is_training def set_is_training(self, training : bool): - if self._is_training != training: - if training: - self._last_save_time = time.time() - else: - self._last_save_time = None - self._is_training = training + self._req_is_training = training def get_loss_history(self): return self._loss_history def set_loss_history(self, lh): self._loss_history = lh @@ -238,71 +234,36 @@ class FaceAlignerTrainerApp: self._model.to(self._device) self._model_optimizer.load_state_dict(self._model_optimizer.state_dict()) - if self._is_training or \ - self._is_previewing_samples or \ - self._ev_request_preview.is_set(): - training_data = self._training_generator.get_next_data(wait=False) - if training_data is not None and \ - training_data.resolution == self.get_resolution(): # Skip if resolution is different, due to delay - - if self._is_training: - self._model_optimizer.zero_grad() - - if self._ev_request_preview.is_set() or \ - self._is_training: - # Inference for both preview and training - img_aligned_shifted_t = torch.tensor(training_data.img_aligned_shifted).to(self._device) - shift_uni_mats_pred_t = self._model(img_aligned_shifted_t).view( (-1,2,3) ) - - if self._is_training: - # Training optimization step - shift_uni_mats_t = torch.tensor(training_data.shift_uni_mats).to(self._device) - - loss_t = (shift_uni_mats_pred_t-shift_uni_mats_t).square().mean()*10.0 - loss_t.backward() - self._model_optimizer.step() - - loss = loss_t.detach().cpu().numpy() - - rec_loss_history = self._loss_history.get('reconstruct', None) - if rec_loss_history is None: - rec_loss_history = self._loss_history['reconstruct'] = [] - rec_loss_history.append(float(loss)) - - self.set_iteration( self.get_iteration() + 1 ) - - if self._ev_request_preview.is_set(): - self._ev_request_preview.clear() - # Preview request - pd = PreviewData() - pd.training_data = training_data - pd.shift_uni_mats_pred = shift_uni_mats_pred_t.detach().cpu().numpy() - self._new_preview_data = pd - - if self._is_previewing_samples: - self._new_viewing_data = training_data - - if self._is_training: - if self._last_save_time is not None: - while (time.time()-self._last_save_time)/60 >= self._autosave_period: - self._last_save_time += self._autosave_period*60 - self._ev_request_save.set() + if self._req_is_training is not None: + if self._req_is_training != self._is_training: + if self._req_is_training: + self._last_save_time = time.time() + else: + self._last_save_time = None + torch.cuda.empty_cache() + self._is_training = self._req_is_training + self._req_is_training = None + + self.main_loop_training() + if self._is_training and self._last_save_time is not None: + while (time.time()-self._last_save_time)/60 >= self._autosave_period: + self._last_save_time += self._autosave_period*60 + self._ev_request_save.set() if self._ev_request_export_model.is_set(): self._ev_request_export_model.clear() print('Exporting...') self.export() print('Exporting done.') - dc.Diacon.update_dlg() + dc.Diacon.get_current_dlg().recreate().set_current() if self._ev_request_save.is_set(): self._ev_request_save.clear() print('Saving...') self.save() print('Saving done.') - - dc.Diacon.update_dlg() + dc.Diacon.get_current_dlg().recreate().set_current() if self._ev_request_quit.is_set(): self._ev_request_quit.clear() @@ -310,6 +271,53 @@ class FaceAlignerTrainerApp: time.sleep(0.005) + def main_loop_training(self): + """ + separated function, because torch tensors refences must be freed from python locals + """ + if self._is_training or \ + self._is_previewing_samples or \ + self._ev_request_preview.is_set(): + training_data = self._training_generator.get_next_data(wait=False) + if training_data is not None and \ + training_data.resolution == self.get_resolution(): # Skip if resolution is different, due to delay + + if self._is_training: + self._model_optimizer.zero_grad() + + if self._ev_request_preview.is_set() or \ + self._is_training: + # Inference for both preview and training + img_aligned_shifted_t = torch.tensor(training_data.img_aligned_shifted).to(self._device) + shift_uni_mats_pred_t = self._model(img_aligned_shifted_t).view( (-1,2,3) ) + + if self._is_training: + # Training optimization step + shift_uni_mats_t = torch.tensor(training_data.shift_uni_mats).to(self._device) + + loss_t = (shift_uni_mats_pred_t-shift_uni_mats_t).square().mean()*10.0 + loss_t.backward() + self._model_optimizer.step() + + loss = loss_t.detach().cpu().numpy() + + rec_loss_history = self._loss_history.get('reconstruct', None) + if rec_loss_history is None: + rec_loss_history = self._loss_history['reconstruct'] = [] + rec_loss_history.append(float(loss)) + + self.set_iteration( self.get_iteration() + 1 ) + + if self._ev_request_preview.is_set(): + self._ev_request_preview.clear() + # Preview request + pd = PreviewData() + pd.training_data = training_data + pd.shift_uni_mats_pred = shift_uni_mats_pred_t.detach().cpu().numpy() + self._new_preview_data = pd + + if self._is_previewing_samples: + self._new_viewing_data = training_data def get_main_dlg(self): last_loss = 0 @@ -334,7 +342,7 @@ class FaceAlignerTrainerApp: dc.DlgChoice(short_name='p', row_def='| Show current preview.', on_choose=lambda dlg: (self._ev_request_preview.set(), dlg.recreate().set_current())), - dc.DlgChoice(short_name='t', row_def=f'| Training | {self._is_training}', + dc.DlgChoice(short_name='t', row_def=f'| Training | {self.get_is_training()}', on_choose=lambda dlg: (self.set_is_training(not self.get_is_training()), dlg.recreate().set_current()) ), dc.DlgChoice(short_name='reset', row_def='| Reset model.', @@ -371,8 +379,8 @@ class FaceAlignerTrainerApp: lh_ar = np.array(lh[-d*max_lines:], np.float32) lh_ar = lh_ar.reshape( (max_lines, d)) lh_ar_max, lh_ar_min, lh_ar_mean, lh_ar_median = lh_ar.max(-1), lh_ar.min(-1), lh_ar.mean(-1), np.median(lh_ar, -1) - - + + print( '\n'.join( f'max:[{max_value:.5f}] min:[{min_value:.5f}] mean:[{mean_value:.5f}] median:[{median_value:.5f}]' for max_value, min_value, mean_value, median_value in zip(lh_ar_max, lh_ar_min, lh_ar_mean, lh_ar_median) ) ) dlg.recreate().set_current() diff --git a/xlib/console/diacon/Diacon.py b/xlib/console/diacon/Diacon.py index 5f37af3..a6ddbd5 100644 --- a/xlib/console/diacon/Diacon.py +++ b/xlib/console/diacon/Diacon.py @@ -16,8 +16,8 @@ class EDlgMode(IntEnum): class DlgChoice: - def __init__(self, short_name : str = None, - row_def : str = None, + def __init__(self, short_name : str = None, + row_def : str = None, on_choose : Callable = None): if len(short_name) == 0: raise ValueError('Zero len short_name is not valid.') @@ -32,18 +32,18 @@ class DlgChoice: class Dlg: def __init__(self, on_recreate : Callable[ [], 'Dlg'] = None, on_back : Callable = None, - top_rows_def : Union[str, List[str]] = None, - bottom_rows_def : Union[str, List[str]] = None, + top_rows_def : Union[str, List[str]] = None, + bottom_rows_def : Union[str, List[str]] = None, ): """ base class for Diacon dialogs. """ self._on_recreate = on_recreate self._on_back = on_back - + self._top_rows_def = top_rows_def self._bottom_rows_def = bottom_rows_def - + def recreate(self): """ """ @@ -51,27 +51,27 @@ class Dlg: return self._on_recreate(self) else: raise Exception('on_recreate() is not defined.') - + def set_current(self, print=True): Diacon.update_dlg(self, print=print) - + def handle_user_input(self, s : str): """ """ mode = self.on_user_input(s.strip()) - + if mode == EDlgMode.UNHANDLED: mode = EDlgMode.RELOAD if mode == EDlgMode.WRONG_INPUT: print('\nWrong input') mode = EDlgMode.RELOAD - + if mode == EDlgMode.RELOAD: self.recreate().set_current() if mode == EDlgMode.BACK: if self._on_back is not None: self._on_back(self) - + #overridable def on_user_input(self, s : str) -> EDlgMode: if len(s) == 0: @@ -80,13 +80,13 @@ class Dlg: if s == '<': return EDlgMode.BACK return EDlgMode.UNHANDLED - + def print(self, table_width_max=80, col_spacing = 3): """ print dialog """ table_def : List[str]= [] - + trd = self._top_rows_def brd = self._bottom_rows_def if trd is not None: @@ -99,7 +99,7 @@ class Dlg: table_def.append('|99') table_def = self.on_print(table_def) - + if brd is not None: if not isinstance(brd, (list,tuple)): brd = [brd] @@ -119,7 +119,7 @@ class Dlg: def on_print(self, table_lines : List[Tuple[str,str]]): return table_lines - + class DlgNumber(Dlg): def __init__(self, is_float : bool, @@ -131,7 +131,7 @@ class DlgNumber(Dlg): on_value : Callable[ [Dlg, Number], None] = None, on_recreate : Callable[ [], 'Dlg'] = None, on_back : Callable = None, - top_rows_def : Union[str, List[str]] = None, + top_rows_def : Union[str, List[str]] = None, bottom_rows_def : Union[str, List[str]] = None, ): super().__init__(on_recreate=on_recreate, on_back=on_back, top_rows_def=top_rows_def, bottom_rows_def=bottom_rows_def) @@ -139,7 +139,7 @@ class DlgNumber(Dlg): raise ValueError('min_value > max_value') if clip_min_value is not None and clip_max_value is not None and clip_min_value > clip_max_value: raise ValueError('clip_min_value > clip_max_value') - + self._is_float = is_float self._current_value = current_value self._min_value = min_value @@ -152,21 +152,21 @@ class DlgNumber(Dlg): def on_print(self, table_def : List[str]): minv, maxv = self._min_value, self._max_value - + if self._is_float: line = '| * | Enter float number' else: line = '| * | Enter integer number' - + if minv is not None and maxv is None: line += f' in range: [{minv} ... )' elif minv is None and maxv is not None: line += f' in range: ( ... {maxv} ]' elif minv is not None and maxv is not None: line += f' in range: [{minv} ... {maxv} ]' - + table_def.append(line) - + return table_def #overridable @@ -192,7 +192,7 @@ class DlgNumber(Dlg): if self._clip_max_value is not None: if v > self._clip_max_value: v = self._clip_max_value - + if self._on_value is not None: self._on_value(self, v) return EDlgMode.HANDLED @@ -206,7 +206,7 @@ class DlgChoices(Dlg): on_multi_choice : Callable[ [ List[DlgChoice] ], None] = None, on_recreate : Callable[ [Dlg], Dlg] = None, on_back : Callable = None, - top_rows_def : Union[str, List[str]] = None, + top_rows_def : Union[str, List[str]] = None, bottom_rows_def : Union[str, List[str]] = None, ): super().__init__(on_recreate=on_recreate, on_back=on_back, top_rows_def=top_rows_def, bottom_rows_def=bottom_rows_def) @@ -252,7 +252,7 @@ class DlgChoices(Dlg): else: id = x[0] choices_id.append(id) - + if len(set(choices_id)) != len(choices_id): # Duplicate input return EDlgMode.WRONG_INPUT @@ -261,7 +261,7 @@ class DlgChoices(Dlg): on_choose = self._choices[id].get_on_choose() if on_choose is not None: on_choose(self) - + if self._on_multi_choice is not None: self._on_multi_choice(choices_id) @@ -304,6 +304,9 @@ class _Diacon: self._dialog_t = None self._input_t = None + def get_current_dlg(self) -> Union[Dlg, None]: + return self._current_dlg + def _input_thread(self,): while self._started: if self._input_request: @@ -335,7 +338,7 @@ class _Diacon: if input_result is not None: if self._current_dlg is not None: - self._current_dlg.handle_user_input(input_result) + self._current_dlg.handle_user_input(input_result) continue time.sleep(0.005) @@ -360,7 +363,7 @@ class _Diacon: """ if not self._started: self.start() - + self._new_dlg = (new_dlg, print) Diacon = _Diacon()