mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-06 04:52:14 -07:00
refactoring
This commit is contained in:
parent
99462a1a17
commit
5d011368a9
2 changed files with 101 additions and 90 deletions
|
@ -43,6 +43,7 @@ class FaceAlignerTrainerApp:
|
|||
self._is_previewing_samples = False
|
||||
self._new_preview_data : 'PreviewData' = None
|
||||
self._last_save_time = None
|
||||
self._req_is_training = None
|
||||
|
||||
# settings / params
|
||||
self._model_data = None
|
||||
|
@ -52,7 +53,7 @@ class FaceAlignerTrainerApp:
|
|||
self._resolution = None
|
||||
self._iteration = None
|
||||
self._autosave_period = None
|
||||
self._is_training = None
|
||||
self._is_training = False
|
||||
self._loss_history = {}
|
||||
|
||||
# Generators
|
||||
|
@ -96,14 +97,9 @@ class FaceAlignerTrainerApp:
|
|||
def set_autosave_period(self, mins : int):
|
||||
self._autosave_period = mins
|
||||
|
||||
def get_is_training(self) -> bool: return self._is_training
|
||||
def get_is_training(self) -> bool: return self._req_is_training if self._req_is_training is not None else self._is_training
|
||||
def set_is_training(self, training : bool):
|
||||
if self._is_training != training:
|
||||
if training:
|
||||
self._last_save_time = time.time()
|
||||
else:
|
||||
self._last_save_time = None
|
||||
self._is_training = training
|
||||
self._req_is_training = training
|
||||
|
||||
def get_loss_history(self): return self._loss_history
|
||||
def set_loss_history(self, lh): self._loss_history = lh
|
||||
|
@ -238,6 +234,47 @@ class FaceAlignerTrainerApp:
|
|||
self._model.to(self._device)
|
||||
self._model_optimizer.load_state_dict(self._model_optimizer.state_dict())
|
||||
|
||||
if self._req_is_training is not None:
|
||||
if self._req_is_training != self._is_training:
|
||||
if self._req_is_training:
|
||||
self._last_save_time = time.time()
|
||||
else:
|
||||
self._last_save_time = None
|
||||
torch.cuda.empty_cache()
|
||||
self._is_training = self._req_is_training
|
||||
self._req_is_training = None
|
||||
|
||||
self.main_loop_training()
|
||||
|
||||
if self._is_training and self._last_save_time is not None:
|
||||
while (time.time()-self._last_save_time)/60 >= self._autosave_period:
|
||||
self._last_save_time += self._autosave_period*60
|
||||
self._ev_request_save.set()
|
||||
|
||||
if self._ev_request_export_model.is_set():
|
||||
self._ev_request_export_model.clear()
|
||||
print('Exporting...')
|
||||
self.export()
|
||||
print('Exporting done.')
|
||||
dc.Diacon.get_current_dlg().recreate().set_current()
|
||||
|
||||
if self._ev_request_save.is_set():
|
||||
self._ev_request_save.clear()
|
||||
print('Saving...')
|
||||
self.save()
|
||||
print('Saving done.')
|
||||
dc.Diacon.get_current_dlg().recreate().set_current()
|
||||
|
||||
if self._ev_request_quit.is_set():
|
||||
self._ev_request_quit.clear()
|
||||
self._is_quit = True
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
def main_loop_training(self):
|
||||
"""
|
||||
separated function, because torch tensors refences must be freed from python locals
|
||||
"""
|
||||
if self._is_training or \
|
||||
self._is_previewing_samples or \
|
||||
self._ev_request_preview.is_set():
|
||||
|
@ -282,35 +319,6 @@ class FaceAlignerTrainerApp:
|
|||
if self._is_previewing_samples:
|
||||
self._new_viewing_data = training_data
|
||||
|
||||
if self._is_training:
|
||||
if self._last_save_time is not None:
|
||||
while (time.time()-self._last_save_time)/60 >= self._autosave_period:
|
||||
self._last_save_time += self._autosave_period*60
|
||||
self._ev_request_save.set()
|
||||
|
||||
|
||||
if self._ev_request_export_model.is_set():
|
||||
self._ev_request_export_model.clear()
|
||||
print('Exporting...')
|
||||
self.export()
|
||||
print('Exporting done.')
|
||||
dc.Diacon.update_dlg()
|
||||
|
||||
if self._ev_request_save.is_set():
|
||||
self._ev_request_save.clear()
|
||||
print('Saving...')
|
||||
self.save()
|
||||
print('Saving done.')
|
||||
|
||||
dc.Diacon.update_dlg()
|
||||
|
||||
if self._ev_request_quit.is_set():
|
||||
self._ev_request_quit.clear()
|
||||
self._is_quit = True
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
|
||||
def get_main_dlg(self):
|
||||
last_loss = 0
|
||||
rec_loss_history = self._loss_history.get('reconstruct', None)
|
||||
|
@ -334,7 +342,7 @@ class FaceAlignerTrainerApp:
|
|||
dc.DlgChoice(short_name='p', row_def='| Show current preview.',
|
||||
on_choose=lambda dlg: (self._ev_request_preview.set(), dlg.recreate().set_current())),
|
||||
|
||||
dc.DlgChoice(short_name='t', row_def=f'| Training | {self._is_training}',
|
||||
dc.DlgChoice(short_name='t', row_def=f'| Training | {self.get_is_training()}',
|
||||
on_choose=lambda dlg: (self.set_is_training(not self.get_is_training()), dlg.recreate().set_current()) ),
|
||||
|
||||
dc.DlgChoice(short_name='reset', row_def='| Reset model.',
|
||||
|
|
|
@ -304,6 +304,9 @@ class _Diacon:
|
|||
self._dialog_t = None
|
||||
self._input_t = None
|
||||
|
||||
def get_current_dlg(self) -> Union[Dlg, None]:
|
||||
return self._current_dlg
|
||||
|
||||
def _input_thread(self,):
|
||||
while self._started:
|
||||
if self._input_request:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue