diff --git a/nnlib/device.py b/nnlib/device.py index 144de43..b61fb16 100644 --- a/nnlib/device.py +++ b/nnlib/device.py @@ -282,19 +282,20 @@ def get_plaidML_devices(): plaidML_devices = [] # Using plaidML OpenCL backend to determine system devices and has_nvidia_device try: - os.environ['PLAIDML_EXPERIMENTAL'] = 'false' #this enables work plaidML without run 'plaidml-setup' + os.environ['PLAIDML_EXPERIMENTAL'] = 'true' #this enables work plaidML without run 'plaidml-setup' import plaidml ctx = plaidml.Context() - for d in plaidml.devices(ctx, return_all=True)[0]: - details = json.loads(d.details) - if details['type'] == 'CPU': #skipping opencl-CPU - continue - if 'nvidia' in details['vendor'].lower(): - has_nvidia_device = True - plaidML_devices += [ {'id':d.id, - 'globalMemSize' : int(details['globalMemSize']), - 'description' : d.description.decode() - }] + for devices in plaidml.devices(ctx, return_all=True): + for d in devices: + details = json.loads(d.details) + if 'type' not in details or details['type'] == 'CPU': #skipping opencl-CPU + continue + if 'vendor' in details and 'nvidia' in details['vendor'].lower(): + has_nvidia_device = True + plaidML_devices += [ {'id':d.id, + 'globalMemSize' : int(details['globalMemSize']), + 'description' : d.description.decode() + }] ctx.shutdown() except: pass