From 90ae878b78099227d6633dc5fffb80bb061f3f88 Mon Sep 17 00:00:00 2001 From: Jeremy Hummel Date: Mon, 12 Aug 2019 22:36:28 -0700 Subject: [PATCH] Color Transfer mode option --- models/Model_SAE/Model.py | 16 +++++---- samplelib/SampleGeneratorFace.py | 60 ++++++++++++++++++-------------- samplelib/SampleProcessor.py | 27 ++++++++------ 3 files changed, 59 insertions(+), 44 deletions(-) diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 2563099..0ce9b12 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -7,8 +7,11 @@ from facelib import FaceType from samplelib import * from interact import interact as io +from samplelib.SampleProcessor import ColorTransferMode # SAE - Styled AutoEncoder + + class SAEModel(ModelBase): encoderH5 = 'encoder.h5' inter_BH5 = 'inter_B.h5' @@ -119,8 +122,8 @@ class SAEModel(ModelBase): help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0) - default_apply_random_ct = 0 if is_first_run else self.options.get('apply_random_ct', 0) - self.options['apply_random_ct'] = io.input_bool( + default_apply_random_ct = ColorTransferMode.NONE if is_first_run else self.options.get('apply_random_ct', ColorTransferMode.NONE) + self.options['apply_random_ct'] = io.input_int( "Apply random color transfer to src faceset? (0) None, (1) LCT, (2) RCT, (3) RCT-masked ?:help skip:%s) : " % ( yn_str[default_apply_random_ct]), default_apply_random_ct, help_message="Increase variativity of src samples by apply LCT color transfer from random dst " @@ -140,7 +143,7 @@ class SAEModel(ModelBase): self.options['pixel_loss'] = self.options.get('pixel_loss', False) self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power) self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power) - self.options['apply_random_ct'] = self.options.get('apply_random_ct', 0) + self.options['apply_random_ct'] = self.options.get('apply_random_ct', ColorTransferMode.NONE) self.options['clipgrad'] = self.options.get('clipgrad', False) if is_first_run: @@ -169,7 +172,7 @@ class SAEModel(ModelBase): self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1 - apply_random_ct = self.options.get('apply_random_ct', 0) + apply_random_ct = self.options.get('apply_random_ct', ColorTransferMode.NONE) masked_training = True warped_src = Input(bgr_shape) @@ -456,8 +459,7 @@ class SAEModel(ModelBase): self.set_training_data_generators([ SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None, - random_ct_samples_path=training_data_dst_path if apply_random_ct else None, - random_ct_type=apply_random_ct, + random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, @@ -590,7 +592,7 @@ class SAEModel(ModelBase): face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF from converters import ConverterMasked - # TODO apply_random_ct + return ConverterMasked(self.predictor_func, predictor_input_size=self.options['resolution'], predictor_masked=self.options['learn_mask'], diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index dbaaa2d..fad63ea 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -9,7 +9,6 @@ from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType) from utils import iter_utils - ''' arg output_sample_types = [ @@ -17,13 +16,17 @@ output_sample_types = [ ... ] ''' + + class SampleGeneratorFace(SampleGeneratorBase): - def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, random_ct_type=0, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs): + def __init__(self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, + random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), + output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, + **kwargs): super().__init__(samples_path, debug, batch_size) self.sample_process_options = sample_process_options self.output_sample_types = output_sample_types self.add_sample_idx = add_sample_idx - self.apply_ct_type = random_ct_type if sort_by_yaw_target_samples_path is not None: self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET @@ -37,17 +40,20 @@ class SampleGeneratorFace(SampleGeneratorBase): self.generators_random_seed = generators_random_seed - samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path) + samples = SampleLoader.load(self.sample_type, self.samples_path, sort_by_yaw_target_samples_path) - ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) if random_ct_samples_path is not None else None + ct_samples = SampleLoader.load(SampleType.FACE, + random_ct_samples_path) if random_ct_samples_path is not None else None self.random_ct_sample_chance = 100 if self.debug: self.generators_count = 1 - self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples, ct_samples) )] + self.generators = [iter_utils.ThisThreadGenerator(self.batch_func, (0, samples, ct_samples))] else: - self.generators_count = min ( generators_count, len(samples) ) - self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ] + self.generators_count = min(generators_count, len(samples)) + self.generators = [ + iter_utils.SubprocessGenerator(self.batch_func, (i, samples[i::self.generators_count], ct_samples)) for + i in range(self.generators_count)] self.generator_counter = -1 @@ -56,14 +62,14 @@ class SampleGeneratorFace(SampleGeneratorBase): def __next__(self): self.generator_counter += 1 - generator = self.generators[self.generator_counter % len(self.generators) ] + generator = self.generators[self.generator_counter % len(self.generators)] return next(generator) - def batch_func(self, param ): + def batch_func(self, param): generator_id, samples, ct_samples = param if self.generators_random_seed is not None: - np.random.seed ( self.generators_random_seed[generator_id] ) + np.random.seed(self.generators_random_seed[generator_id]) samples_len = len(samples) samples_idxs = [*range(samples_len)] @@ -74,14 +80,14 @@ class SampleGeneratorFace(SampleGeneratorBase): raise ValueError('No training data provided.') if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: - if all ( [ samples[idx] == None for idx in samples_idxs] ): + if all([samples[idx] == None for idx in samples_idxs]): raise ValueError('Not enough training data. Gather more faces!') if self.sample_type == SampleType.FACE: shuffle_idxs = [] elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: shuffle_idxs = [] - shuffle_idxs_2D = [[]]*samples_len + shuffle_idxs_2D = [[]] * samples_len while True: batches = None @@ -95,7 +101,7 @@ class SampleGeneratorFace(SampleGeneratorBase): np.random.shuffle(shuffle_idxs) idx = shuffle_idxs.pop() - sample = samples[ idx ] + sample = samples[idx] elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: if len(shuffle_idxs) == 0: @@ -105,8 +111,8 @@ class SampleGeneratorFace(SampleGeneratorBase): idx = shuffle_idxs.pop() if samples[idx] != None: if len(shuffle_idxs_2D[idx]) == 0: - a = shuffle_idxs_2D[idx] = [ *range(len(samples[idx])) ] - np.random.shuffle (a) + a = shuffle_idxs_2D[idx] = [*range(len(samples[idx]))] + np.random.shuffle(a) idx2 = shuffle_idxs_2D[idx].pop() sample = samples[idx][idx2] @@ -115,29 +121,31 @@ class SampleGeneratorFace(SampleGeneratorBase): if sample is not None: try: - ct_sample=None + ct_sample = None if ct_samples is not None: if np.random.randint(100) < self.random_ct_sample_chance: - ct_sample=ct_samples[np.random.randint(ct_samples_len)] + ct_sample = ct_samples[np.random.randint(ct_samples_len)] - x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, apply_ct_type=self.apply_ct_type) + x = SampleProcessor.process(sample, self.sample_process_options, self.output_sample_types, + self.debug, ct_sample=ct_sample) except: - raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) + raise Exception( + "Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc())) if type(x) != tuple and type(x) != list: raise Exception('SampleProcessor.process returns NOT tuple/list') if batches is None: - batches = [ [] for _ in range(len(x)) ] + batches = [[] for _ in range(len(x))] if self.add_sample_idx: - batches += [ [] ] - i_sample_idx = len(batches)-1 + batches += [[]] + i_sample_idx = len(batches) - 1 for i in range(len(x)): - batches[i].append ( x[i] ) + batches[i].append(x[i]) if self.add_sample_idx: - batches[i_sample_idx].append (idx) + batches[i_sample_idx].append(idx) break - yield [ np.array(batch) for batch in batches] + yield [np.array(batch) for batch in batches] diff --git a/samplelib/SampleProcessor.py b/samplelib/SampleProcessor.py index 42e5964..dab8bc3 100644 --- a/samplelib/SampleProcessor.py +++ b/samplelib/SampleProcessor.py @@ -42,6 +42,13 @@ opts: """ +class ColorTransferMode(IntEnum): + NONE = 0 + LCT = 1 + RCT = 2 + RCT_MASKED = 3 + + class SampleProcessor(object): class Types(IntEnum): NONE = 0 @@ -82,7 +89,7 @@ class SampleProcessor(object): self.ty_range = ty_range @staticmethod - def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None, apply_ct_type=0): + def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None): SPTF = SampleProcessor.Types sample_bgr = sample.load_bgr() @@ -120,8 +127,7 @@ class SampleProcessor(object): normalize_std_dev = opts.get('normalize_std_dev', False) normalize_vgg = opts.get('normalize_vgg', False) motion_blur = opts.get('motion_blur', None) - apply_ct = opts.get('apply_ct', False) - apply_ct_type = opts.get('apply_ct_type', 0) + apply_ct = opts.get('apply_ct', ColorTransferMode.NONE) normalize_tanh = opts.get('normalize_tanh', False) img_type = SPTF.NONE @@ -225,16 +231,15 @@ class SampleProcessor(object): if ct_sample_bgr is None: ct_sample_bgr = ct_sample.load_bgr() - # TODO enum, apply_ct_type - if apply_ct_type == 3 and ct_sample_mask is None: + if apply_ct == ColorTransferMode.LCT: + img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr) + + elif apply_ct == ColorTransferMode.RCT: + img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True) + + elif apply_ct == ColorTransferMode.RCT_MASKED and ct_sample_mask is None: ct_sample_mask = ct_sample.load_fanseg_mask() or \ LandmarksProcessor.get_image_hull_mask(ct_sample_bgr.shape, ct_sample.landmarks) - - if apply_ct_type == 1: - img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr) - if apply_ct_type == 2: - img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True) - if apply_ct_type == 3: img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True,