mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
optimized sample generator
This commit is contained in:
parent
b5c234dac3
commit
21b25038ac
6 changed files with 201 additions and 160 deletions
|
@ -7,8 +7,8 @@ import numpy as np
|
|||
from facelib import LandmarksProcessor
|
||||
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
|
||||
SampleType)
|
||||
from utils import iter_utils
|
||||
from utils import mp_utils
|
||||
from utils import iter_utils, mp_utils
|
||||
|
||||
|
||||
'''
|
||||
arg
|
||||
|
@ -30,8 +30,13 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
self.output_sample_types = output_sample_types
|
||||
self.add_sample_idx = add_sample_idx
|
||||
|
||||
samples_host = SampleHost.mp_host (SampleType.FACE, self.samples_path)
|
||||
self.samples_len = len(samples_host.get_list())
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
else:
|
||||
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 6)
|
||||
|
||||
samples_clis = SampleHost.host (SampleType.FACE, self.samples_path, number_of_clis=self.generators_count)
|
||||
self.samples_len = len(samples_clis[0])
|
||||
|
||||
if self.samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
@ -39,18 +44,16 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
index_host = mp_utils.IndexHost(self.samples_len)
|
||||
|
||||
if random_ct_samples_path is not None:
|
||||
ct_samples_host = SampleHost.mp_host (SampleType.FACE, random_ct_samples_path)
|
||||
ct_index_host = mp_utils.IndexHost( len(ct_samples_host.get_list()) )
|
||||
ct_samples_clis = SampleHost.host (SampleType.FACE, random_ct_samples_path, number_of_clis=self.generators_count)
|
||||
ct_index_host = mp_utils.IndexHost( len(ct_samples_clis[0]) )
|
||||
else:
|
||||
ct_samples_host = None
|
||||
ct_samples_clis = None
|
||||
ct_index_host = None
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index_host.create_cli(), ct_samples_host.create_cli() if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_clis[0], index_host.create_cli(), ct_samples_clis[0] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
||||
else:
|
||||
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 4)
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_host.create_cli(), index_host.create_cli(), ct_samples_host.create_cli() if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_clis[i], index_host.create_cli(), ct_samples_clis[i] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
|
@ -72,13 +75,16 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
while True:
|
||||
batches = None
|
||||
|
||||
indexes = index_host.get(bs)
|
||||
ct_indexes = ct_index_host.get(bs) if ct_samples is not None else None
|
||||
indexes = index_host.multi_get(bs)
|
||||
ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None
|
||||
|
||||
batch_samples = samples.multi_get (indexes)
|
||||
batch_ct_samples = ct_samples.multi_get (ct_indexes) if ct_samples is not None else None
|
||||
|
||||
for n_batch in range(bs):
|
||||
sample_idx = indexes[n_batch]
|
||||
sample = samples[ sample_idx ]
|
||||
ct_sample = ct_samples[ ct_indexes[n_batch] ] if ct_samples is not None else None
|
||||
sample = batch_samples[n_batch]
|
||||
ct_sample = batch_ct_samples[n_batch] if ct_samples is not None else None
|
||||
|
||||
try:
|
||||
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue