mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 13:32:09 -07:00
optimized sample generator
This commit is contained in:
parent
b5c234dac3
commit
21b25038ac
6 changed files with 201 additions and 160 deletions
|
@ -20,14 +20,17 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
|||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
|
||||
self.samples = SampleHost.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )]
|
||||
else:
|
||||
self.generators_count = min ( generators_count, len(self.samples) )
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, i ) for i in range(self.generators_count) ]
|
||||
self.generators_count = generators_count
|
||||
|
||||
samples_clis = SampleHost.host (SampleType.FACE_TEMPORAL_SORTED, self.samples_path, number_of_clis=self.generators_count)
|
||||
|
||||
if self.debug:
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples_clis[0]) )]
|
||||
else:
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples_clis[i]) ) for i in range(self.generators_count) ]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
|
@ -39,8 +42,9 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
|||
generator = self.generators[self.generator_counter % len(self.generators) ]
|
||||
return next(generator)
|
||||
|
||||
def batch_func(self, generator_id):
|
||||
samples = self.samples
|
||||
def batch_func(self, param):
|
||||
generator_id, samples = param
|
||||
|
||||
samples_len = len(samples)
|
||||
if samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
@ -56,10 +60,8 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
|||
shuffle_idxs = []
|
||||
|
||||
while True:
|
||||
|
||||
batches = None
|
||||
for n_batch in range(self.batch_size):
|
||||
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = samples_idxs.copy()
|
||||
np.random.shuffle (shuffle_idxs)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue