fix multigpu extractor

This commit is contained in:
iperov 2019-01-13 13:59:22 +04:00
parent 91bb5221d4
commit 4625bcec1c

View file

@ -62,10 +62,13 @@ class ExtractSubprocessor(SubprocessorBase):
def get_devices_for_type (self, type, multi_gpu): def get_devices_for_type (self, type, multi_gpu):
if (type == 'rects' or type == 'landmarks'): if (type == 'rects' or type == 'landmarks'):
if not multi_gpu: if multi_gpu:
devices = [nnlib.device.getBestDeviceIdx()]
else:
devices = nnlib.device.getDevicesWithAtLeastTotalMemoryGB(2) devices = nnlib.device.getDevicesWithAtLeastTotalMemoryGB(2)
else:
devices = [nnlib.device.getBestDeviceIdx()]
if devices[0] == -1:
devices = []
devices = [ (idx, nnlib.device.getDeviceName(idx), nnlib.device.getDeviceVRAMTotalGb(idx) ) for idx in devices] devices = [ (idx, nnlib.device.getDeviceName(idx), nnlib.device.getDeviceVRAMTotalGb(idx) ) for idx in devices]
elif type == 'final': elif type == 'final':
@ -82,32 +85,35 @@ class ExtractSubprocessor(SubprocessorBase):
'output_dir': str(self.output_path), 'output_dir': str(self.output_path),
'detector': self.detector} 'detector': self.detector}
if self.cpu_only: if not self.cpu_only:
num_processes = 1 devices = self.get_devices_for_type(self.type, self.multi_gpu)
if not self.manual and self.type == 'rects' and self.detector == 'mt': if len(devices) != 0:
num_processes = int ( max (1, multiprocessing.cpu_count() / 2 ) ) for (device_idx, device_name, device_total_vram_gb) in devices:
num_processes = 1
if not self.manual and self.type == 'rects' and self.detector == 'mt':
num_processes = int ( max (1, device_total_vram_gb / 2) )
for i in range(0, num_processes ): for i in range(0, num_processes ):
client_dict = base_dict.copy() client_dict = base_dict.copy()
client_dict['device_idx'] = 0 client_dict['device_idx'] = device_idx
client_dict['device_name'] = 'CPU' if num_processes == 1 else 'CPU #%d' % (i), client_dict['device_name'] = device_name if num_processes == 1 else '%s #%d' % (device_name,i)
client_dict['device_type'] = 'CPU' client_dict['device_type'] = 'GPU'
yield client_dict['device_name'], {}, client_dict yield client_dict['device_name'], {}, client_dict
return
print ("No capable GPU's found, falling back to CPU mode.")
else: num_processes = 1
for (device_idx, device_name, device_total_vram_gb) in self.get_devices_for_type(self.type, self.multi_gpu): if not self.manual and self.type == 'rects' and self.detector == 'mt':
num_processes = 1 num_processes = int ( max (1, multiprocessing.cpu_count() / 2 ) )
if not self.manual and self.type == 'rects' and self.detector == 'mt':
num_processes = int ( max (1, device_total_vram_gb / 2) )
for i in range(0, num_processes ): for i in range(0, num_processes ):
client_dict = base_dict.copy() client_dict = base_dict.copy()
client_dict['device_idx'] = device_idx client_dict['device_idx'] = 0
client_dict['device_name'] = device_name if num_processes == 1 else '%s #%d' % (device_name,i) client_dict['device_name'] = 'CPU' if num_processes == 1 else 'CPU #%d' % (i),
client_dict['device_type'] = 'GPU' client_dict['device_type'] = 'CPU'
yield client_dict['device_name'], {}, client_dict yield client_dict['device_name'], {}, client_dict
#override #override
def get_no_process_started_message(self): def get_no_process_started_message(self):