SampleGeneratorFace added ability to specify random seed (currently unused)

This commit is contained in:
iperov 2019-03-27 10:39:21 +04:00
parent 24dd4ef000
commit 773e8d80e0

View file

@ -18,7 +18,7 @@ output_sample_types = [
]
'''
class SampleGeneratorFace(SampleGeneratorBase):
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, with_close_to_self=False, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, add_pitch=False, add_yaw=False, generators_count=2, **kwargs):
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, with_close_to_self=False, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, add_pitch=False, add_yaw=False, generators_count=2, generators_random_seed=None, **kwargs):
super().__init__(samples_path, debug, batch_size)
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
@ -37,6 +37,11 @@ class SampleGeneratorFace(SampleGeneratorBase):
self.samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
if generators_random_seed is not None and len(generators_random_seed) != generators_count:
raise ValueError("len(generators_random_seed) != generators_count")
self.generators_random_seed = generators_random_seed
if self.debug:
self.generators_count = 1
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )]
@ -65,6 +70,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
def batch_func(self, generator_id):
gen_sq = self.generators_sq[generator_id]
if self.generators_random_seed is not None:
np.random.seed ( self.generators_random_seed[generator_id] )
samples = self.samples
samples_len = len(samples)
samples_idxs = [ *range(samples_len) ] [generator_id::self.generators_count]