From 1cd9cf70e649ee7be8fbfec3648b155f6ea2866d Mon Sep 17 00:00:00 2001 From: Jeremy Hummel Date: Mon, 12 Aug 2019 00:30:55 -0700 Subject: [PATCH] Select color transfer type --- models/Model_SAE/Model.py | 15 ++++++++++----- samplelib/SampleGeneratorFace.py | 11 ++++++----- samplelib/SampleProcessor.py | 26 ++++++++++++++++---------- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 3440660..2563099 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -119,11 +119,14 @@ class SAEModel(ModelBase): help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0) - default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False) + default_apply_random_ct = 0 if is_first_run else self.options.get('apply_random_ct', 0) self.options['apply_random_ct'] = io.input_bool( - "Apply random color transfer to src faceset? (y/n, ?:help skip:%s) : " % ( + "Apply random color transfer to src faceset? (0) None, (1) LCT, (2) RCT, (3) RCT-masked ?:help skip:%s) : " % ( yn_str[default_apply_random_ct]), default_apply_random_ct, - help_message="Increase variativity of src samples by apply LCT color transfer from random dst samples. It is like 'face_style' learning, but more precise color transfer and without risk of model collapse, also it does not require additional GPU resources, but the training time may be longer, due to the src faceset is becoming more diverse.") + help_message="Increase variativity of src samples by apply LCT color transfer from random dst " + "samples. It is like 'face_style' learning, but more precise color transfer and without " + "risk of model collapse, also it does not require additional GPU resources, " + "but the training time may be longer, due to the src faceset is becoming more diverse.") if nnlib.device.backend != 'plaidML': # todo https://github.com/plaidml/plaidml/issues/301 default_clipgrad = False if is_first_run else self.options.get('clipgrad', False) @@ -137,7 +140,7 @@ class SAEModel(ModelBase): self.options['pixel_loss'] = self.options.get('pixel_loss', False) self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power) self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power) - self.options['apply_random_ct'] = self.options.get('apply_random_ct', False) + self.options['apply_random_ct'] = self.options.get('apply_random_ct', 0) self.options['clipgrad'] = self.options.get('clipgrad', False) if is_first_run: @@ -166,7 +169,7 @@ class SAEModel(ModelBase): self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1 - apply_random_ct = self.options.get('apply_random_ct', False) + apply_random_ct = self.options.get('apply_random_ct', 0) masked_training = True warped_src = Input(bgr_shape) @@ -454,6 +457,7 @@ class SAEModel(ModelBase): SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None, random_ct_samples_path=training_data_dst_path if apply_random_ct else None, + random_ct_type=apply_random_ct, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, @@ -586,6 +590,7 @@ class SAEModel(ModelBase): face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF from converters import ConverterMasked + # TODO apply_random_ct return ConverterMasked(self.predictor_func, predictor_input_size=self.options['resolution'], predictor_masked=self.options['learn_mask'], diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index 593f8e2..dbaaa2d 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -18,11 +18,12 @@ output_sample_types = [ ] ''' class SampleGeneratorFace(SampleGeneratorBase): - def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs): + def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, random_ct_type=0, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs): super().__init__(samples_path, debug, batch_size) self.sample_process_options = sample_process_options self.output_sample_types = output_sample_types self.add_sample_idx = add_sample_idx + self.apply_ct_type = random_ct_type if sort_by_yaw_target_samples_path is not None: self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET @@ -114,12 +115,12 @@ class SampleGeneratorFace(SampleGeneratorBase): if sample is not None: try: - ct_sample=None - if ct_samples is not None: + ct_sample=None + if ct_samples is not None: if np.random.randint(100) < self.random_ct_sample_chance: ct_sample=ct_samples[np.random.randint(ct_samples_len)] - - x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample) + + x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, apply_ct_type=self.apply_ct_type) except: raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) diff --git a/samplelib/SampleProcessor.py b/samplelib/SampleProcessor.py index 847299b..42e5964 100644 --- a/samplelib/SampleProcessor.py +++ b/samplelib/SampleProcessor.py @@ -82,7 +82,7 @@ class SampleProcessor(object): self.ty_range = ty_range @staticmethod - def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None): + def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None, apply_ct_type=0): SPTF = SampleProcessor.Types sample_bgr = sample.load_bgr() @@ -121,6 +121,7 @@ class SampleProcessor(object): normalize_vgg = opts.get('normalize_vgg', False) motion_blur = opts.get('motion_blur', None) apply_ct = opts.get('apply_ct', False) + apply_ct_type = opts.get('apply_ct_type', 0) normalize_tanh = opts.get('normalize_tanh', False) img_type = SPTF.NONE @@ -223,18 +224,23 @@ class SampleProcessor(object): if apply_ct and ct_sample is not None: if ct_sample_bgr is None: ct_sample_bgr = ct_sample.load_bgr() - if ct_sample_mask is None: + + # TODO enum, apply_ct_type + if apply_ct_type == 3 and ct_sample_mask is None: ct_sample_mask = ct_sample.load_fanseg_mask() or \ LandmarksProcessor.get_image_hull_mask(ct_sample_bgr.shape, ct_sample.landmarks) - img_bgr = imagelib.reinhard_color_transfer(img_bgr, - ct_sample_bgr, - clip=True, - target_mask=img_mask, - source_mask=ct_sample_mask) - # img_bgr = imagelib.reinhard_color_transfer(img_bgr, - # ct_sample_bgr, - # clip=True) + if apply_ct_type == 1: + img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr) + if apply_ct_type == 2: + img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True) + if apply_ct_type == 3: + img_bgr = imagelib.reinhard_color_transfer(img_bgr, + ct_sample_bgr, + clip=True, + target_mask=img_mask, + source_mask=ct_sample_mask) + if normalize_std_dev: img_bgr = (img_bgr - img_bgr.mean((0, 1))) / img_bgr.std((0, 1))