fix check no training data provided

This commit is contained in:
Colombo 2019-10-25 00:27:34 +04:00
parent 75eeef0a96
commit b9c41a269d

View file

@ -49,7 +49,10 @@ class SampleGeneratorFace(SampleGeneratorBase):
self.generators_random_seed = generators_random_seed
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path, person_id_mode=person_id_mode)
self.total_samples_count = len(samples)
self.samples_len = len(samples)
if self.samples_len == 0:
raise ValueError('No training data provided.')
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) if random_ct_samples_path is not None else None
self.random_ct_sample_chance = 100
@ -58,14 +61,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
self.generators_count = 1
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples, ct_samples) )]
else:
self.generators_count = min ( generators_count, len(samples) )
self.generators_count = min ( generators_count, self.samples_len )
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ]
self.generator_counter = -1
#overridable
def get_total_sample_count(self):
return self.total_samples_count
return self.samples_len
def __iter__(self):
return self
@ -86,9 +89,6 @@ class SampleGeneratorFace(SampleGeneratorBase):
ct_samples_len = len(ct_samples) if ct_samples is not None else 0
if len(samples_idxs) == 0:
raise ValueError('No training data provided.')
if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
if all ( [ samples[idx] == None for idx in samples_idxs] ):
raise ValueError('Not enough training data. Gather more faces!')