mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-06 04:52:14 -07:00
140 lines
4.3 KiB
Python
140 lines
4.3 KiB
Python
from typing import List
|
|
|
|
|
|
class TorchDeviceInfo:
|
|
"""
|
|
Represents picklable torch device info
|
|
"""
|
|
def __init__(self, index=None, name=None, total_memory=None):
|
|
self._index : int = index
|
|
self._name : str = name
|
|
self._total_memory : int = total_memory
|
|
|
|
def __getstate__(self):
|
|
return self.__dict__.copy()
|
|
|
|
def __setstate__(self, d):
|
|
self.__init__()
|
|
self.__dict__.update(d)
|
|
|
|
def is_cpu(self) -> bool: return self._index == -1
|
|
|
|
def get_index(self) -> int:
|
|
return self._index
|
|
|
|
def get_name(self) -> str:
|
|
return self._name
|
|
|
|
def get_total_memory(self) -> int:
|
|
return self._total_memory
|
|
|
|
def __eq__(self, other):
|
|
if self is not None and other is not None and isinstance(self, TorchDeviceInfo) and isinstance(other, TorchDeviceInfo):
|
|
return self._index == other._index
|
|
return False
|
|
|
|
def __hash__(self):
|
|
return self._index
|
|
|
|
def __str__(self):
|
|
if self.is_cpu():
|
|
return "CPU"
|
|
else:
|
|
return f"[{self._index}] {self._name} [{(self._total_memory / 1024**3) :.3}Gb]"
|
|
|
|
def __repr__(self):
|
|
return f'{self.__class__.__name__} object: ' + self.__str__()
|
|
|
|
# class TorchDevicesInfo:
|
|
# """
|
|
# picklable list of TorchDeviceInfo
|
|
# """
|
|
# def __init__(self, devices : List[TorchDeviceInfo] = None):
|
|
# if devices is None:
|
|
# devices = []
|
|
# self._devices = devices
|
|
|
|
# def __getstate__(self):
|
|
# return self.__dict__.copy()
|
|
|
|
# def __setstate__(self, d):
|
|
# self.__init__()
|
|
# self.__dict__.update(d)
|
|
|
|
# def add(self, device_or_devices : TorchDeviceInfo):
|
|
# if isinstance(device_or_devices, TorchDeviceInfo):
|
|
# if device_or_devices not in self._devices:
|
|
# self._devices.append(device_or_devices)
|
|
# elif isinstance(device_or_devices, TorchDevicesInfo):
|
|
# for device in device_or_devices:
|
|
# self.add(device)
|
|
|
|
# def copy(self):
|
|
# return copy.deepcopy(self)
|
|
|
|
# def get_count(self): return len(self._devices)
|
|
|
|
# def get_largest_total_memory_device(self) -> TorchDeviceInfo:
|
|
# raise NotImplementedError()
|
|
# result = None
|
|
# idx_mem = 0
|
|
# for device in self._devices:
|
|
# mem = device.get_total_memory()
|
|
# if result is None or (mem is not None and mem > idx_mem):
|
|
# result = device
|
|
# idx_mem = mem
|
|
# return result
|
|
|
|
# def get_smallest_total_memory_device(self) -> TorchDeviceInfo:
|
|
# raise NotImplementedError()
|
|
# result = None
|
|
# idx_mem = sys.maxsize
|
|
# for device in self._devices:
|
|
# mem = device.get_total_memory()
|
|
# if result is None or (mem is not None and mem < idx_mem):
|
|
# result = device
|
|
# idx_mem = mem
|
|
# return result
|
|
|
|
# def __len__(self):
|
|
# return len(self._devices)
|
|
|
|
# def __getitem__(self, key):
|
|
# result = self._devices[key]
|
|
# if isinstance(key, slice):
|
|
# return self.__class__(result)
|
|
# return result
|
|
|
|
# def __iter__(self):
|
|
# for device in self._devices:
|
|
# yield device
|
|
|
|
# def __str__(self): return f'{self.__class__.__name__}:[' + ', '.join([ device.__str__() for device in self._devices ]) + ']'
|
|
# def __repr__(self): return f'{self.__class__.__name__}:[' + ', '.join([ device.__repr__() for device in self._devices ]) + ']'
|
|
|
|
|
|
_torch_devices = None
|
|
|
|
def get_cpu_device() -> TorchDeviceInfo:
|
|
return TorchDeviceInfo(index=-1, name='CPU', total_memory=0)
|
|
|
|
def get_available_devices(include_cpu=True, cpu_only=False) -> List[TorchDeviceInfo]:
|
|
"""
|
|
returns a list of available TorchDeviceInfo
|
|
"""
|
|
global _torch_devices
|
|
if _torch_devices is None:
|
|
import torch
|
|
devices = []
|
|
|
|
if not cpu_only:
|
|
for i in range (torch.cuda.device_count()):
|
|
device_props = torch.cuda.get_device_properties(i)
|
|
devices.append ( TorchDeviceInfo(index=i, name=device_props.name, total_memory=device_props.total_memory))
|
|
|
|
if include_cpu or cpu_only:
|
|
devices.append ( get_cpu_device() )
|
|
|
|
_torch_devices = devices
|
|
return _torch_devices
|
|
|