mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-16 10:03:42 -07:00
_
This commit is contained in:
parent
69f71ddecd
commit
d40fce6a5a
4 changed files with 99 additions and 75 deletions
|
@ -94,7 +94,7 @@ class FaceAlignerTrainerApp:
|
|||
def set_random_warp(self, random_warp : bool):
|
||||
self._random_warp = random_warp
|
||||
self._training_generator.set_random_warp(random_warp)
|
||||
|
||||
|
||||
def get_iteration(self) -> int: return self._iteration
|
||||
def set_iteration(self, iteration : int):
|
||||
self._iteration = iteration
|
||||
|
@ -234,7 +234,7 @@ class FaceAlignerTrainerApp:
|
|||
def main_loop(self):
|
||||
|
||||
while not self._is_quit:
|
||||
|
||||
|
||||
if self._ev_request_reset_model.is_set():
|
||||
self._ev_request_reset_model.clear()
|
||||
self.reset_model(load=False)
|
||||
|
@ -253,7 +253,7 @@ class FaceAlignerTrainerApp:
|
|||
torch.cuda.empty_cache()
|
||||
self._is_training = self._req_is_training
|
||||
self._req_is_training = None
|
||||
|
||||
|
||||
self.main_loop_training()
|
||||
|
||||
if self._is_training and self._last_save_time is not None:
|
||||
|
@ -267,20 +267,20 @@ class FaceAlignerTrainerApp:
|
|||
self.export()
|
||||
print('Exporting done.')
|
||||
dc.Diacon.get_current_dlg().recreate().set_current()
|
||||
|
||||
|
||||
if self._ev_request_save.is_set():
|
||||
self._ev_request_save.clear()
|
||||
print('Saving...')
|
||||
self.save()
|
||||
print('Saving done.')
|
||||
dc.Diacon.get_current_dlg().recreate().set_current()
|
||||
|
||||
|
||||
if self._ev_request_quit.is_set():
|
||||
self._ev_request_quit.clear()
|
||||
self._is_quit = True
|
||||
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
|
||||
def main_loop_training(self):
|
||||
"""
|
||||
separated function, because torch tensors refences must be freed from python locals
|
||||
|
@ -288,36 +288,32 @@ class FaceAlignerTrainerApp:
|
|||
if self._is_training or \
|
||||
self._is_previewing_samples or \
|
||||
self._ev_request_preview.is_set():
|
||||
|
||||
training_data = self._training_generator.get_next_data(wait=False)
|
||||
if training_data is not None and \
|
||||
training_data.resolution == self.get_resolution(): # Skip if resolution is different, due to delay
|
||||
|
||||
training_data.resolution == self.get_resolution(): # Skip if resolution is different, due to delay
|
||||
if self._is_training:
|
||||
self._model_optimizer.zero_grad()
|
||||
|
||||
|
||||
if self._ev_request_preview.is_set() or \
|
||||
self._is_training:
|
||||
# Inference for both preview and training
|
||||
img_aligned_shifted_t = torch.tensor(training_data.img_aligned_shifted).to(self._device)
|
||||
shift_uni_mats_pred_t = self._model(img_aligned_shifted_t)#.view( (-1,2,3) )
|
||||
|
||||
|
||||
if self._is_training:
|
||||
# Training optimization step
|
||||
# Training optimization step
|
||||
shift_uni_mats_t = torch.tensor(training_data.shift_uni_mats).to(self._device)
|
||||
|
||||
loss_t = (shift_uni_mats_pred_t-shift_uni_mats_t).square().mean()*10.0
|
||||
loss_t.backward()
|
||||
self._model_optimizer.step()
|
||||
|
||||
loss = loss_t.detach().cpu().numpy()
|
||||
|
||||
rec_loss_history = self._loss_history.get('reconstruct', None)
|
||||
if rec_loss_history is None:
|
||||
rec_loss_history = self._loss_history['reconstruct'] = []
|
||||
rec_loss_history.append(float(loss))
|
||||
|
||||
self.set_iteration( self.get_iteration() + 1 )
|
||||
|
||||
|
||||
if self._ev_request_preview.is_set():
|
||||
self._ev_request_preview.clear()
|
||||
# Preview request
|
||||
|
@ -325,7 +321,7 @@ class FaceAlignerTrainerApp:
|
|||
pd.training_data = training_data
|
||||
pd.shift_uni_mats_pred = shift_uni_mats_pred_t.detach().cpu().numpy()
|
||||
self._new_preview_data = pd
|
||||
|
||||
|
||||
if self._is_previewing_samples:
|
||||
self._new_viewing_data = training_data
|
||||
|
||||
|
@ -348,7 +344,7 @@ class FaceAlignerTrainerApp:
|
|||
|
||||
dc.DlgChoice(short_name='l', row_def=f'| Print loss history | Last loss = {last_loss:.5f} ',
|
||||
on_choose=self.on_main_dlg_print_loss_history ),
|
||||
|
||||
|
||||
dc.DlgChoice(short_name='p', row_def='| Show current preview.',
|
||||
on_choose=lambda dlg: (self._ev_request_preview.set(), dlg.recreate().set_current())),
|
||||
|
||||
|
@ -375,8 +371,6 @@ class FaceAlignerTrainerApp:
|
|||
def on_main_dlg_quit(self, dlg):
|
||||
self._ev_request_quit.set()
|
||||
|
||||
|
||||
|
||||
def on_main_dlg_print_loss_history(self, dlg):
|
||||
max_lines = 20
|
||||
for key in self._loss_history.keys():
|
||||
|
@ -384,14 +378,16 @@ class FaceAlignerTrainerApp:
|
|||
|
||||
print(f'Loss history for: {key}')
|
||||
|
||||
d = len(lh) // max_lines
|
||||
lh_len = len(lh)
|
||||
if lh_len >= max_lines:
|
||||
d = len(lh) // max_lines
|
||||
|
||||
lh_ar = np.array(lh[-d*max_lines:], np.float32)
|
||||
lh_ar = lh_ar.reshape( (max_lines, d))
|
||||
lh_ar_max, lh_ar_min, lh_ar_mean, lh_ar_median = lh_ar.max(-1), lh_ar.min(-1), lh_ar.mean(-1), np.median(lh_ar, -1)
|
||||
lh_ar = np.array(lh[-d*max_lines:], np.float32)
|
||||
lh_ar = lh_ar.reshape( (max_lines, d))
|
||||
lh_ar_max, lh_ar_min, lh_ar_mean, lh_ar_median = lh_ar.max(-1), lh_ar.min(-1), lh_ar.mean(-1), np.median(lh_ar, -1)
|
||||
|
||||
|
||||
print( '\n'.join( f'max:[{max_value:.5f}] min:[{min_value:.5f}] mean:[{mean_value:.5f}] median:[{median_value:.5f}]' for max_value, min_value, mean_value, median_value in zip(lh_ar_max, lh_ar_min, lh_ar_mean, lh_ar_median) ) )
|
||||
print( '\n'.join( f'max:[{max_value:.5f}] min:[{min_value:.5f}] mean:[{mean_value:.5f}] median:[{median_value:.5f}]' for max_value, min_value, mean_value, median_value in zip(lh_ar_max, lh_ar_min, lh_ar_mean, lh_ar_median) ) )
|
||||
|
||||
dlg.recreate().set_current()
|
||||
|
||||
|
@ -401,7 +397,7 @@ class FaceAlignerTrainerApp:
|
|||
dc.DlgChoice(short_name='v', row_def=f'| Previewing samples | {self._is_previewing_samples}',
|
||||
on_choose=self.on_sample_generator_dlg_previewing_last_samples,
|
||||
),
|
||||
|
||||
|
||||
dc.DlgChoice(short_name='rw', row_def=f'| Random warp | {self.get_random_warp()}',
|
||||
on_choose=lambda dlg: (self.set_random_warp(not self.get_random_warp()), dlg.recreate().set_current()) ),
|
||||
|
||||
|
@ -469,4 +465,4 @@ class DlgTorchDevicesInfo(dc.DlgChoices):
|
|||
|
||||
class PreviewData:
|
||||
training_data : Data = None
|
||||
shift_uni_mats_pred = None
|
||||
shift_uni_mats_pred = None
|
Loading…
Add table
Add a link
Reference in a new issue