SampleGeneratorFace optimizations

This commit is contained in:
iperov 2019-04-24 21:27:05 +04:00
parent 9535a657d2
commit c0a63addd4
2 changed files with 33 additions and 55 deletions

View file

@ -20,7 +20,6 @@ class SampleGeneratorFace(SampleGeneratorBase):
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx
# self.add_pitch_yaw_roll = add_pitch_yaw_roll
if sort_by_yaw_target_samples_path is not None:
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
@ -29,21 +28,19 @@ class SampleGeneratorFace(SampleGeneratorBase):
else:
self.sample_type = SampleType.FACE
self.samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
if generators_random_seed is not None and len(generators_random_seed) != generators_count:
raise ValueError("len(generators_random_seed) != generators_count")
self.generators_random_seed = generators_random_seed
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
if self.debug:
self.generators_count = 1
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )]
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples) )]
else:
self.generators_count = min ( generators_count, len(self.samples) )
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, i ) for i in range(self.generators_count) ]
self.generators_sq = [ multiprocessing.Queue() for _ in range(self.generators_count) ]
self.generators_count = min ( generators_count, len(samples) )
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count] ) ) for i in range(self.generators_count) ]
self.generator_counter = -1
@ -55,22 +52,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator)
#forces to repeat these sample idxs as fast as possible
#currently unused
def repeat_sample_idxs(self, idxs): # [ idx, ... ]
#send idxs list to all sub generators.
for gen_sq in self.generators_sq:
gen_sq.put (idxs)
def batch_func(self, param ):
generator_id, samples = param
def batch_func(self, generator_id):
gen_sq = self.generators_sq[generator_id]
if self.generators_random_seed is not None:
np.random.seed ( self.generators_random_seed[generator_id] )
samples = self.samples
samples_len = len(samples)
samples_idxs = [ *range(samples_len) ] [generator_id::self.generators_count]
repeat_samples_idxs = []
samples_idxs = [*range(samples_len)]
if len(samples_idxs) == 0:
raise ValueError('No training data provided.')
@ -86,24 +75,11 @@ class SampleGeneratorFace(SampleGeneratorBase):
shuffle_idxs_2D = [[]]*samples_len
while True:
while not gen_sq.empty():
idxs = gen_sq.get()
for idx in idxs:
if idx in samples_idxs:
repeat_samples_idxs.append(idx)
batches = None
for n_batch in range(self.batch_size):
while True:
sample = None
if len(repeat_samples_idxs) > 0:
idx = repeat_samples_idxs.pop()
if self.sample_type == SampleType.FACE:
sample = samples[idx]
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
sample = samples[(idx >> 16) & 0xFFFF][idx & 0xFFFF]
else:
if self.sample_type == SampleType.FACE:
if len(shuffle_idxs) == 0:
shuffle_idxs = samples_idxs.copy()

View file

@ -31,8 +31,8 @@ class SubprocessGenerator(object):
self.cs_queue = multiprocessing.Queue()
self.p = None
def process_func(self):
self.generator_func = self.generator_func(self.user_param)
def process_func(self, user_param):
self.generator_func = self.generator_func(user_param)
while True:
while self.prefetch > -1:
try:
@ -50,7 +50,9 @@ class SubprocessGenerator(object):
def __next__(self):
if self.p == None:
self.p = multiprocessing.Process(target=self.process_func, args=())
user_param = self.user_param
self.user_param = None
self.p = multiprocessing.Process(target=self.process_func, args=(user_param,) )
self.p.daemon = True
self.p.start()