mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-31 12:10:14 -07:00
upd xlib.torch
This commit is contained in:
parent
fd0ca499bf
commit
6e84cbe8b6
3 changed files with 74 additions and 4 deletions
|
@ -25,9 +25,13 @@ class TorchDeviceInfo:
|
|||
return self._index
|
||||
|
||||
def get_name(self) -> str:
|
||||
if self.is_cpu():
|
||||
return 'CPU'
|
||||
return self._name
|
||||
|
||||
def get_total_memory(self) -> int:
|
||||
if self.is_cpu():
|
||||
return 0
|
||||
return self._total_memory
|
||||
|
||||
def __eq__(self, other):
|
||||
|
@ -49,17 +53,23 @@ class TorchDeviceInfo:
|
|||
|
||||
_torch_devices = None
|
||||
|
||||
def get_cpu_device_info() -> TorchDeviceInfo:
|
||||
return TorchDeviceInfo(index=-1, name='CPU', total_memory=0)
|
||||
def get_cpu_device_info() -> TorchDeviceInfo: return TorchDeviceInfo(index=-1)
|
||||
|
||||
def get_device_info_by_index(index) -> Union[TorchDeviceInfo, None]:
|
||||
def get_device_info_by_index( index : int ) -> Union[TorchDeviceInfo, None]:
|
||||
"""
|
||||
index if -1, returns CPU Device info
|
||||
"""
|
||||
if index == -1:
|
||||
return get_cpu_device_info()
|
||||
for device in get_available_devices_info(include_cpu=False):
|
||||
if device.get_index() == index:
|
||||
return device
|
||||
|
||||
return None
|
||||
|
||||
def get_device(device_info : TorchDeviceInfo) -> torch.device:
|
||||
"""
|
||||
get physical torch.device from TorchDeviceInfo
|
||||
"""
|
||||
if device_info.is_cpu():
|
||||
return torch.device('cpu')
|
||||
return torch.device(f'cuda:{device_info.get_index()}')
|
||||
|
|
59
xlib/torch/optim/AdaBelief.py
Normal file
59
xlib/torch/optim/AdaBelief.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
class AdaBelief(torch.optim.Optimizer):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
||||
weight_decay=0, lr_dropout = 1.0):
|
||||
|
||||
defaults = dict(lr=lr, lr_dropout=lr_dropout, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super(AdaBelief, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(AdaBelief, self).__setstate__(state)
|
||||
|
||||
def reset(self):
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
state = self.state[p]
|
||||
state['step'] = 0
|
||||
state['m_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
|
||||
state['v_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
|
||||
|
||||
def step(self):
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
grad = p.grad.data
|
||||
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['m_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
|
||||
state['v_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format)
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
grad.add_(p.data, alpha=group['weight_decay'])
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
m_t, v_t = state['m_t'], state['v_t']
|
||||
m_t.mul_(beta1).add_( grad , alpha=1 - beta1)
|
||||
v_t.mul_(beta2).add_( (grad - m_t)**2 , alpha=1 - beta2)
|
||||
|
||||
v_diff = (-group['lr'] * m_t).div_( v_t.sqrt().add_(group['eps']) )
|
||||
|
||||
if group['lr_dropout'] < 1.0:
|
||||
lrd_rand = torch.ones_like(p.data)
|
||||
v_diff *= torch.bernoulli(lrd_rand * group['lr_dropout'] )
|
||||
|
||||
# from xlib.console.diacon import Diacon
|
||||
# Diacon.stop()
|
||||
# import code
|
||||
# code.interact(local=dict(globals(), **locals()))
|
||||
|
||||
|
||||
p.data.add_(v_diff)
|
1
xlib/torch/optim/__init__.py
Normal file
1
xlib/torch/optim/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .AdaBelief import AdaBelief
|
Loading…
Add table
Add a link
Reference in a new issue