mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
decreased time of training initialization
This commit is contained in:
parent
814da70577
commit
01376fd17c
4 changed files with 36 additions and 8 deletions
|
@ -1,7 +1,28 @@
|
|||
import queue as Queue
|
||||
import multiprocessing
|
||||
import queue as Queue
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
class SubprocessGenerator(object):
|
||||
|
||||
@staticmethod
|
||||
def launch_thread(generator):
|
||||
generator._start()
|
||||
|
||||
@staticmethod
|
||||
def start_in_parallel( generator_list ):
|
||||
"""
|
||||
Start list of generators in parallel
|
||||
"""
|
||||
for generator in generator_list:
|
||||
thread = threading.Thread(target=SubprocessGenerator.launch_thread, args=(generator,) )
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
while not all ([generator._is_started() for generator in generator_list]):
|
||||
time.sleep(0.005)
|
||||
|
||||
def __init__(self, generator_func, user_param=None, prefetch=2, start_now=True):
|
||||
super().__init__()
|
||||
self.prefetch = prefetch
|
||||
|
@ -17,9 +38,13 @@ class SubprocessGenerator(object):
|
|||
if self.p == None:
|
||||
user_param = self.user_param
|
||||
self.user_param = None
|
||||
self.p = multiprocessing.Process(target=self.process_func, args=(user_param,) )
|
||||
self.p.daemon = True
|
||||
self.p.start()
|
||||
p = multiprocessing.Process(target=self.process_func, args=(user_param,) )
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.p = p
|
||||
|
||||
def _is_started(self):
|
||||
return self.p is not None
|
||||
|
||||
def process_func(self, user_param):
|
||||
self.generator_func = self.generator_func(user_param)
|
||||
|
|
|
@ -60,7 +60,10 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
if self.debug:
|
||||
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
|
||||
else:
|
||||
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
|
||||
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \
|
||||
for i in range(self.generators_count) ]
|
||||
|
||||
SubprocessGenerator.start_in_parallel( self.generators )
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
|
|||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) )]
|
||||
else:
|
||||
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 4)
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),), start_now=True ) for i in range(self.generators_count) ]
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) ) for i in range(self.generators_count) ]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
|
|||
if self.debug:
|
||||
self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),) )]
|
||||
else:
|
||||
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),), start_now=True ) for i in range(self.generators_count) ]
|
||||
self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),) ) for i in range(self.generators_count) ]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue