This commit is contained in:
Colombo 2019-12-22 19:10:57 +04:00
parent 754d6c385c
commit 9ed0111824

View file

@ -38,18 +38,13 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
if self.samples_len == 0: if self.samples_len == 0:
raise ValueError('No training data provided.') raise ValueError('No training data provided.')
persons_name_idxs = {} unique_person_names = { sample.person_name for sample in samples }
persons_name_idxs = { person_name : [] for person_name in unique_person_names }
for i,sample in enumerate(samples): for i,sample in enumerate(samples):
person_name = sample.person_name persons_name_idxs[sample.person_name].append (i)
if person_name not in persons_name_idxs: indexes2D = [ persons_name_idxs[person_name] for person_name in unique_person_names ]
persons_name_idxs[person_name] = []
persons_name_idxs[person_name].append (i)
indexes2D = [ persons_name_idxs[person_name] for person_name in sorted(list(persons_name_idxs.keys())) ]
index2d_host = mp_utils.Index2DHost(indexes2D) index2d_host = mp_utils.Index2DHost(indexes2D)
if self.debug: if self.debug:
self.generators_count = 1 self.generators_count = 1
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) )] self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) )]