refactoring

This commit is contained in:
iperov 2021-11-16 15:28:42 +04:00
parent 99462a1a17
commit 5d011368a9
2 changed files with 101 additions and 90 deletions

View file

@ -43,6 +43,7 @@ class FaceAlignerTrainerApp:
self._is_previewing_samples = False
self._new_preview_data : 'PreviewData' = None
self._last_save_time = None
self._req_is_training = None
# settings / params
self._model_data = None
@ -52,7 +53,7 @@ class FaceAlignerTrainerApp:
self._resolution = None
self._iteration = None
self._autosave_period = None
self._is_training = None
self._is_training = False
self._loss_history = {}
# Generators
@ -96,14 +97,9 @@ class FaceAlignerTrainerApp:
def set_autosave_period(self, mins : int):
self._autosave_period = mins
def get_is_training(self) -> bool: return self._is_training
def get_is_training(self) -> bool: return self._req_is_training if self._req_is_training is not None else self._is_training
def set_is_training(self, training : bool):
if self._is_training != training:
if training:
self._last_save_time = time.time()
else:
self._last_save_time = None
self._is_training = training
self._req_is_training = training
def get_loss_history(self): return self._loss_history
def set_loss_history(self, lh): self._loss_history = lh
@ -238,71 +234,36 @@ class FaceAlignerTrainerApp:
self._model.to(self._device)
self._model_optimizer.load_state_dict(self._model_optimizer.state_dict())
if self._is_training or \
self._is_previewing_samples or \
self._ev_request_preview.is_set():
training_data = self._training_generator.get_next_data(wait=False)
if training_data is not None and \
training_data.resolution == self.get_resolution(): # Skip if resolution is different, due to delay
if self._is_training:
self._model_optimizer.zero_grad()
if self._ev_request_preview.is_set() or \
self._is_training:
# Inference for both preview and training
img_aligned_shifted_t = torch.tensor(training_data.img_aligned_shifted).to(self._device)
shift_uni_mats_pred_t = self._model(img_aligned_shifted_t).view( (-1,2,3) )
if self._is_training:
# Training optimization step
shift_uni_mats_t = torch.tensor(training_data.shift_uni_mats).to(self._device)
loss_t = (shift_uni_mats_pred_t-shift_uni_mats_t).square().mean()*10.0
loss_t.backward()
self._model_optimizer.step()
loss = loss_t.detach().cpu().numpy()
rec_loss_history = self._loss_history.get('reconstruct', None)
if rec_loss_history is None:
rec_loss_history = self._loss_history['reconstruct'] = []
rec_loss_history.append(float(loss))
self.set_iteration( self.get_iteration() + 1 )
if self._ev_request_preview.is_set():
self._ev_request_preview.clear()
# Preview request
pd = PreviewData()
pd.training_data = training_data
pd.shift_uni_mats_pred = shift_uni_mats_pred_t.detach().cpu().numpy()
self._new_preview_data = pd
if self._is_previewing_samples:
self._new_viewing_data = training_data
if self._is_training:
if self._last_save_time is not None:
while (time.time()-self._last_save_time)/60 >= self._autosave_period:
self._last_save_time += self._autosave_period*60
self._ev_request_save.set()
if self._req_is_training is not None:
if self._req_is_training != self._is_training:
if self._req_is_training:
self._last_save_time = time.time()
else:
self._last_save_time = None
torch.cuda.empty_cache()
self._is_training = self._req_is_training
self._req_is_training = None
self.main_loop_training()
if self._is_training and self._last_save_time is not None:
while (time.time()-self._last_save_time)/60 >= self._autosave_period:
self._last_save_time += self._autosave_period*60
self._ev_request_save.set()
if self._ev_request_export_model.is_set():
self._ev_request_export_model.clear()
print('Exporting...')
self.export()
print('Exporting done.')
dc.Diacon.update_dlg()
dc.Diacon.get_current_dlg().recreate().set_current()
if self._ev_request_save.is_set():
self._ev_request_save.clear()
print('Saving...')
self.save()
print('Saving done.')
dc.Diacon.update_dlg()
dc.Diacon.get_current_dlg().recreate().set_current()
if self._ev_request_quit.is_set():
self._ev_request_quit.clear()
@ -310,6 +271,53 @@ class FaceAlignerTrainerApp:
time.sleep(0.005)
def main_loop_training(self):
"""
separated function, because torch tensors refences must be freed from python locals
"""
if self._is_training or \
self._is_previewing_samples or \
self._ev_request_preview.is_set():
training_data = self._training_generator.get_next_data(wait=False)
if training_data is not None and \
training_data.resolution == self.get_resolution(): # Skip if resolution is different, due to delay
if self._is_training:
self._model_optimizer.zero_grad()
if self._ev_request_preview.is_set() or \
self._is_training:
# Inference for both preview and training
img_aligned_shifted_t = torch.tensor(training_data.img_aligned_shifted).to(self._device)
shift_uni_mats_pred_t = self._model(img_aligned_shifted_t).view( (-1,2,3) )
if self._is_training:
# Training optimization step
shift_uni_mats_t = torch.tensor(training_data.shift_uni_mats).to(self._device)
loss_t = (shift_uni_mats_pred_t-shift_uni_mats_t).square().mean()*10.0
loss_t.backward()
self._model_optimizer.step()
loss = loss_t.detach().cpu().numpy()
rec_loss_history = self._loss_history.get('reconstruct', None)
if rec_loss_history is None:
rec_loss_history = self._loss_history['reconstruct'] = []
rec_loss_history.append(float(loss))
self.set_iteration( self.get_iteration() + 1 )
if self._ev_request_preview.is_set():
self._ev_request_preview.clear()
# Preview request
pd = PreviewData()
pd.training_data = training_data
pd.shift_uni_mats_pred = shift_uni_mats_pred_t.detach().cpu().numpy()
self._new_preview_data = pd
if self._is_previewing_samples:
self._new_viewing_data = training_data
def get_main_dlg(self):
last_loss = 0
@ -334,7 +342,7 @@ class FaceAlignerTrainerApp:
dc.DlgChoice(short_name='p', row_def='| Show current preview.',
on_choose=lambda dlg: (self._ev_request_preview.set(), dlg.recreate().set_current())),
dc.DlgChoice(short_name='t', row_def=f'| Training | {self._is_training}',
dc.DlgChoice(short_name='t', row_def=f'| Training | {self.get_is_training()}',
on_choose=lambda dlg: (self.set_is_training(not self.get_is_training()), dlg.recreate().set_current()) ),
dc.DlgChoice(short_name='reset', row_def='| Reset model.',
@ -371,8 +379,8 @@ class FaceAlignerTrainerApp:
lh_ar = np.array(lh[-d*max_lines:], np.float32)
lh_ar = lh_ar.reshape( (max_lines, d))
lh_ar_max, lh_ar_min, lh_ar_mean, lh_ar_median = lh_ar.max(-1), lh_ar.min(-1), lh_ar.mean(-1), np.median(lh_ar, -1)
print( '\n'.join( f'max:[{max_value:.5f}] min:[{min_value:.5f}] mean:[{mean_value:.5f}] median:[{median_value:.5f}]' for max_value, min_value, mean_value, median_value in zip(lh_ar_max, lh_ar_min, lh_ar_mean, lh_ar_median) ) )
dlg.recreate().set_current()