mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-10 23:33:30 -07:00
SampleGeneratorFace added ability to specify random seed (currently unused)
This commit is contained in:
parent
24dd4ef000
commit
773e8d80e0
1 changed files with 9 additions and 1 deletions
|
@ -18,7 +18,7 @@ output_sample_types = [
|
||||||
]
|
]
|
||||||
'''
|
'''
|
||||||
class SampleGeneratorFace(SampleGeneratorBase):
|
class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, with_close_to_self=False, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, add_pitch=False, add_yaw=False, generators_count=2, **kwargs):
|
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, with_close_to_self=False, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, add_pitch=False, add_yaw=False, generators_count=2, generators_random_seed=None, **kwargs):
|
||||||
super().__init__(samples_path, debug, batch_size)
|
super().__init__(samples_path, debug, batch_size)
|
||||||
self.sample_process_options = sample_process_options
|
self.sample_process_options = sample_process_options
|
||||||
self.output_sample_types = output_sample_types
|
self.output_sample_types = output_sample_types
|
||||||
|
@ -37,6 +37,11 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
|
|
||||||
self.samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
self.samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
||||||
|
|
||||||
|
if generators_random_seed is not None and len(generators_random_seed) != generators_count:
|
||||||
|
raise ValueError("len(generators_random_seed) != generators_count")
|
||||||
|
|
||||||
|
self.generators_random_seed = generators_random_seed
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.generators_count = 1
|
self.generators_count = 1
|
||||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )]
|
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )]
|
||||||
|
@ -65,6 +70,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
|
|
||||||
def batch_func(self, generator_id):
|
def batch_func(self, generator_id):
|
||||||
gen_sq = self.generators_sq[generator_id]
|
gen_sq = self.generators_sq[generator_id]
|
||||||
|
if self.generators_random_seed is not None:
|
||||||
|
np.random.seed ( self.generators_random_seed[generator_id] )
|
||||||
|
|
||||||
samples = self.samples
|
samples = self.samples
|
||||||
samples_len = len(samples)
|
samples_len = len(samples)
|
||||||
samples_idxs = [ *range(samples_len) ] [generator_id::self.generators_count]
|
samples_idxs = [ *range(samples_len) ] [generator_id::self.generators_count]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue