SAE: added test option: 'Apply random color transfer to src faceset'

This commit is contained in:
iperov 2019-05-06 11:34:56 +04:00
parent bde700243c
commit a805f81142
9 changed files with 152 additions and 129 deletions

View file

@ -1,11 +1,14 @@
import traceback
import numpy as np
import cv2
import multiprocessing
from utils import iter_utils
from facelib import LandmarksProcessor
import traceback
import cv2
import numpy as np
from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
SampleType)
from utils import iter_utils
from samplelib import SampleType, SampleProcessor, SampleLoader, SampleGeneratorBase
'''
arg
@ -15,7 +18,7 @@ output_sample_types = [
]
'''
class SampleGeneratorFace(SampleGeneratorBase):
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
super().__init__(samples_path, debug, batch_size)
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
@ -32,15 +35,17 @@ class SampleGeneratorFace(SampleGeneratorBase):
raise ValueError("len(generators_random_seed) != generators_count")
self.generators_random_seed = generators_random_seed
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) if random_ct_samples_path is not None else None
if self.debug:
self.generators_count = 1
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples) )]
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples, ct_samples) )]
else:
self.generators_count = min ( generators_count, len(samples) )
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count] ) ) for i in range(self.generators_count) ]
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ]
self.generator_counter = -1
@ -53,14 +58,16 @@ class SampleGeneratorFace(SampleGeneratorBase):
return next(generator)
def batch_func(self, param ):
generator_id, samples = param
generator_id, samples, ct_samples = param
if self.generators_random_seed is not None:
np.random.seed ( self.generators_random_seed[generator_id] )
samples_len = len(samples)
samples_idxs = [*range(samples_len)]
ct_samples_len = len(ct_samples) if ct_samples is not None else 0
if len(samples_idxs) == 0:
raise ValueError('No training data provided.')
@ -106,7 +113,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
if sample is not None:
try:
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug)
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug,
ct_sample=ct_samples[np.random.randint(ct_samples_len)] if ct_samples is not None else None )
except:
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )