fix devicelib error line:61

This commit is contained in:
plucky 2018-12-29 23:57:58 +08:00
commit 111bced890

View file

@ -1,7 +1,7 @@
from .pynvml import * from .pynvml import *
class devicelib: class devicelib:
class Config(): class Config():
force_best_gpu_idx = -1 force_best_gpu_idx = -1
multi_gpu = False multi_gpu = False
force_gpu_idxs = None force_gpu_idxs = None
@ -11,10 +11,10 @@ class devicelib:
allow_growth = True allow_growth = True
float16 = False float16 = False
cpu_only = False cpu_only = False
def __init__ (self, force_best_gpu_idx = -1, def __init__ (self, force_best_gpu_idx = -1,
multi_gpu = False, multi_gpu = False,
force_gpu_idxs = None, force_gpu_idxs = None,
choose_worst_gpu = False, choose_worst_gpu = False,
allow_growth = True, allow_growth = True,
float16 = False, float16 = False,
@ -28,15 +28,15 @@ class devicelib:
self.force_best_gpu_idx = force_best_gpu_idx self.force_best_gpu_idx = force_best_gpu_idx
self.multi_gpu = multi_gpu self.multi_gpu = multi_gpu
self.force_gpu_idxs = force_gpu_idxs self.force_gpu_idxs = force_gpu_idxs
self.choose_worst_gpu = choose_worst_gpu self.choose_worst_gpu = choose_worst_gpu
self.allow_growth = allow_growth self.allow_growth = allow_growth
self.gpu_idxs = [] self.gpu_idxs = []
if force_gpu_idxs is not None: if force_gpu_idxs is not None:
for idx in force_gpu_idxs.split(','): for idx in force_gpu_idxs.split(','):
idx = int(idx) idx = int(idx)
if devicelib.isValidDeviceIdx(idx): if devicelib.isValidDeviceIdx(idx):
self.gpu_idxs.append(idx) self.gpu_idxs.append(idx)
else: else:
gpu_idx = force_best_gpu_idx if (force_best_gpu_idx >= 0 and devicelib.isValidDeviceIdx(force_best_gpu_idx)) else devicelib.getBestDeviceIdx() if not choose_worst_gpu else devicelib.getWorstDeviceIdx() gpu_idx = force_best_gpu_idx if (force_best_gpu_idx >= 0 and devicelib.isValidDeviceIdx(force_best_gpu_idx)) else devicelib.getBestDeviceIdx() if not choose_worst_gpu else devicelib.getWorstDeviceIdx()
if gpu_idx != -1: if gpu_idx != -1:
@ -46,22 +46,22 @@ class devicelib:
self.multi_gpu = False self.multi_gpu = False
else: else:
self.gpu_idxs = [gpu_idx] self.gpu_idxs = [gpu_idx]
if len(self.gpu_idxs) == 0: if len(self.gpu_idxs) == 0:
self.cpu_only = True self.cpu_only = True
else: else:
self.cpu_only = False self.cpu_only = False
self.gpu_total_vram_gb = devicelib.getDeviceVRAMTotalGb ( self.gpu_idxs[0] ) self.gpu_total_vram_gb = devicelib.getDeviceVRAMTotalGb ( self.gpu_idxs[0] )
@staticmethod @staticmethod
def hasNVML(): def hasNVML():
try: try:
nvmlInit() nvmlInit()
nvmlShutdown() nvmlShutdown()
except e: except:
return False return False
return True return True
@staticmethod @staticmethod
def getDevicesWithAtLeastFreeMemory(freememsize): def getDevicesWithAtLeastFreeMemory(freememsize):
result = [] result = []
@ -71,12 +71,12 @@ class devicelib:
handle = nvmlDeviceGetHandleByIndex(i) handle = nvmlDeviceGetHandleByIndex(i)
memInfo = nvmlDeviceGetMemoryInfo( handle ) memInfo = nvmlDeviceGetMemoryInfo( handle )
if (memInfo.total - memInfo.used) >= freememsize: if (memInfo.total - memInfo.used) >= freememsize:
result.append (i) result.append (i)
nvmlShutdown() nvmlShutdown()
except: except:
pass pass
return result return result
@staticmethod @staticmethod
def getDevicesWithAtLeastTotalMemoryGB(totalmemsize_gb): def getDevicesWithAtLeastTotalMemoryGB(totalmemsize_gb):
result = [] result = []
@ -86,43 +86,43 @@ class devicelib:
handle = nvmlDeviceGetHandleByIndex(i) handle = nvmlDeviceGetHandleByIndex(i)
memInfo = nvmlDeviceGetMemoryInfo( handle ) memInfo = nvmlDeviceGetMemoryInfo( handle )
if (memInfo.total) >= totalmemsize_gb*1024*1024*1024: if (memInfo.total) >= totalmemsize_gb*1024*1024*1024:
result.append (i) result.append (i)
nvmlShutdown() nvmlShutdown()
except: except:
pass pass
return result return result
@staticmethod @staticmethod
def getAllDevicesIdxsList (): def getAllDevicesIdxsList ():
result = [] result = []
try: try:
nvmlInit() nvmlInit()
result = [ i for i in range(0, nvmlDeviceGetCount() ) ] result = [ i for i in range(0, nvmlDeviceGetCount() ) ]
nvmlShutdown() nvmlShutdown()
except: except:
pass pass
return result return result
@staticmethod @staticmethod
def getDeviceVRAMFree (idx): def getDeviceVRAMFree (idx):
result = 0 result = 0
try: try:
nvmlInit() nvmlInit()
if idx < nvmlDeviceGetCount(): if idx < nvmlDeviceGetCount():
handle = nvmlDeviceGetHandleByIndex(idx) handle = nvmlDeviceGetHandleByIndex(idx)
memInfo = nvmlDeviceGetMemoryInfo( handle ) memInfo = nvmlDeviceGetMemoryInfo( handle )
result = (memInfo.total - memInfo.used) result = (memInfo.total - memInfo.used)
nvmlShutdown() nvmlShutdown()
except: except:
pass pass
return result return result
@staticmethod @staticmethod
def getDeviceVRAMTotalGb (idx): def getDeviceVRAMTotalGb (idx):
result = 0 result = 0
try: try:
nvmlInit() nvmlInit()
if idx < nvmlDeviceGetCount(): if idx < nvmlDeviceGetCount():
handle = nvmlDeviceGetHandleByIndex(idx) handle = nvmlDeviceGetHandleByIndex(idx)
memInfo = nvmlDeviceGetMemoryInfo( handle ) memInfo = nvmlDeviceGetMemoryInfo( handle )
result = memInfo.total / (1024*1024*1024) result = memInfo.total / (1024*1024*1024)
@ -131,7 +131,7 @@ class devicelib:
except: except:
pass pass
return result return result
@staticmethod @staticmethod
def getBestDeviceIdx(): def getBestDeviceIdx():
idx = -1 idx = -1
@ -149,13 +149,13 @@ class devicelib:
except: except:
pass pass
return idx return idx
@staticmethod @staticmethod
def getWorstDeviceIdx(): def getWorstDeviceIdx():
idx = -1 idx = -1
try: try:
nvmlInit() nvmlInit()
idx_mem = sys.maxsize idx_mem = sys.maxsize
for i in range(0, nvmlDeviceGetCount() ): for i in range(0, nvmlDeviceGetCount() ):
handle = nvmlDeviceGetHandleByIndex(i) handle = nvmlDeviceGetHandleByIndex(i)
@ -168,42 +168,42 @@ class devicelib:
except: except:
pass pass
return idx return idx
@staticmethod @staticmethod
def isValidDeviceIdx(idx): def isValidDeviceIdx(idx):
result = False result = False
try: try:
nvmlInit() nvmlInit()
result = (idx < nvmlDeviceGetCount()) result = (idx < nvmlDeviceGetCount())
nvmlShutdown() nvmlShutdown()
except: except:
pass pass
return result return result
@staticmethod @staticmethod
def getDeviceIdxsEqualModel(idx): def getDeviceIdxsEqualModel(idx):
result = [] result = []
try: try:
nvmlInit() nvmlInit()
idx_name = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode() idx_name = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode()
for i in range(0, nvmlDeviceGetCount() ): for i in range(0, nvmlDeviceGetCount() ):
if nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)).decode() == idx_name: if nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)).decode() == idx_name:
result.append (i) result.append (i)
nvmlShutdown() nvmlShutdown()
except: except:
pass pass
return result return result
@staticmethod @staticmethod
def getDeviceName (idx): def getDeviceName (idx):
result = '' result = ''
try: try:
nvmlInit() nvmlInit()
if idx < nvmlDeviceGetCount(): if idx < nvmlDeviceGetCount():
result = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode() result = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode()
nvmlShutdown() nvmlShutdown()
except: except:
pass pass
return result return result