mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-08-14 02:37:01 -07:00
_dev parts upd
This commit is contained in:
parent
d30dd4b46d
commit
6478551358
2 changed files with 37 additions and 16 deletions
|
@ -53,12 +53,16 @@ class FaceAlignerTrainerApp:
|
|||
self._device_info = None
|
||||
self._batch_size = None
|
||||
self._resolution = None
|
||||
self._learning_rate = None
|
||||
self._random_warp = None
|
||||
self._iteration = None
|
||||
self._autosave_period = None
|
||||
self._is_training = False
|
||||
self._loss_history = {}
|
||||
|
||||
self._model = None
|
||||
self._model_optimizer = None
|
||||
|
||||
# Generators
|
||||
self._training_generator = TrainingDataGenerator(faceset_path)
|
||||
|
||||
|
@ -90,6 +94,12 @@ class FaceAlignerTrainerApp:
|
|||
self._resolution = resolution
|
||||
self._training_generator.set_resolution(resolution)
|
||||
|
||||
def get_learning_rate(self) -> float: return self._learning_rate
|
||||
def set_learning_rate(self, learning_rate : float):
|
||||
self._learning_rate = learning_rate
|
||||
if self._model_optimizer is not None:
|
||||
self._model_optimizer.set_lr(learning_rate)
|
||||
|
||||
def get_random_warp(self) -> bool: return self._random_warp
|
||||
def set_random_warp(self, random_warp : bool):
|
||||
self._random_warp = random_warp
|
||||
|
@ -118,6 +128,7 @@ class FaceAlignerTrainerApp:
|
|||
self.set_device_info( lib_torch.get_device_info_by_index(model_data.get('device_index', -1)) )
|
||||
self.set_batch_size( model_data.get('batch_size', 64) )
|
||||
self.set_resolution( model_data.get('resolution', 224) )
|
||||
self.set_learning_rate( model_data.get('learning_rate', 5e-5) )
|
||||
self.set_random_warp( model_data.get('random_warp', True) )
|
||||
self.set_iteration( model_data.get('iteration', 0) )
|
||||
self.set_autosave_period( model_data.get('autosave_period', 25) )
|
||||
|
@ -234,7 +245,7 @@ class FaceAlignerTrainerApp:
|
|||
def main_loop(self):
|
||||
|
||||
while not self._is_quit:
|
||||
|
||||
|
||||
if self._ev_request_reset_model.is_set():
|
||||
self._ev_request_reset_model.clear()
|
||||
self.reset_model(load=False)
|
||||
|
@ -267,20 +278,20 @@ class FaceAlignerTrainerApp:
|
|||
self.export()
|
||||
print('Exporting done.')
|
||||
dc.Diacon.get_current_dlg().recreate().set_current()
|
||||
|
||||
|
||||
if self._ev_request_save.is_set():
|
||||
self._ev_request_save.clear()
|
||||
print('Saving...')
|
||||
self.save()
|
||||
print('Saving done.')
|
||||
dc.Diacon.get_current_dlg().recreate().set_current()
|
||||
|
||||
|
||||
if self._ev_request_quit.is_set():
|
||||
self._ev_request_quit.clear()
|
||||
self._is_quit = True
|
||||
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
|
||||
def main_loop_training(self):
|
||||
"""
|
||||
separated function, because torch tensors refences must be freed from python locals
|
||||
|
@ -288,21 +299,21 @@ class FaceAlignerTrainerApp:
|
|||
if self._is_training or \
|
||||
self._is_previewing_samples or \
|
||||
self._ev_request_preview.is_set():
|
||||
|
||||
|
||||
training_data = self._training_generator.get_next_data(wait=False)
|
||||
if training_data is not None and \
|
||||
training_data.resolution == self.get_resolution(): # Skip if resolution is different, due to delay
|
||||
if self._is_training:
|
||||
self._model_optimizer.zero_grad()
|
||||
|
||||
|
||||
if self._ev_request_preview.is_set() or \
|
||||
self._is_training:
|
||||
# Inference for both preview and training
|
||||
img_aligned_shifted_t = torch.tensor(training_data.img_aligned_shifted).to(self._device)
|
||||
shift_uni_mats_pred_t = self._model(img_aligned_shifted_t)#.view( (-1,2,3) )
|
||||
|
||||
|
||||
if self._is_training:
|
||||
# Training optimization step
|
||||
# Training optimization step
|
||||
shift_uni_mats_t = torch.tensor(training_data.shift_uni_mats).to(self._device)
|
||||
loss_t = (shift_uni_mats_pred_t-shift_uni_mats_t).square().mean()*10.0
|
||||
loss_t.backward()
|
||||
|
@ -313,7 +324,7 @@ class FaceAlignerTrainerApp:
|
|||
rec_loss_history = self._loss_history['reconstruct'] = []
|
||||
rec_loss_history.append(float(loss))
|
||||
self.set_iteration( self.get_iteration() + 1 )
|
||||
|
||||
|
||||
if self._ev_request_preview.is_set():
|
||||
self._ev_request_preview.clear()
|
||||
# Preview request
|
||||
|
@ -321,7 +332,7 @@ class FaceAlignerTrainerApp:
|
|||
pd.training_data = training_data
|
||||
pd.shift_uni_mats_pred = shift_uni_mats_pred_t.detach().cpu().numpy()
|
||||
self._new_preview_data = pd
|
||||
|
||||
|
||||
if self._is_previewing_samples:
|
||||
self._new_viewing_data = training_data
|
||||
|
||||
|
@ -339,6 +350,9 @@ class FaceAlignerTrainerApp:
|
|||
dc.DlgChoice(short_name='d', row_def=f'| Device | {self._device_info}',
|
||||
on_choose=lambda dlg: self.get_training_device_dlg(dlg).set_current()),
|
||||
|
||||
dc.DlgChoice(short_name='lr', row_def=f'| Learning rate | {self.get_learning_rate()}',
|
||||
on_choose=lambda dlg: self.get_learning_rate_dlg(parent_dlg=dlg).set_current() ),
|
||||
|
||||
dc.DlgChoice(short_name='i', row_def=f'| Iteration | {self.get_iteration()}',
|
||||
on_choose=lambda dlg: self.get_iteration_dlg(parent_dlg=dlg).set_current() ),
|
||||
|
||||
|
@ -435,6 +449,13 @@ class FaceAlignerTrainerApp:
|
|||
on_back = lambda dlg: parent_dlg.recreate().set_current(),
|
||||
top_rows_def='|c9 Set iteration', )
|
||||
|
||||
def get_learning_rate_dlg(self, parent_dlg):
|
||||
return dc.DlgNumber(is_float=True, min_value=0, max_value=0.1,
|
||||
on_value = lambda dlg, value: (self.set_learning_rate(value), parent_dlg.recreate().set_current()),
|
||||
on_recreate = lambda dlg: self.get_learning_rate_dlg(parent_dlg),
|
||||
on_back = lambda dlg: parent_dlg.recreate().set_current(),
|
||||
top_rows_def='|c9 Set learning rate', )
|
||||
|
||||
def get_training_device_dlg(self, parent_dlg):
|
||||
return DlgTorchDevicesInfo(on_device_choice = lambda dlg, device_info: (self.set_device_info(device_info), parent_dlg.recreate().set_current()),
|
||||
on_recreate = lambda dlg: self.get_training_device_dlg(parent_dlg),
|
||||
|
@ -442,7 +463,6 @@ class FaceAlignerTrainerApp:
|
|||
top_rows_def='|c9 Choose device'
|
||||
)
|
||||
|
||||
|
||||
class DlgTorchDevicesInfo(dc.DlgChoices):
|
||||
def __init__(self, on_device_choice : Callable = None,
|
||||
on_device_multi_choice : Callable = None,
|
||||
|
|
|
@ -6,10 +6,7 @@ class AdaBelief(torch.optim.Optimizer):
|
|||
|
||||
defaults = dict(lr=lr, lr_dropout=lr_dropout, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super(AdaBelief, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(AdaBelief, self).__setstate__(state)
|
||||
|
||||
|
||||
def reset(self):
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
|
@ -17,6 +14,10 @@ class AdaBelief(torch.optim.Optimizer):
|
|||
state['step'] = 0
|
||||
state['m_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
|
||||
state['v_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
|
||||
|
||||
def set_lr(self, lr):
|
||||
for group in self.param_groups:
|
||||
group['lr'] = lr
|
||||
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue