This commit is contained in:
Colombo 2019-12-22 19:10:57 +04:00
parent 754d6c385c
commit 9ed0111824

View file

@ -37,19 +37,14 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
if self.samples_len == 0:
raise ValueError('No training data provided.')
persons_name_idxs = {}
for i,sample in enumerate(samples):
person_name = sample.person_name
if person_name not in persons_name_idxs:
persons_name_idxs[person_name] = []
persons_name_idxs[person_name].append (i)
indexes2D = [ persons_name_idxs[person_name] for person_name in sorted(list(persons_name_idxs.keys())) ]
unique_person_names = { sample.person_name for sample in samples }
persons_name_idxs = { person_name : [] for person_name in unique_person_names }
for i,sample in enumerate(samples):
persons_name_idxs[sample.person_name].append (i)
indexes2D = [ persons_name_idxs[person_name] for person_name in unique_person_names ]
index2d_host = mp_utils.Index2DHost(indexes2D)
if self.debug:
self.generators_count = 1
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) )]