diff --git a/samples/SampleGeneratorFace.py b/samples/SampleGeneratorFace.py index 6bdf5ea..7f639f0 100644 --- a/samples/SampleGeneratorFace.py +++ b/samples/SampleGeneratorFace.py @@ -18,7 +18,7 @@ output_sample_types = [ ] ''' class SampleGeneratorFace(SampleGeneratorBase): - def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, with_close_to_self=False, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, add_pitch=False, add_yaw=False, generators_count=2, **kwargs): + def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, with_close_to_self=False, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, add_pitch=False, add_yaw=False, generators_count=2, generators_random_seed=None, **kwargs): super().__init__(samples_path, debug, batch_size) self.sample_process_options = sample_process_options self.output_sample_types = output_sample_types @@ -37,6 +37,11 @@ class SampleGeneratorFace(SampleGeneratorBase): self.samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path) + if generators_random_seed is not None and len(generators_random_seed) != generators_count: + raise ValueError("len(generators_random_seed) != generators_count") + + self.generators_random_seed = generators_random_seed + if self.debug: self.generators_count = 1 self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] @@ -65,6 +70,9 @@ class SampleGeneratorFace(SampleGeneratorBase): def batch_func(self, generator_id): gen_sq = self.generators_sq[generator_id] + if self.generators_random_seed is not None: + np.random.seed ( self.generators_random_seed[generator_id] ) + samples = self.samples samples_len = len(samples) samples_idxs = [ *range(samples_len) ] [generator_id::self.generators_count]