diff --git a/xlib/torch/device.py b/xlib/torch/device.py index c94cfac..5a12437 100644 --- a/xlib/torch/device.py +++ b/xlib/torch/device.py @@ -25,9 +25,13 @@ class TorchDeviceInfo: return self._index def get_name(self) -> str: + if self.is_cpu(): + return 'CPU' return self._name def get_total_memory(self) -> int: + if self.is_cpu(): + return 0 return self._total_memory def __eq__(self, other): @@ -49,17 +53,23 @@ class TorchDeviceInfo: _torch_devices = None -def get_cpu_device_info() -> TorchDeviceInfo: - return TorchDeviceInfo(index=-1, name='CPU', total_memory=0) +def get_cpu_device_info() -> TorchDeviceInfo: return TorchDeviceInfo(index=-1) -def get_device_info_by_index(index) -> Union[TorchDeviceInfo, None]: +def get_device_info_by_index( index : int ) -> Union[TorchDeviceInfo, None]: + """ + index if -1, returns CPU Device info + """ + if index == -1: + return get_cpu_device_info() for device in get_available_devices_info(include_cpu=False): if device.get_index() == index: return device - return None def get_device(device_info : TorchDeviceInfo) -> torch.device: + """ + get physical torch.device from TorchDeviceInfo + """ if device_info.is_cpu(): return torch.device('cpu') return torch.device(f'cuda:{device_info.get_index()}') diff --git a/xlib/torch/optim/AdaBelief.py b/xlib/torch/optim/AdaBelief.py new file mode 100644 index 0000000..aedc91f --- /dev/null +++ b/xlib/torch/optim/AdaBelief.py @@ -0,0 +1,59 @@ +import torch +import numpy as np + +class AdaBelief(torch.optim.Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, + weight_decay=0, lr_dropout = 1.0): + + defaults = dict(lr=lr, lr_dropout=lr_dropout, betas=betas, eps=eps, weight_decay=weight_decay) + super(AdaBelief, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdaBelief, self).__setstate__(state) + + def reset(self): + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + state['step'] = 0 + state['m_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) + state['v_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) + + def step(self): + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad.data + + beta1, beta2 = group['betas'] + + state = self.state[p] + if len(state) == 0: + state['step'] = 0 + state['m_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) + state['v_t'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) + + if group['weight_decay'] != 0: + grad.add_(p.data, alpha=group['weight_decay']) + + state['step'] += 1 + + m_t, v_t = state['m_t'], state['v_t'] + m_t.mul_(beta1).add_( grad , alpha=1 - beta1) + v_t.mul_(beta2).add_( (grad - m_t)**2 , alpha=1 - beta2) + + v_diff = (-group['lr'] * m_t).div_( v_t.sqrt().add_(group['eps']) ) + + if group['lr_dropout'] < 1.0: + lrd_rand = torch.ones_like(p.data) + v_diff *= torch.bernoulli(lrd_rand * group['lr_dropout'] ) + + # from xlib.console.diacon import Diacon + # Diacon.stop() + # import code + # code.interact(local=dict(globals(), **locals())) + + + p.data.add_(v_diff) diff --git a/xlib/torch/optim/__init__.py b/xlib/torch/optim/__init__.py new file mode 100644 index 0000000..3c4058d --- /dev/null +++ b/xlib/torch/optim/__init__.py @@ -0,0 +1 @@ +from .AdaBelief import AdaBelief \ No newline at end of file