mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
upd SampleGenerator
This commit is contained in:
parent
1898bd6881
commit
9860a38907
4 changed files with 42 additions and 33 deletions
|
@ -27,12 +27,16 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
output_sample_types=[],
|
||||
add_sample_idx=False,
|
||||
generators_count=4,
|
||||
rnd_seed=None,
|
||||
**kwargs):
|
||||
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
self.add_sample_idx = add_sample_idx
|
||||
|
||||
if rnd_seed is None:
|
||||
rnd_seed = np.random.randint(0x80000000)
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
|
@ -45,11 +49,11 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
if self.samples_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
||||
index_host = mplib.IndexHost(self.samples_len)
|
||||
index_host = mplib.IndexHost(self.samples_len, rnd_seed=rnd_seed)
|
||||
|
||||
if random_ct_samples_path is not None:
|
||||
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path)
|
||||
ct_index_host = mplib.IndexHost( len(ct_samples) )
|
||||
ct_index_host = mplib.IndexHost( len(ct_samples), rnd_seed=rnd_seed )
|
||||
else:
|
||||
ct_samples = None
|
||||
ct_index_host = None
|
||||
|
@ -58,9 +62,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None
|
||||
|
||||
if self.debug:
|
||||
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
||||
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None, rnd_seed) )]
|
||||
else:
|
||||
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \
|
||||
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None, rnd_seed), start_now=False ) \
|
||||
for i in range(self.generators_count) ]
|
||||
|
||||
SubprocessGenerator.start_in_parallel( self.generators )
|
||||
|
@ -76,7 +80,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
return next(generator)
|
||||
|
||||
def batch_func(self, param ):
|
||||
pickled_samples, index_host, ct_pickled_samples, ct_index_host = param
|
||||
pickled_samples, index_host, ct_pickled_samples, ct_index_host, rnd_seed = param
|
||||
|
||||
rnd_state = np.random.RandomState(rnd_seed)
|
||||
|
||||
samples = pickle.loads(pickled_samples)
|
||||
ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None
|
||||
|
@ -98,7 +104,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
ct_sample = ct_samples[ct_indexes[n_batch]]
|
||||
|
||||
try:
|
||||
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
|
||||
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, rnd_state=rnd_state)
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue