mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-08 05:51:40 -07:00
SampleGeneratorFace added ability to specify random seed (currently unused)
This commit is contained in:
parent
24dd4ef000
commit
773e8d80e0
1 changed files with 9 additions and 1 deletions
|
@ -18,7 +18,7 @@ output_sample_types = [
|
|||
]
|
||||
'''
|
||||
class SampleGeneratorFace(SampleGeneratorBase):
|
||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, with_close_to_self=False, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, add_pitch=False, add_yaw=False, generators_count=2, **kwargs):
|
||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, with_close_to_self=False, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, add_pitch=False, add_yaw=False, generators_count=2, generators_random_seed=None, **kwargs):
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
|
@ -37,6 +37,11 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
|
||||
self.samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
||||
|
||||
if generators_random_seed is not None and len(generators_random_seed) != generators_count:
|
||||
raise ValueError("len(generators_random_seed) != generators_count")
|
||||
|
||||
self.generators_random_seed = generators_random_seed
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )]
|
||||
|
@ -65,6 +70,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
|
||||
def batch_func(self, generator_id):
|
||||
gen_sq = self.generators_sq[generator_id]
|
||||
if self.generators_random_seed is not None:
|
||||
np.random.seed ( self.generators_random_seed[generator_id] )
|
||||
|
||||
samples = self.samples
|
||||
samples_len = len(samples)
|
||||
samples_idxs = [ *range(samples_len) ] [generator_id::self.generators_count]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue