optimized sample generator

This commit is contained in:
Colombo 2020-01-05 11:53:31 +04:00
parent b5c234dac3
commit 21b25038ac
6 changed files with 201 additions and 160 deletions

View file

@ -20,14 +20,17 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
self.samples = SampleHost.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
if self.debug:
self.generators_count = 1
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )]
else:
self.generators_count = min ( generators_count, len(self.samples) )
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, i ) for i in range(self.generators_count) ]
self.generators_count = generators_count
samples_clis = SampleHost.host (SampleType.FACE_TEMPORAL_SORTED, self.samples_path, number_of_clis=self.generators_count)
if self.debug:
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples_clis[0]) )]
else:
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples_clis[i]) ) for i in range(self.generators_count) ]
self.generator_counter = -1
@ -39,8 +42,9 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator)
def batch_func(self, generator_id):
samples = self.samples
def batch_func(self, param):
generator_id, samples = param
samples_len = len(samples)
if samples_len == 0:
raise ValueError('No training data provided.')
@ -56,10 +60,8 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
shuffle_idxs = []
while True:
batches = None
for n_batch in range(self.batch_size):
if len(shuffle_idxs) == 0:
shuffle_idxs = samples_idxs.copy()
np.random.shuffle (shuffle_idxs)