fixes and optimizations

This commit is contained in:
Colombo 2020-01-07 13:45:54 +04:00
parent 842a48964f
commit d3e6b435aa
6 changed files with 72 additions and 58 deletions

View file

@ -1,10 +1,13 @@
import pickle
import traceback
import numpy as np
import cv2
import cv2
import numpy as np
from samplelib import (SampleGeneratorBase, SampleHost, SampleProcessor,
SampleType)
from utils import iter_utils
from samplelib import SampleType, SampleProcessor, SampleHost, SampleGeneratorBase
'''
output_sample_types = [
@ -24,14 +27,18 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
self.generators_count = 1
else:
self.generators_count = generators_count
samples_clis = SampleHost.host (SampleType.FACE_TEMPORAL_SORTED, self.samples_path, number_of_clis=self.generators_count)
samples = SampleHost.load (SampleType.FACE_TEMPORAL_SORTED, self.samples_path)
samples_len = len(samples)
if samples_len == 0:
raise ValueError('No training data provided.')
pickled_samples = pickle.dumps(samples, 4)
if self.debug:
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples_clis[0]) )]
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, pickled_samples) )]
else:
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples_clis[i]) ) for i in range(self.generators_count) ]
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, pickled_samples) ) for i in range(self.generators_count) ]
self.generator_counter = -1
def __iter__(self):
@ -43,22 +50,20 @@ class SampleGeneratorFaceTemporal(SampleGeneratorBase):
return next(generator)
def batch_func(self, param):
generator_id, samples = param
generator_id, pickled_samples = param
samples = pickle.loads(pickled_samples)
samples_len = len(samples)
if samples_len == 0:
raise ValueError('No training data provided.')
mult_max = 1
l = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) )
samples_idxs = [ *range(l+1) ] [generator_id::self.generators_count]
samples_idxs = [ *range(l+1) ]
if len(samples_idxs) - self.temporal_image_count < 0:
raise ValueError('Not enough samples to fit temporal line.')
shuffle_idxs = []
while True:
batches = None
for n_batch in range(self.batch_size):