diff --git a/core/joblib/SubprocessGenerator.py b/core/joblib/SubprocessGenerator.py index 82ccbdb..84a5937 100644 --- a/core/joblib/SubprocessGenerator.py +++ b/core/joblib/SubprocessGenerator.py @@ -1,7 +1,28 @@ -import queue as Queue import multiprocessing +import queue as Queue +import threading +import time + class SubprocessGenerator(object): + + @staticmethod + def launch_thread(generator): + generator._start() + + @staticmethod + def start_in_parallel( generator_list ): + """ + Start list of generators in parallel + """ + for generator in generator_list: + thread = threading.Thread(target=SubprocessGenerator.launch_thread, args=(generator,) ) + thread.daemon = True + thread.start() + + while not all ([generator._is_started() for generator in generator_list]): + time.sleep(0.005) + def __init__(self, generator_func, user_param=None, prefetch=2, start_now=True): super().__init__() self.prefetch = prefetch @@ -17,10 +38,14 @@ class SubprocessGenerator(object): if self.p == None: user_param = self.user_param self.user_param = None - self.p = multiprocessing.Process(target=self.process_func, args=(user_param,) ) - self.p.daemon = True - self.p.start() - + p = multiprocessing.Process(target=self.process_func, args=(user_param,) ) + p.daemon = True + p.start() + self.p = p + + def _is_started(self): + return self.p is not None + def process_func(self, user_param): self.generator_func = self.generator_func(user_param) while True: diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index 696a4f2..a7e622f 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -60,7 +60,10 @@ class SampleGeneratorFace(SampleGeneratorBase): if self.debug: self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )] else: - self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ] + self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \ + for i in range(self.generators_count) ] + + SubprocessGenerator.start_in_parallel( self.generators ) self.generator_counter = -1 diff --git a/samplelib/SampleGeneratorFacePerson.py b/samplelib/SampleGeneratorFacePerson.py index d6be2d8..ba69f44 100644 --- a/samplelib/SampleGeneratorFacePerson.py +++ b/samplelib/SampleGeneratorFacePerson.py @@ -52,7 +52,7 @@ class SampleGeneratorFacePerson(SampleGeneratorBase): self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) )] else: self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 4) - self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),), start_now=True ) for i in range(self.generators_count) ] + self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) ) for i in range(self.generators_count) ] self.generator_counter = -1 diff --git a/samplelib/SampleGeneratorFaceTemporal.py b/samplelib/SampleGeneratorFaceTemporal.py index 7a9215f..5ffb8e4 100644 --- a/samplelib/SampleGeneratorFaceTemporal.py +++ b/samplelib/SampleGeneratorFaceTemporal.py @@ -44,7 +44,7 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase): if self.debug: self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),) )] else: - self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),), start_now=True ) for i in range(self.generators_count) ] + self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),) ) for i in range(self.generators_count) ] self.generator_counter = -1