import multiprocessing import time import traceback import cv2 import numpy as np from core import mplib from core.joblib import SubprocessGenerator, ThisThreadGenerator from facelib import LandmarksProcessor from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType) ''' arg output_sample_types = [ [SampleProcessor.TypeFlags, size, (optional) {} opts ] , ... ] ''' class SampleGeneratorFace(SampleGeneratorBase): def __init__ (self, samples_path, debug=False, batch_size=1, random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=4, raise_on_no_data=True, **kwargs): super().__init__(debug, batch_size) self.sample_process_options = sample_process_options self.output_sample_types = output_sample_types self.add_sample_idx = add_sample_idx if self.debug: self.generators_count = 1 else: self.generators_count = max(1, generators_count) samples = SampleLoader.load (SampleType.FACE, samples_path) self.samples_len = len(samples) self.initialized = False if self.samples_len == 0: if raise_on_no_data: raise ValueError('No training data provided.') else: return index_host = mplib.IndexHost(self.samples_len) if random_ct_samples_path is not None: ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) ct_index_host = mplib.IndexHost( len(ct_samples) ) else: ct_samples = None ct_index_host = None if self.debug: self.generators = [ThisThreadGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )] else: self.generators = [SubprocessGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \ for i in range(self.generators_count) ] SubprocessGenerator.start_in_parallel( self.generators ) self.generator_counter = -1 self.initialized = True #overridable def is_initialized(self): return self.initialized def __iter__(self): return self def __next__(self): if not self.initialized: return [] self.generator_counter += 1 generator = self.generators[self.generator_counter % len(self.generators) ] return next(generator) def batch_func(self, param ): samples, index_host, ct_samples, ct_index_host = param bs = self.batch_size while True: batches = None indexes = index_host.multi_get(bs) ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None t = time.time() for n_batch in range(bs): sample_idx = indexes[n_batch] sample = samples[sample_idx] ct_sample = None if ct_samples is not None: ct_sample = ct_samples[ct_indexes[n_batch]] try: x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample) except: raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) if batches is None: batches = [ [] for _ in range(len(x)) ] if self.add_sample_idx: batches += [ [] ] i_sample_idx = len(batches)-1 for i in range(len(x)): batches[i].append ( x[i] ) if self.add_sample_idx: batches[i_sample_idx].append (sample_idx) yield [ np.array(batch) for batch in batches]