This commit is contained in:
Colombo 2019-11-24 19:51:07 +04:00
parent 1bfd65abe5
commit 77b390c04b
4 changed files with 150 additions and 25 deletions

View file

@ -24,8 +24,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
random_ct_samples_path=None,
sample_process_options=SampleProcessor.Options(),
output_sample_types=[],
person_id_mode=False,
add_sample_idx=False,
use_caching=False,
generators_count=2,
generators_random_seed=None,
**kwargs):
@ -34,7 +34,6 @@ class SampleGeneratorFace(SampleGeneratorBase):
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx
self.person_id_mode = person_id_mode
if sort_by_yaw_target_samples_path is not None:
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
@ -48,7 +47,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
self.generators_random_seed = generators_random_seed
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path, person_id_mode=person_id_mode)
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path, use_caching=use_caching)
np.random.shuffle(samples)
self.samples_len = len(samples)
@ -149,19 +148,12 @@ class SampleGeneratorFace(SampleGeneratorBase):
if self.add_sample_idx:
batches += [ [] ]
i_sample_idx = len(batches)-1
if self.person_id_mode:
batches += [ [] ]
i_person_id = len(batches)-1
for i in range(len(x)):
batches[i].append ( x[i] )
if self.add_sample_idx:
batches[i_sample_idx].append (idx)
if self.person_id_mode:
batches[i_person_id].append ( np.array([sample.person_id]) )
break