refactorings

This commit is contained in:
Colombo 2019-12-22 19:00:59 +04:00
parent e0e8970ab9
commit 754d6c385c
13 changed files with 243 additions and 104 deletions

View file

@ -31,7 +31,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
self.add_sample_idx = add_sample_idx
samples_host = SampleHost.mp_host (SampleType.FACE, self.samples_path)
self.samples_len = len(samples_host)
self.samples_len = len(samples_host.get_list())
if self.samples_len == 0:
raise ValueError('No training data provided.')
@ -40,7 +40,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
if random_ct_samples_path is not None:
ct_samples_host = SampleHost.mp_host (SampleType.FACE, random_ct_samples_path)
ct_index_host = mp_utils.IndexHost( len(ct_samples_host) )
ct_index_host = mp_utils.IndexHost( len(ct_samples_host.get_list()) )
else:
ct_samples_host = None
ct_index_host = None
@ -76,7 +76,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
ct_indexes = ct_index_host.get(bs) if ct_samples is not None else None
for n_batch in range(bs):
sample = samples[ indexes[n_batch] ]
sample_idx = indexes[n_batch]
sample = samples[ sample_idx ]
ct_sample = ct_samples[ ct_indexes[n_batch] ] if ct_samples is not None else None
try:
@ -94,9 +95,5 @@ class SampleGeneratorFace(SampleGeneratorBase):
batches[i].append ( x[i] )
if self.add_sample_idx:
batches[i_sample_idx].append (idx)
batches[i_sample_idx].append (sample_idx)
yield [ np.array(batch) for batch in batches]
@staticmethod
def get_person_id_max_count(samples_path):
return SampleHost.get_person_id_max_count(samples_path)