mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-14 00:53:48 -07:00
fix multigpu extractor
This commit is contained in:
parent
91bb5221d4
commit
4625bcec1c
1 changed files with 34 additions and 28 deletions
|
@ -62,10 +62,13 @@ class ExtractSubprocessor(SubprocessorBase):
|
|||
|
||||
def get_devices_for_type (self, type, multi_gpu):
|
||||
if (type == 'rects' or type == 'landmarks'):
|
||||
if not multi_gpu:
|
||||
devices = [nnlib.device.getBestDeviceIdx()]
|
||||
else:
|
||||
if multi_gpu:
|
||||
devices = nnlib.device.getDevicesWithAtLeastTotalMemoryGB(2)
|
||||
else:
|
||||
devices = [nnlib.device.getBestDeviceIdx()]
|
||||
if devices[0] == -1:
|
||||
devices = []
|
||||
|
||||
devices = [ (idx, nnlib.device.getDeviceName(idx), nnlib.device.getDeviceVRAMTotalGb(idx) ) for idx in devices]
|
||||
|
||||
elif type == 'final':
|
||||
|
@ -82,21 +85,10 @@ class ExtractSubprocessor(SubprocessorBase):
|
|||
'output_dir': str(self.output_path),
|
||||
'detector': self.detector}
|
||||
|
||||
if self.cpu_only:
|
||||
num_processes = 1
|
||||
if not self.manual and self.type == 'rects' and self.detector == 'mt':
|
||||
num_processes = int ( max (1, multiprocessing.cpu_count() / 2 ) )
|
||||
|
||||
for i in range(0, num_processes ):
|
||||
client_dict = base_dict.copy()
|
||||
client_dict['device_idx'] = 0
|
||||
client_dict['device_name'] = 'CPU' if num_processes == 1 else 'CPU #%d' % (i),
|
||||
client_dict['device_type'] = 'CPU'
|
||||
|
||||
yield client_dict['device_name'], {}, client_dict
|
||||
|
||||
else:
|
||||
for (device_idx, device_name, device_total_vram_gb) in self.get_devices_for_type(self.type, self.multi_gpu):
|
||||
if not self.cpu_only:
|
||||
devices = self.get_devices_for_type(self.type, self.multi_gpu)
|
||||
if len(devices) != 0:
|
||||
for (device_idx, device_name, device_total_vram_gb) in devices:
|
||||
num_processes = 1
|
||||
if not self.manual and self.type == 'rects' and self.detector == 'mt':
|
||||
num_processes = int ( max (1, device_total_vram_gb / 2) )
|
||||
|
@ -107,6 +99,20 @@ class ExtractSubprocessor(SubprocessorBase):
|
|||
client_dict['device_name'] = device_name if num_processes == 1 else '%s #%d' % (device_name,i)
|
||||
client_dict['device_type'] = 'GPU'
|
||||
|
||||
yield client_dict['device_name'], {}, client_dict
|
||||
return
|
||||
print ("No capable GPU's found, falling back to CPU mode.")
|
||||
|
||||
num_processes = 1
|
||||
if not self.manual and self.type == 'rects' and self.detector == 'mt':
|
||||
num_processes = int ( max (1, multiprocessing.cpu_count() / 2 ) )
|
||||
|
||||
for i in range(0, num_processes ):
|
||||
client_dict = base_dict.copy()
|
||||
client_dict['device_idx'] = 0
|
||||
client_dict['device_name'] = 'CPU' if num_processes == 1 else 'CPU #%d' % (i),
|
||||
client_dict['device_type'] = 'CPU'
|
||||
|
||||
yield client_dict['device_name'], {}, client_dict
|
||||
|
||||
#override
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue