fixes and optimizations

This commit is contained in:
Colombo 2020-01-07 13:45:54 +04:00
parent 842a48964f
commit d3e6b435aa
6 changed files with 72 additions and 58 deletions

View file

@ -1,9 +1,9 @@
import multiprocessing
import traceback
import pickle
import cv2
import numpy as np
import time
from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
SampleType)
@ -23,6 +23,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
sample_process_options=SampleProcessor.Options(),
output_sample_types=[],
add_sample_idx=False,
generators_count=4,
**kwargs):
super().__init__(samples_path, debug, batch_size)
@ -33,10 +34,10 @@ class SampleGeneratorFace(SampleGeneratorBase):
if self.debug:
self.generators_count = 1
else:
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 6)
self.generators_count = np.clip(multiprocessing.cpu_count(), 2, generators_count)
samples_clis = SampleHost.host (SampleType.FACE, self.samples_path, number_of_clis=self.generators_count)
self.samples_len = len(samples_clis[0])
samples = SampleHost.load (SampleType.FACE, self.samples_path)
self.samples_len = len(samples)
if self.samples_len == 0:
raise ValueError('No training data provided.')
@ -44,16 +45,19 @@ class SampleGeneratorFace(SampleGeneratorBase):
index_host = mp_utils.IndexHost(self.samples_len)
if random_ct_samples_path is not None:
ct_samples_clis = SampleHost.host (SampleType.FACE, random_ct_samples_path, number_of_clis=self.generators_count)
ct_index_host = mp_utils.IndexHost( len(ct_samples_clis[0]) )
ct_samples = SampleHost.load (SampleType.FACE, random_ct_samples_path)
ct_index_host = mp_utils.IndexHost( len(ct_samples) )
else:
ct_samples_clis = None
ct_samples = None
ct_index_host = None
pickled_samples = pickle.dumps(samples, 4)
ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None
if self.debug:
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_clis[0], index_host.create_cli(), ct_samples_clis[0] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None) )]
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )]
else:
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_clis[i], index_host.create_cli(), ct_samples_clis[i] if ct_index_host is not None else None, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(), ct_pickled_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=True ) for i in range(self.generators_count) ]
self.generator_counter = -1
@ -70,7 +74,11 @@ class SampleGeneratorFace(SampleGeneratorBase):
return next(generator)
def batch_func(self, param ):
samples, index_host, ct_samples, ct_index_host = param
pickled_samples, index_host, ct_pickled_samples, ct_index_host = param
samples = pickle.loads(pickled_samples)
ct_samples = pickle.loads(ct_pickled_samples) if ct_pickled_samples is not None else None
bs = self.batch_size
while True:
batches = None
@ -78,13 +86,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
indexes = index_host.multi_get(bs)
ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None
batch_samples = samples.multi_get (indexes)
batch_ct_samples = ct_samples.multi_get (ct_indexes) if ct_samples is not None else None
t = time.time()
for n_batch in range(bs):
sample_idx = indexes[n_batch]
sample = batch_samples[n_batch]
ct_sample = batch_ct_samples[n_batch] if ct_samples is not None else None
sample = samples[sample_idx]
ct_sample = None
if ct_samples is not None:
ct_sample = ct_samples[ct_indexes[n_batch]]
try:
x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
@ -102,4 +111,5 @@ class SampleGeneratorFace(SampleGeneratorBase):
if self.add_sample_idx:
batches[i_sample_idx].append (sample_idx)
yield [ np.array(batch) for batch in batches]