optimized sample generator

This commit is contained in:
Colombo 2020-01-05 11:53:31 +04:00
commit 21b25038ac
6 changed files with 201 additions and 160 deletions

View file

@ -7,8 +7,8 @@ import numpy as np
from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
SampleType)
from utils import iter_utils
from utils import mp_utils
from utils import iter_utils, mp_utils
'''
arg
@ -30,8 +30,13 @@ class SampleGeneratorFace(SampleGeneratorBase):
self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx
samples_host = SampleHost.mp_host (SampleType.FACE, self.samples_path)
self.samples_len = len(samples_host.get_list())
if self.debug:
self.generators_count = 1
else:
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 6)
samples_clis = SampleHost.host (SampleType.FACE, self.samples_path, number_of_clis=self.generators_count)
self.samples_len = len(samples_clis[0])
if self.samples_len == 0:
raise ValueError('No training data provided.')
@ -39,18 +44,16 @@ class SampleGeneratorFace(SampleGeneratorBase):
index_host = mp_utils.IndexHost(self.samples_len)
if random_ct_samples_path is not None:
ct_samples_host = SampleHost.mp_host (SampleType.FACE, random_ct_samples_path)
ct_index_host = mp_utils.IndexHost( len(ct_samples_host.get_list()) )
ct_samples_clis = SampleHost.host (SampleType.FACE, random_ct_samples_path, number_of_clis=self.generators_count)
ct_index_host = mp_utils.IndexHost( len(ct_samples_clis[0]) )
else:
ct_samples_host = None
ct_samples_clis = None
ct_index_host = None
if self.debug:
self.generators_count = 1
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index_host.create_cli(), ct_samples_host.create_cli() if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None) )]
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_clis[0], index_host.create_cli(), ct_samples_clis[0] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None) )]
else:
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 4)
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_host.create_cli(), index_host.create_cli(), ct_samples_host.create_cli() if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_clis[i], index_host.create_cli(), ct_samples_clis[i] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
self.generator_counter = -1
@ -72,13 +75,16 @@ class SampleGeneratorFace(SampleGeneratorBase):
while True:
batches = None
indexes = index_host.get(bs)
ct_indexes = ct_index_host.get(bs) if ct_samples is not None else None
indexes = index_host.multi_get(bs)
ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None
batch_samples = samples.multi_get (indexes)
batch_ct_samples = ct_samples.multi_get (ct_indexes) if ct_samples is not None else None
for n_batch in range(bs):
sample_idx = indexes[n_batch]
sample = samples[ sample_idx ]
ct_sample = ct_samples[ ct_indexes[n_batch] ] if ct_samples is not None else None
sample = batch_samples[n_batch]
ct_sample = batch_ct_samples[n_batch] if ct_samples is not None else None
try:
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)

View file

@ -30,6 +30,7 @@ class SampleGeneratorFacePerson(SampleGeneratorBase):
self.output_sample_types = output_sample_types
self.person_id_mode = person_id_mode
raise NotImplementedError("Currently SampleGeneratorFacePerson is not implemented.")
samples_host = SampleHost.mp_host (SampleType.FACE, self.samples_path)
samples = samples_host.get_list()

View file

@ -20,14 +20,17 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
self.samples = SampleHost.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
if self.debug:
self.generators_count = 1
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )]
else:
self.generators_count = min ( generators_count, len(self.samples) )
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, i ) for i in range(self.generators_count) ]
self.generators_count = generators_count
samples_clis = SampleHost.host (SampleType.FACE_TEMPORAL_SORTED, self.samples_path, number_of_clis=self.generators_count)
if self.debug:
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples_clis[0]) )]
else:
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples_clis[i]) ) for i in range(self.generators_count) ]
self.generator_counter = -1
@ -39,8 +42,9 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
generator = self.generators[self.generator_counter % len(self.generators) ]
return next(generator)
def batch_func(self, generator_id):
samples = self.samples
def batch_func(self, param):
generator_id, samples = param
samples_len = len(samples)
if samples_len == 0:
raise ValueError('No training data provided.')
@ -56,10 +60,8 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
shuffle_idxs = []
while True:
batches = None
for n_batch in range(self.batch_size):
if len(shuffle_idxs) == 0:
shuffle_idxs = samples_idxs.copy()
np.random.shuffle (shuffle_idxs)

View file

@ -1,7 +1,5 @@
import gc
import multiprocessing
import operator
import pickle
import traceback
from pathlib import Path
@ -16,9 +14,11 @@ from .Sample import Sample, SampleType
class SampleHost:
samples_cache = dict()
host_cache = dict()
@staticmethod
def get_person_id_max_count(samples_path):
samples = None
@ -35,7 +35,7 @@ class SampleHost:
return len(list(persons_name_idxs.keys()))
@staticmethod
def load(sample_type, samples_path):
def host(sample_type, samples_path, number_of_clis):
samples_cache = SampleHost.samples_cache
if str(samples_path) not in samples_cache.keys():
@ -46,9 +46,11 @@ class SampleHost:
if sample_type == SampleType.IMAGE:
if samples[sample_type] is None:
samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( Path_utils.get_image_paths(samples_path), "Loading") ]
elif sample_type == SampleType.FACE:
if samples[sample_type] is None:
result = None
elif sample_type == SampleType.FACE or \
sample_type == SampleType.FACE_TEMPORAL_SORTED:
result = None
if samples[sample_type] is None:
try:
result = samplelib.PackedFaceset.load(samples_path)
except:
@ -60,33 +62,26 @@ class SampleHost:
if result is None:
result = SampleHost.load_face_samples( Path_utils.get_image_paths(samples_path) )
result_dmp = pickle.dumps(result)
del result
gc.collect()
result = pickle.loads(result_dmp)
samples[sample_type] = result
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
if samples[sample_type] is None:
samples[sample_type] = SampleHost.upgradeToFaceTemporalSortedSamples( SampleHost.load(SampleType.FACE, samples_path) )
samples[sample_type] = mp_utils.ListHost()
if sample_type == SampleType.FACE_TEMPORAL_SORTED:
result = SampleHost.upgradeToFaceTemporalSortedSamples(result)
list_host = samples[sample_type]
clis = [ list_host.create_cli() for _ in range(number_of_clis) ]
if result is not None:
while True:
if len(result) == 0:
break
items = result[0:10000]
del result[0:10000]
clis[0].extend(items)
return clis
return samples[sample_type]
@staticmethod
def mp_host(sample_type, samples_path):
result = SampleHost.load (sample_type, samples_path)
host_cache = SampleHost.host_cache
if str(samples_path) not in host_cache.keys():
host_cache[str(samples_path)] = [None]*SampleType.QTY
hosts = host_cache[str(samples_path)]
if hosts[sample_type] is None:
hosts[sample_type] = mp_utils.ListHost(result)
return hosts[sample_type]
@staticmethod
def load_face_samples ( image_paths):
result = FaceSamplesLoaderSubprocessor(image_paths).run()