mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-08 05:51:40 -07:00
SampleGeneratorFace optimizations
This commit is contained in:
parent
9535a657d2
commit
c0a63addd4
2 changed files with 33 additions and 55 deletions
|
@ -20,7 +20,6 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
self.sample_process_options = sample_process_options
|
self.sample_process_options = sample_process_options
|
||||||
self.output_sample_types = output_sample_types
|
self.output_sample_types = output_sample_types
|
||||||
self.add_sample_idx = add_sample_idx
|
self.add_sample_idx = add_sample_idx
|
||||||
# self.add_pitch_yaw_roll = add_pitch_yaw_roll
|
|
||||||
|
|
||||||
if sort_by_yaw_target_samples_path is not None:
|
if sort_by_yaw_target_samples_path is not None:
|
||||||
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
|
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
|
||||||
|
@ -29,21 +28,19 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
else:
|
else:
|
||||||
self.sample_type = SampleType.FACE
|
self.sample_type = SampleType.FACE
|
||||||
|
|
||||||
self.samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
|
||||||
|
|
||||||
if generators_random_seed is not None and len(generators_random_seed) != generators_count:
|
if generators_random_seed is not None and len(generators_random_seed) != generators_count:
|
||||||
raise ValueError("len(generators_random_seed) != generators_count")
|
raise ValueError("len(generators_random_seed) != generators_count")
|
||||||
|
|
||||||
self.generators_random_seed = generators_random_seed
|
self.generators_random_seed = generators_random_seed
|
||||||
|
|
||||||
|
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.generators_count = 1
|
self.generators_count = 1
|
||||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )]
|
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples) )]
|
||||||
else:
|
else:
|
||||||
self.generators_count = min ( generators_count, len(self.samples) )
|
self.generators_count = min ( generators_count, len(samples) )
|
||||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, i ) for i in range(self.generators_count) ]
|
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count] ) ) for i in range(self.generators_count) ]
|
||||||
|
|
||||||
self.generators_sq = [ multiprocessing.Queue() for _ in range(self.generators_count) ]
|
|
||||||
|
|
||||||
self.generator_counter = -1
|
self.generator_counter = -1
|
||||||
|
|
||||||
|
@ -55,22 +52,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
generator = self.generators[self.generator_counter % len(self.generators) ]
|
generator = self.generators[self.generator_counter % len(self.generators) ]
|
||||||
return next(generator)
|
return next(generator)
|
||||||
|
|
||||||
#forces to repeat these sample idxs as fast as possible
|
def batch_func(self, param ):
|
||||||
#currently unused
|
generator_id, samples = param
|
||||||
def repeat_sample_idxs(self, idxs): # [ idx, ... ]
|
|
||||||
#send idxs list to all sub generators.
|
|
||||||
for gen_sq in self.generators_sq:
|
|
||||||
gen_sq.put (idxs)
|
|
||||||
|
|
||||||
def batch_func(self, generator_id):
|
|
||||||
gen_sq = self.generators_sq[generator_id]
|
|
||||||
if self.generators_random_seed is not None:
|
if self.generators_random_seed is not None:
|
||||||
np.random.seed ( self.generators_random_seed[generator_id] )
|
np.random.seed ( self.generators_random_seed[generator_id] )
|
||||||
|
|
||||||
samples = self.samples
|
|
||||||
samples_len = len(samples)
|
samples_len = len(samples)
|
||||||
samples_idxs = [ *range(samples_len) ] [generator_id::self.generators_count]
|
samples_idxs = [*range(samples_len)]
|
||||||
repeat_samples_idxs = []
|
|
||||||
|
|
||||||
if len(samples_idxs) == 0:
|
if len(samples_idxs) == 0:
|
||||||
raise ValueError('No training data provided.')
|
raise ValueError('No training data provided.')
|
||||||
|
@ -86,24 +75,11 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
shuffle_idxs_2D = [[]]*samples_len
|
shuffle_idxs_2D = [[]]*samples_len
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
while not gen_sq.empty():
|
|
||||||
idxs = gen_sq.get()
|
|
||||||
for idx in idxs:
|
|
||||||
if idx in samples_idxs:
|
|
||||||
repeat_samples_idxs.append(idx)
|
|
||||||
|
|
||||||
batches = None
|
batches = None
|
||||||
for n_batch in range(self.batch_size):
|
for n_batch in range(self.batch_size):
|
||||||
while True:
|
while True:
|
||||||
sample = None
|
sample = None
|
||||||
|
|
||||||
if len(repeat_samples_idxs) > 0:
|
|
||||||
idx = repeat_samples_idxs.pop()
|
|
||||||
if self.sample_type == SampleType.FACE:
|
|
||||||
sample = samples[idx]
|
|
||||||
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
|
||||||
sample = samples[(idx >> 16) & 0xFFFF][idx & 0xFFFF]
|
|
||||||
else:
|
|
||||||
if self.sample_type == SampleType.FACE:
|
if self.sample_type == SampleType.FACE:
|
||||||
if len(shuffle_idxs) == 0:
|
if len(shuffle_idxs) == 0:
|
||||||
shuffle_idxs = samples_idxs.copy()
|
shuffle_idxs = samples_idxs.copy()
|
||||||
|
|
|
@ -31,8 +31,8 @@ class SubprocessGenerator(object):
|
||||||
self.cs_queue = multiprocessing.Queue()
|
self.cs_queue = multiprocessing.Queue()
|
||||||
self.p = None
|
self.p = None
|
||||||
|
|
||||||
def process_func(self):
|
def process_func(self, user_param):
|
||||||
self.generator_func = self.generator_func(self.user_param)
|
self.generator_func = self.generator_func(user_param)
|
||||||
while True:
|
while True:
|
||||||
while self.prefetch > -1:
|
while self.prefetch > -1:
|
||||||
try:
|
try:
|
||||||
|
@ -50,7 +50,9 @@ class SubprocessGenerator(object):
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
if self.p == None:
|
if self.p == None:
|
||||||
self.p = multiprocessing.Process(target=self.process_func, args=())
|
user_param = self.user_param
|
||||||
|
self.user_param = None
|
||||||
|
self.p = multiprocessing.Process(target=self.process_func, args=(user_param,) )
|
||||||
self.p.daemon = True
|
self.p.daemon = True
|
||||||
self.p.start()
|
self.p.start()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue