diff --git a/mainscripts/Extractor.py b/mainscripts/Extractor.py index ea6a47f..ea2544f 100644 --- a/mainscripts/Extractor.py +++ b/mainscripts/Extractor.py @@ -62,10 +62,13 @@ class ExtractSubprocessor(SubprocessorBase): def get_devices_for_type (self, type, multi_gpu): if (type == 'rects' or type == 'landmarks'): - if not multi_gpu: - devices = [nnlib.device.getBestDeviceIdx()] - else: + if multi_gpu: devices = nnlib.device.getDevicesWithAtLeastTotalMemoryGB(2) + else: + devices = [nnlib.device.getBestDeviceIdx()] + if devices[0] == -1: + devices = [] + devices = [ (idx, nnlib.device.getDeviceName(idx), nnlib.device.getDeviceVRAMTotalGb(idx) ) for idx in devices] elif type == 'final': @@ -82,32 +85,35 @@ class ExtractSubprocessor(SubprocessorBase): 'output_dir': str(self.output_path), 'detector': self.detector} - if self.cpu_only: - num_processes = 1 - if not self.manual and self.type == 'rects' and self.detector == 'mt': - num_processes = int ( max (1, multiprocessing.cpu_count() / 2 ) ) - - for i in range(0, num_processes ): - client_dict = base_dict.copy() - client_dict['device_idx'] = 0 - client_dict['device_name'] = 'CPU' if num_processes == 1 else 'CPU #%d' % (i), - client_dict['device_type'] = 'CPU' - - yield client_dict['device_name'], {}, client_dict + if not self.cpu_only: + devices = self.get_devices_for_type(self.type, self.multi_gpu) + if len(devices) != 0: + for (device_idx, device_name, device_total_vram_gb) in devices: + num_processes = 1 + if not self.manual and self.type == 'rects' and self.detector == 'mt': + num_processes = int ( max (1, device_total_vram_gb / 2) ) + + for i in range(0, num_processes ): + client_dict = base_dict.copy() + client_dict['device_idx'] = device_idx + client_dict['device_name'] = device_name if num_processes == 1 else '%s #%d' % (device_name,i) + client_dict['device_type'] = 'GPU' + + yield client_dict['device_name'], {}, client_dict + return + print ("No capable GPU's found, falling back to CPU mode.") + + num_processes = 1 + if not self.manual and self.type == 'rects' and self.detector == 'mt': + num_processes = int ( max (1, multiprocessing.cpu_count() / 2 ) ) - else: - for (device_idx, device_name, device_total_vram_gb) in self.get_devices_for_type(self.type, self.multi_gpu): - num_processes = 1 - if not self.manual and self.type == 'rects' and self.detector == 'mt': - num_processes = int ( max (1, device_total_vram_gb / 2) ) - - for i in range(0, num_processes ): - client_dict = base_dict.copy() - client_dict['device_idx'] = device_idx - client_dict['device_name'] = device_name if num_processes == 1 else '%s #%d' % (device_name,i) - client_dict['device_type'] = 'GPU' - - yield client_dict['device_name'], {}, client_dict + for i in range(0, num_processes ): + client_dict = base_dict.copy() + client_dict['device_idx'] = 0 + client_dict['device_name'] = 'CPU' if num_processes == 1 else 'CPU #%d' % (i), + client_dict['device_type'] = 'CPU' + + yield client_dict['device_name'], {}, client_dict #override def get_no_process_started_message(self):