Select color transfer type

This commit is contained in:
Jeremy Hummel 2019-08-12 00:30:55 -07:00
commit 1cd9cf70e6
3 changed files with 32 additions and 20 deletions

View file

@ -119,11 +119,14 @@ class SAEModel(ModelBase):
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."),
0.0, 100.0)
default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False)
default_apply_random_ct = 0 if is_first_run else self.options.get('apply_random_ct', 0)
self.options['apply_random_ct'] = io.input_bool(
"Apply random color transfer to src faceset? (y/n, ?:help skip:%s) : " % (
"Apply random color transfer to src faceset? (0) None, (1) LCT, (2) RCT, (3) RCT-masked ?:help skip:%s) : " % (
yn_str[default_apply_random_ct]), default_apply_random_ct,
help_message="Increase variativity of src samples by apply LCT color transfer from random dst samples. It is like 'face_style' learning, but more precise color transfer and without risk of model collapse, also it does not require additional GPU resources, but the training time may be longer, due to the src faceset is becoming more diverse.")
help_message="Increase variativity of src samples by apply LCT color transfer from random dst "
"samples. It is like 'face_style' learning, but more precise color transfer and without "
"risk of model collapse, also it does not require additional GPU resources, "
"but the training time may be longer, due to the src faceset is becoming more diverse.")
if nnlib.device.backend != 'plaidML': # todo https://github.com/plaidml/plaidml/issues/301
default_clipgrad = False if is_first_run else self.options.get('clipgrad', False)
@ -137,7 +140,7 @@ class SAEModel(ModelBase):
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
self.options['apply_random_ct'] = self.options.get('apply_random_ct', False)
self.options['apply_random_ct'] = self.options.get('apply_random_ct', 0)
self.options['clipgrad'] = self.options.get('clipgrad', False)
if is_first_run:
@ -166,7 +169,7 @@ class SAEModel(ModelBase):
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
apply_random_ct = self.options.get('apply_random_ct', False)
apply_random_ct = self.options.get('apply_random_ct', 0)
masked_training = True
warped_src = Input(bgr_shape)
@ -454,6 +457,7 @@ class SAEModel(ModelBase):
SampleGeneratorFace(training_data_src_path,
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
random_ct_type=apply_random_ct,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
scale_range=np.array([-0.05,
@ -586,6 +590,7 @@ class SAEModel(ModelBase):
face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
from converters import ConverterMasked
# TODO apply_random_ct
return ConverterMasked(self.predictor_func,
predictor_input_size=self.options['resolution'],
predictor_masked=self.options['learn_mask'],

View file

@ -18,11 +18,12 @@ output_sample_types = [
]
'''
class SampleGeneratorFace(SampleGeneratorBase):
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, random_ct_type=0, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
super().__init__(samples_path, debug, batch_size)
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx
self.apply_ct_type = random_ct_type
if sort_by_yaw_target_samples_path is not None:
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
@ -119,7 +120,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
if np.random.randint(100) < self.random_ct_sample_chance:
ct_sample=ct_samples[np.random.randint(ct_samples_len)]
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, apply_ct_type=self.apply_ct_type)
except:
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )

View file

@ -82,7 +82,7 @@ class SampleProcessor(object):
self.ty_range = ty_range
@staticmethod
def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None):
def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None, apply_ct_type=0):
SPTF = SampleProcessor.Types
sample_bgr = sample.load_bgr()
@ -121,6 +121,7 @@ class SampleProcessor(object):
normalize_vgg = opts.get('normalize_vgg', False)
motion_blur = opts.get('motion_blur', None)
apply_ct = opts.get('apply_ct', False)
apply_ct_type = opts.get('apply_ct_type', 0)
normalize_tanh = opts.get('normalize_tanh', False)
img_type = SPTF.NONE
@ -223,18 +224,23 @@ class SampleProcessor(object):
if apply_ct and ct_sample is not None:
if ct_sample_bgr is None:
ct_sample_bgr = ct_sample.load_bgr()
if ct_sample_mask is None:
# TODO enum, apply_ct_type
if apply_ct_type == 3 and ct_sample_mask is None:
ct_sample_mask = ct_sample.load_fanseg_mask() or \
LandmarksProcessor.get_image_hull_mask(ct_sample_bgr.shape, ct_sample.landmarks)
if apply_ct_type == 1:
img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr)
if apply_ct_type == 2:
img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True)
if apply_ct_type == 3:
img_bgr = imagelib.reinhard_color_transfer(img_bgr,
ct_sample_bgr,
clip=True,
target_mask=img_mask,
source_mask=ct_sample_mask)
# img_bgr = imagelib.reinhard_color_transfer(img_bgr,
# ct_sample_bgr,
# clip=True)
if normalize_std_dev:
img_bgr = (img_bgr - img_bgr.mean((0, 1))) / img_bgr.std((0, 1))