refactoring

This commit is contained in:
iperov 2018-12-22 19:37:25 +04:00
parent 0d7387165a
commit 44798c2b85

View file

@ -60,10 +60,8 @@ class ExtractSubprocessor(SubprocessorBase):
cv2.setMouseCallback(self.wnd_name, onMouse, self.param)
def get_devices_for_type (self, type, multi_gpu, cpu_only):
if cpu_only:
devices = [ (0, 'CPU', 0 ) ]
elif (type == 'rects' or type == 'landmarks'):
def get_devices_for_type (self, type, multi_gpu):
if (type == 'rects' or type == 'landmarks'):
if not multi_gpu:
devices = [gpufmkmgr.getBestDeviceIdx()]
else:
@ -76,27 +74,41 @@ class ExtractSubprocessor(SubprocessorBase):
return devices
#override
def process_info_generator(self):
for (device_idx, device_name, device_total_vram_gb) in self.get_devices_for_type(self.type, self.multi_gpu, self.cpu_only):
def process_info_generator(self):
base_dict = {'type' : self.type,
'image_size': self.image_size,
'face_type': self.face_type,
'debug': self.debug,
'output_dir': str(self.output_path),
'detector': self.detector}
if self.cpu_only:
num_processes = 1
if not self.manual and self.type == 'rects' and self.detector == 'mt':
if self.cpu_only:
num_processes = int ( max (1, multiprocessing.cpu_count() / 2 ) )
else:
num_processes = int ( max (1, device_total_vram_gb / 2) )
num_processes = int ( max (1, multiprocessing.cpu_count() / 2 ) )
for i in range(0, num_processes ):
device_name_for_process = device_name if num_processes == 1 else '%s #%d' % (device_name,i)
yield device_name_for_process, {}, {'type' : self.type,
'device_idx' : device_idx,
'device_name' : device_name_for_process,
'device_type' : 'CPU' if self.cpu_only else 'GPU',
'image_size': self.image_size,
'face_type': self.face_type,
'debug': self.debug,
'output_dir': str(self.output_path),
'detector': self.detector}
client_dict = base_dict.copy()
client_dict['device_idx'] = 0
client_dict['device_name'] = 'CPU' if num_processes == 1 else 'CPU #%d' % (i),
client_dict['device_type'] = 'CPU'
yield client_dict['device_name'], {}, client_dict
else:
for (device_idx, device_name, device_total_vram_gb) in self.get_devices_for_type(self.type, self.multi_gpu):
num_processes = 1
if not self.manual and self.type == 'rects' and self.detector == 'mt':
num_processes = int ( max (1, device_total_vram_gb / 2) )
for i in range(0, num_processes ):
client_dict = base_dict.copy()
client_dict['device_idx'] = device_idx
client_dict['device_name'] = device_name if num_processes == 1 else '%s #%d' % (device_name,i)
client_dict['device_type'] = 'GPU'
yield client_dict['device_name'], {}, client_dict
#override
def get_no_process_started_message(self):
if (self.type == 'rects' or self.type == 'landmarks'):