diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index 1eef85d..4782848 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -49,7 +49,10 @@ class SampleGeneratorFace(SampleGeneratorBase): self.generators_random_seed = generators_random_seed samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path, person_id_mode=person_id_mode) - self.total_samples_count = len(samples) + self.samples_len = len(samples) + + if self.samples_len == 0: + raise ValueError('No training data provided.') ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) if random_ct_samples_path is not None else None self.random_ct_sample_chance = 100 @@ -58,14 +61,14 @@ class SampleGeneratorFace(SampleGeneratorBase): self.generators_count = 1 self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples, ct_samples) )] else: - self.generators_count = min ( generators_count, len(samples) ) + self.generators_count = min ( generators_count, self.samples_len ) self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ] self.generator_counter = -1 #overridable def get_total_sample_count(self): - return self.total_samples_count + return self.samples_len def __iter__(self): return self @@ -86,9 +89,6 @@ class SampleGeneratorFace(SampleGeneratorBase): ct_samples_len = len(ct_samples) if ct_samples is not None else 0 - if len(samples_idxs) == 0: - raise ValueError('No training data provided.') - if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: if all ( [ samples[idx] == None for idx in samples_idxs] ): raise ValueError('Not enough training data. Gather more faces!')