fix multigpu extractor

This commit is contained in:
iperov 2019-01-13 13:59:22 +04:00
parent 91bb5221d4
commit 4625bcec1c

View file

@ -62,10 +62,13 @@ class ExtractSubprocessor(SubprocessorBase):
def get_devices_for_type (self, type, multi_gpu):
if (type == 'rects' or type == 'landmarks'):
if not multi_gpu:
devices = [nnlib.device.getBestDeviceIdx()]
else:
if multi_gpu:
devices = nnlib.device.getDevicesWithAtLeastTotalMemoryGB(2)
else:
devices = [nnlib.device.getBestDeviceIdx()]
if devices[0] == -1:
devices = []
devices = [ (idx, nnlib.device.getDeviceName(idx), nnlib.device.getDeviceVRAMTotalGb(idx) ) for idx in devices]
elif type == 'final':
@ -82,21 +85,10 @@ class ExtractSubprocessor(SubprocessorBase):
'output_dir': str(self.output_path),
'detector': self.detector}
if self.cpu_only:
num_processes = 1
if not self.manual and self.type == 'rects' and self.detector == 'mt':
num_processes = int ( max (1, multiprocessing.cpu_count() / 2 ) )
for i in range(0, num_processes ):
client_dict = base_dict.copy()
client_dict['device_idx'] = 0
client_dict['device_name'] = 'CPU' if num_processes == 1 else 'CPU #%d' % (i),
client_dict['device_type'] = 'CPU'
yield client_dict['device_name'], {}, client_dict
else:
for (device_idx, device_name, device_total_vram_gb) in self.get_devices_for_type(self.type, self.multi_gpu):
if not self.cpu_only:
devices = self.get_devices_for_type(self.type, self.multi_gpu)
if len(devices) != 0:
for (device_idx, device_name, device_total_vram_gb) in devices:
num_processes = 1
if not self.manual and self.type == 'rects' and self.detector == 'mt':
num_processes = int ( max (1, device_total_vram_gb / 2) )
@ -107,6 +99,20 @@ class ExtractSubprocessor(SubprocessorBase):
client_dict['device_name'] = device_name if num_processes == 1 else '%s #%d' % (device_name,i)
client_dict['device_type'] = 'GPU'
yield client_dict['device_name'], {}, client_dict
return
print ("No capable GPU's found, falling back to CPU mode.")
num_processes = 1
if not self.manual and self.type == 'rects' and self.detector == 'mt':
num_processes = int ( max (1, multiprocessing.cpu_count() / 2 ) )
for i in range(0, num_processes ):
client_dict = base_dict.copy()
client_dict['device_idx'] = 0
client_dict['device_name'] = 'CPU' if num_processes == 1 else 'CPU #%d' % (i),
client_dict['device_type'] = 'CPU'
yield client_dict['device_name'], {}, client_dict
#override