diff --git a/samplelib/SampleGeneratorFacePerson.py b/samplelib/SampleGeneratorFacePerson.py index c5d53e9..f0a927f 100644 --- a/samplelib/SampleGeneratorFacePerson.py +++ b/samplelib/SampleGeneratorFacePerson.py @@ -37,19 +37,14 @@ class SampleGeneratorFacePerson(SampleGeneratorBase): if self.samples_len == 0: raise ValueError('No training data provided.') - - persons_name_idxs = {} - - for i,sample in enumerate(samples): - person_name = sample.person_name - if person_name not in persons_name_idxs: - persons_name_idxs[person_name] = [] - persons_name_idxs[person_name].append (i) - - indexes2D = [ persons_name_idxs[person_name] for person_name in sorted(list(persons_name_idxs.keys())) ] + + unique_person_names = { sample.person_name for sample in samples } + persons_name_idxs = { person_name : [] for person_name in unique_person_names } + for i,sample in enumerate(samples): + persons_name_idxs[sample.person_name].append (i) + indexes2D = [ persons_name_idxs[person_name] for person_name in unique_person_names ] index2d_host = mp_utils.Index2DHost(indexes2D) - if self.debug: self.generators_count = 1 self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) )]