This commit is contained in:
iperov 2021-11-13 18:15:54 +04:00
commit 97991ca5cf
2 changed files with 8 additions and 13 deletions

View file

@ -1,5 +1,4 @@
import torch
import numpy as np
class AdaBelief(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
@ -39,21 +38,15 @@ class AdaBelief(torch.optim.Optimizer):
grad.add_(p.data, alpha=group['weight_decay'])
state['step'] += 1
m_t, v_t = state['m_t'], state['v_t']
m_t.mul_(beta1).add_( grad , alpha=1 - beta1)
v_t.mul_(beta2).add_( (grad - m_t)**2 , alpha=1 - beta2)
v_diff = (-group['lr'] * m_t).div_( v_t.sqrt().add_(group['eps']) )
if group['lr_dropout'] < 1.0:
lrd_rand = torch.ones_like(p.data)
v_diff *= torch.bernoulli(lrd_rand * group['lr_dropout'] )
# from xlib.console.diacon import Diacon
# Diacon.stop()
# import code
# code.interact(local=dict(globals(), **locals()))
p.data.add_(v_diff)