upd xlib.torch

This commit is contained in:
iperov 2021-11-11 23:09:26 +04:00
parent fd0ca499bf
commit 6e84cbe8b6
3 changed files with 74 additions and 4 deletions

View file

@ -25,9 +25,13 @@ class TorchDeviceInfo:
return self._index
def get_name(self) -> str:
if self.is_cpu():
return 'CPU'
return self._name
def get_total_memory(self) -> int:
if self.is_cpu():
return 0
return self._total_memory
def __eq__(self, other):
@ -49,17 +53,23 @@ class TorchDeviceInfo:
_torch_devices = None
def get_cpu_device_info() -> TorchDeviceInfo:
return TorchDeviceInfo(index=-1, name='CPU', total_memory=0)
def get_cpu_device_info() -> TorchDeviceInfo: return TorchDeviceInfo(index=-1)
def get_device_info_by_index(index) -> Union[TorchDeviceInfo, None]:
def get_device_info_by_index( index : int ) -> Union[TorchDeviceInfo, None]:
"""
index if -1, returns CPU Device info
"""
if index == -1:
return get_cpu_device_info()
for device in get_available_devices_info(include_cpu=False):
if device.get_index() == index:
return device
return None
def get_device(device_info : TorchDeviceInfo) -> torch.device:
"""
get physical torch.device from TorchDeviceInfo
"""
if device_info.is_cpu():
return torch.device('cpu')
return torch.device(f'cuda:{device_info.get_index()}')