From 111bced890cad1781f38f0c09e4c94600a9a6cbf Mon Sep 17 00:00:00 2001 From: plucky Date: Sat, 29 Dec 2018 23:57:58 +0800 Subject: [PATCH] fix devicelib error line:61 --- nnlib/devicelib.py | 74 +++++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/nnlib/devicelib.py b/nnlib/devicelib.py index 81f0a7f..51811d8 100644 --- a/nnlib/devicelib.py +++ b/nnlib/devicelib.py @@ -1,7 +1,7 @@ from .pynvml import * class devicelib: - class Config(): + class Config(): force_best_gpu_idx = -1 multi_gpu = False force_gpu_idxs = None @@ -11,10 +11,10 @@ class devicelib: allow_growth = True float16 = False cpu_only = False - - def __init__ (self, force_best_gpu_idx = -1, - multi_gpu = False, - force_gpu_idxs = None, + + def __init__ (self, force_best_gpu_idx = -1, + multi_gpu = False, + force_gpu_idxs = None, choose_worst_gpu = False, allow_growth = True, float16 = False, @@ -28,15 +28,15 @@ class devicelib: self.force_best_gpu_idx = force_best_gpu_idx self.multi_gpu = multi_gpu self.force_gpu_idxs = force_gpu_idxs - self.choose_worst_gpu = choose_worst_gpu + self.choose_worst_gpu = choose_worst_gpu self.allow_growth = allow_growth - + self.gpu_idxs = [] if force_gpu_idxs is not None: for idx in force_gpu_idxs.split(','): idx = int(idx) if devicelib.isValidDeviceIdx(idx): - self.gpu_idxs.append(idx) + self.gpu_idxs.append(idx) else: gpu_idx = force_best_gpu_idx if (force_best_gpu_idx >= 0 and devicelib.isValidDeviceIdx(force_best_gpu_idx)) else devicelib.getBestDeviceIdx() if not choose_worst_gpu else devicelib.getWorstDeviceIdx() if gpu_idx != -1: @@ -46,22 +46,22 @@ class devicelib: self.multi_gpu = False else: self.gpu_idxs = [gpu_idx] - + if len(self.gpu_idxs) == 0: self.cpu_only = True else: self.cpu_only = False self.gpu_total_vram_gb = devicelib.getDeviceVRAMTotalGb ( self.gpu_idxs[0] ) - + @staticmethod def hasNVML(): try: nvmlInit() nvmlShutdown() - except e: + except: return False - return True - + return True + @staticmethod def getDevicesWithAtLeastFreeMemory(freememsize): result = [] @@ -71,12 +71,12 @@ class devicelib: handle = nvmlDeviceGetHandleByIndex(i) memInfo = nvmlDeviceGetMemoryInfo( handle ) if (memInfo.total - memInfo.used) >= freememsize: - result.append (i) + result.append (i) nvmlShutdown() except: pass return result - + @staticmethod def getDevicesWithAtLeastTotalMemoryGB(totalmemsize_gb): result = [] @@ -86,43 +86,43 @@ class devicelib: handle = nvmlDeviceGetHandleByIndex(i) memInfo = nvmlDeviceGetMemoryInfo( handle ) if (memInfo.total) >= totalmemsize_gb*1024*1024*1024: - result.append (i) + result.append (i) nvmlShutdown() except: pass return result - + @staticmethod def getAllDevicesIdxsList (): result = [] try: - nvmlInit() - result = [ i for i in range(0, nvmlDeviceGetCount() ) ] + nvmlInit() + result = [ i for i in range(0, nvmlDeviceGetCount() ) ] nvmlShutdown() except: pass return result - + @staticmethod def getDeviceVRAMFree (idx): result = 0 try: nvmlInit() - if idx < nvmlDeviceGetCount(): + if idx < nvmlDeviceGetCount(): handle = nvmlDeviceGetHandleByIndex(idx) memInfo = nvmlDeviceGetMemoryInfo( handle ) - result = (memInfo.total - memInfo.used) + result = (memInfo.total - memInfo.used) nvmlShutdown() except: pass return result - + @staticmethod def getDeviceVRAMTotalGb (idx): result = 0 try: nvmlInit() - if idx < nvmlDeviceGetCount(): + if idx < nvmlDeviceGetCount(): handle = nvmlDeviceGetHandleByIndex(idx) memInfo = nvmlDeviceGetMemoryInfo( handle ) result = memInfo.total / (1024*1024*1024) @@ -131,7 +131,7 @@ class devicelib: except: pass return result - + @staticmethod def getBestDeviceIdx(): idx = -1 @@ -149,13 +149,13 @@ class devicelib: except: pass return idx - + @staticmethod def getWorstDeviceIdx(): idx = -1 try: - nvmlInit() - + nvmlInit() + idx_mem = sys.maxsize for i in range(0, nvmlDeviceGetCount() ): handle = nvmlDeviceGetHandleByIndex(i) @@ -168,42 +168,42 @@ class devicelib: except: pass return idx - + @staticmethod def isValidDeviceIdx(idx): result = False try: - nvmlInit() + nvmlInit() result = (idx < nvmlDeviceGetCount()) nvmlShutdown() except: pass return result - + @staticmethod def getDeviceIdxsEqualModel(idx): result = [] try: - nvmlInit() + nvmlInit() idx_name = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode() for i in range(0, nvmlDeviceGetCount() ): if nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)).decode() == idx_name: result.append (i) - + nvmlShutdown() except: pass return result - + @staticmethod def getDeviceName (idx): result = '' try: - nvmlInit() - if idx < nvmlDeviceGetCount(): + nvmlInit() + if idx < nvmlDeviceGetCount(): result = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode() nvmlShutdown() except: pass - return result \ No newline at end of file + return result