mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-08-14 02:37:01 -07:00
_dev parts upd
This commit is contained in:
parent
d30dd4b46d
commit
6478551358
2 changed files with 37 additions and 16 deletions
|
@ -6,10 +6,7 @@ class AdaBelief(torch.optim.Optimizer):
|
|||
|
||||
defaults = dict(lr=lr, lr_dropout=lr_dropout, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super(AdaBelief, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(AdaBelief, self).__setstate__(state)
|
||||
|
||||
|
||||
def reset(self):
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
|
@ -17,6 +14,10 @@ class AdaBelief(torch.optim.Optimizer):
|
|||
state['step'] = 0
|
||||
state['m_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
|
||||
state['v_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
|
||||
|
||||
def set_lr(self, lr):
|
||||
for group in self.param_groups:
|
||||
group['lr'] = lr
|
||||
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue