From c0a63addd430c7121247ba4243af506e67831d6e Mon Sep 17 00:00:00 2001 From: iperov Date: Wed, 24 Apr 2019 21:27:05 +0400 Subject: [PATCH] SampleGeneratorFace optimizations --- samplelib/SampleGeneratorFace.py | 80 +++++++++++--------------------- utils/iter_utils.py | 8 ++-- 2 files changed, 33 insertions(+), 55 deletions(-) diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index 59f8a58..a781ff7 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -20,7 +20,6 @@ class SampleGeneratorFace(SampleGeneratorBase): self.sample_process_options = sample_process_options self.output_sample_types = output_sample_types self.add_sample_idx = add_sample_idx - # self.add_pitch_yaw_roll = add_pitch_yaw_roll if sort_by_yaw_target_samples_path is not None: self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET @@ -29,21 +28,19 @@ class SampleGeneratorFace(SampleGeneratorBase): else: self.sample_type = SampleType.FACE - self.samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path) - if generators_random_seed is not None and len(generators_random_seed) != generators_count: raise ValueError("len(generators_random_seed) != generators_count") self.generators_random_seed = generators_random_seed - + + samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path) + if self.debug: self.generators_count = 1 - self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] + self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples) )] else: - self.generators_count = min ( generators_count, len(self.samples) ) - self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, i ) for i in range(self.generators_count) ] - - self.generators_sq = [ multiprocessing.Queue() for _ in range(self.generators_count) ] + self.generators_count = min ( generators_count, len(samples) ) + self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count] ) ) for i in range(self.generators_count) ] self.generator_counter = -1 @@ -55,22 +52,14 @@ class SampleGeneratorFace(SampleGeneratorBase): generator = self.generators[self.generator_counter % len(self.generators) ] return next(generator) - #forces to repeat these sample idxs as fast as possible - #currently unused - def repeat_sample_idxs(self, idxs): # [ idx, ... ] - #send idxs list to all sub generators. - for gen_sq in self.generators_sq: - gen_sq.put (idxs) - - def batch_func(self, generator_id): - gen_sq = self.generators_sq[generator_id] + def batch_func(self, param ): + generator_id, samples = param + if self.generators_random_seed is not None: np.random.seed ( self.generators_random_seed[generator_id] ) - samples = self.samples samples_len = len(samples) - samples_idxs = [ *range(samples_len) ] [generator_id::self.generators_count] - repeat_samples_idxs = [] + samples_idxs = [*range(samples_len)] if len(samples_idxs) == 0: raise ValueError('No training data provided.') @@ -86,47 +75,34 @@ class SampleGeneratorFace(SampleGeneratorBase): shuffle_idxs_2D = [[]]*samples_len while True: - while not gen_sq.empty(): - idxs = gen_sq.get() - for idx in idxs: - if idx in samples_idxs: - repeat_samples_idxs.append(idx) - batches = None for n_batch in range(self.batch_size): while True: sample = None - if len(repeat_samples_idxs) > 0: - idx = repeat_samples_idxs.pop() - if self.sample_type == SampleType.FACE: - sample = samples[idx] - elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: - sample = samples[(idx >> 16) & 0xFFFF][idx & 0xFFFF] - else: - if self.sample_type == SampleType.FACE: - if len(shuffle_idxs) == 0: - shuffle_idxs = samples_idxs.copy() - np.random.shuffle(shuffle_idxs) + if self.sample_type == SampleType.FACE: + if len(shuffle_idxs) == 0: + shuffle_idxs = samples_idxs.copy() + np.random.shuffle(shuffle_idxs) - idx = shuffle_idxs.pop() - sample = samples[ idx ] + idx = shuffle_idxs.pop() + sample = samples[ idx ] - elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: - if len(shuffle_idxs) == 0: - shuffle_idxs = samples_idxs.copy() - np.random.shuffle(shuffle_idxs) + elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: + if len(shuffle_idxs) == 0: + shuffle_idxs = samples_idxs.copy() + np.random.shuffle(shuffle_idxs) - idx = shuffle_idxs.pop() - if samples[idx] != None: - if len(shuffle_idxs_2D[idx]) == 0: - a = shuffle_idxs_2D[idx] = [ *range(len(samples[idx])) ] - np.random.shuffle (a) + idx = shuffle_idxs.pop() + if samples[idx] != None: + if len(shuffle_idxs_2D[idx]) == 0: + a = shuffle_idxs_2D[idx] = [ *range(len(samples[idx])) ] + np.random.shuffle (a) - idx2 = shuffle_idxs_2D[idx].pop() - sample = samples[idx][idx2] + idx2 = shuffle_idxs_2D[idx].pop() + sample = samples[idx][idx2] - idx = (idx << 16) | (idx2 & 0xFFFF) + idx = (idx << 16) | (idx2 & 0xFFFF) if sample is not None: try: diff --git a/utils/iter_utils.py b/utils/iter_utils.py index b197fc0..75fb21f 100644 --- a/utils/iter_utils.py +++ b/utils/iter_utils.py @@ -31,8 +31,8 @@ class SubprocessGenerator(object): self.cs_queue = multiprocessing.Queue() self.p = None - def process_func(self): - self.generator_func = self.generator_func(self.user_param) + def process_func(self, user_param): + self.generator_func = self.generator_func(user_param) while True: while self.prefetch > -1: try: @@ -50,7 +50,9 @@ class SubprocessGenerator(object): def __next__(self): if self.p == None: - self.p = multiprocessing.Process(target=self.process_func, args=()) + user_param = self.user_param + self.user_param = None + self.p = multiprocessing.Process(target=self.process_func, args=(user_param,) ) self.p.daemon = True self.p.start()