Color Transfer mode option

This commit is contained in:
Jeremy Hummel 2019-08-12 22:36:28 -07:00
commit 90ae878b78
3 changed files with 59 additions and 44 deletions

View file

@ -7,8 +7,11 @@ from facelib import FaceType
from samplelib import *
from interact import interact as io
from samplelib.SampleProcessor import ColorTransferMode
# SAE - Styled AutoEncoder
class SAEModel(ModelBase):
encoderH5 = 'encoder.h5'
inter_BH5 = 'inter_B.h5'
@ -119,8 +122,8 @@ class SAEModel(ModelBase):
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."),
0.0, 100.0)
default_apply_random_ct = 0 if is_first_run else self.options.get('apply_random_ct', 0)
self.options['apply_random_ct'] = io.input_bool(
default_apply_random_ct = ColorTransferMode.NONE if is_first_run else self.options.get('apply_random_ct', ColorTransferMode.NONE)
self.options['apply_random_ct'] = io.input_int(
"Apply random color transfer to src faceset? (0) None, (1) LCT, (2) RCT, (3) RCT-masked ?:help skip:%s) : " % (
yn_str[default_apply_random_ct]), default_apply_random_ct,
help_message="Increase variativity of src samples by apply LCT color transfer from random dst "
@ -140,7 +143,7 @@ class SAEModel(ModelBase):
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
self.options['apply_random_ct'] = self.options.get('apply_random_ct', 0)
self.options['apply_random_ct'] = self.options.get('apply_random_ct', ColorTransferMode.NONE)
self.options['clipgrad'] = self.options.get('clipgrad', False)
if is_first_run:
@ -169,7 +172,7 @@ class SAEModel(ModelBase):
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
apply_random_ct = self.options.get('apply_random_ct', 0)
apply_random_ct = self.options.get('apply_random_ct', ColorTransferMode.NONE)
masked_training = True
warped_src = Input(bgr_shape)
@ -456,8 +459,7 @@ class SAEModel(ModelBase):
self.set_training_data_generators([
SampleGeneratorFace(training_data_src_path,
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
random_ct_type=apply_random_ct,
random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
scale_range=np.array([-0.05,
@ -590,7 +592,7 @@ class SAEModel(ModelBase):
face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
from converters import ConverterMasked
# TODO apply_random_ct
return ConverterMasked(self.predictor_func,
predictor_input_size=self.options['resolution'],
predictor_masked=self.options['learn_mask'],

View file

@ -9,7 +9,6 @@ from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
SampleType)
from utils import iter_utils
'''
arg
output_sample_types = [
@ -17,13 +16,17 @@ output_sample_types = [
...
]
'''
class SampleGeneratorFace(SampleGeneratorBase):
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, random_ct_type=0, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
def __init__(self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None,
random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(),
output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None,
**kwargs):
super().__init__(samples_path, debug, batch_size)
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx
self.apply_ct_type = random_ct_type
if sort_by_yaw_target_samples_path is not None:
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
@ -37,17 +40,20 @@ class SampleGeneratorFace(SampleGeneratorBase):
self.generators_random_seed = generators_random_seed
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
samples = SampleLoader.load(self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) if random_ct_samples_path is not None else None
ct_samples = SampleLoader.load(SampleType.FACE,
random_ct_samples_path) if random_ct_samples_path is not None else None
self.random_ct_sample_chance = 100
if self.debug:
self.generators_count = 1
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples, ct_samples) )]
self.generators = [iter_utils.ThisThreadGenerator(self.batch_func, (0, samples, ct_samples))]
else:
self.generators_count = min ( generators_count, len(samples) )
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ]
self.generators_count = min(generators_count, len(samples))
self.generators = [
iter_utils.SubprocessGenerator(self.batch_func, (i, samples[i::self.generators_count], ct_samples)) for
i in range(self.generators_count)]
self.generator_counter = -1
@ -56,14 +62,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
def __next__(self):
self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ]
generator = self.generators[self.generator_counter % len(self.generators)]
return next(generator)
def batch_func(self, param ):
def batch_func(self, param):
generator_id, samples, ct_samples = param
if self.generators_random_seed is not None:
np.random.seed ( self.generators_random_seed[generator_id] )
np.random.seed(self.generators_random_seed[generator_id])
samples_len = len(samples)
samples_idxs = [*range(samples_len)]
@ -74,14 +80,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
raise ValueError('No training data provided.')
if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
if all ( [ samples[idx] == None for idx in samples_idxs] ):
if all([samples[idx] == None for idx in samples_idxs]):
raise ValueError('Not enough training data. Gather more faces!')
if self.sample_type == SampleType.FACE:
shuffle_idxs = []
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
shuffle_idxs = []
shuffle_idxs_2D = [[]]*samples_len
shuffle_idxs_2D = [[]] * samples_len
while True:
batches = None
@ -95,7 +101,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
np.random.shuffle(shuffle_idxs)
idx = shuffle_idxs.pop()
sample = samples[ idx ]
sample = samples[idx]
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
if len(shuffle_idxs) == 0:
@ -105,8 +111,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
idx = shuffle_idxs.pop()
if samples[idx] != None:
if len(shuffle_idxs_2D[idx]) == 0:
a = shuffle_idxs_2D[idx] = [ *range(len(samples[idx])) ]
np.random.shuffle (a)
a = shuffle_idxs_2D[idx] = [*range(len(samples[idx]))]
np.random.shuffle(a)
idx2 = shuffle_idxs_2D[idx].pop()
sample = samples[idx][idx2]
@ -115,29 +121,31 @@ class SampleGeneratorFace(SampleGeneratorBase):
if sample is not None:
try:
ct_sample=None
ct_sample = None
if ct_samples is not None:
if np.random.randint(100) < self.random_ct_sample_chance:
ct_sample=ct_samples[np.random.randint(ct_samples_len)]
ct_sample = ct_samples[np.random.randint(ct_samples_len)]
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, apply_ct_type=self.apply_ct_type)
x = SampleProcessor.process(sample, self.sample_process_options, self.output_sample_types,
self.debug, ct_sample=ct_sample)
except:
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
raise Exception(
"Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc()))
if type(x) != tuple and type(x) != list:
raise Exception('SampleProcessor.process returns NOT tuple/list')
if batches is None:
batches = [ [] for _ in range(len(x)) ]
batches = [[] for _ in range(len(x))]
if self.add_sample_idx:
batches += [ [] ]
i_sample_idx = len(batches)-1
batches += [[]]
i_sample_idx = len(batches) - 1
for i in range(len(x)):
batches[i].append ( x[i] )
batches[i].append(x[i])
if self.add_sample_idx:
batches[i_sample_idx].append (idx)
batches[i_sample_idx].append(idx)
break
yield [ np.array(batch) for batch in batches]
yield [np.array(batch) for batch in batches]

View file

@ -42,6 +42,13 @@ opts:
"""
class ColorTransferMode(IntEnum):
NONE = 0
LCT = 1
RCT = 2
RCT_MASKED = 3
class SampleProcessor(object):
class Types(IntEnum):
NONE = 0
@ -82,7 +89,7 @@ class SampleProcessor(object):
self.ty_range = ty_range
@staticmethod
def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None, apply_ct_type=0):
def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None):
SPTF = SampleProcessor.Types
sample_bgr = sample.load_bgr()
@ -120,8 +127,7 @@ class SampleProcessor(object):
normalize_std_dev = opts.get('normalize_std_dev', False)
normalize_vgg = opts.get('normalize_vgg', False)
motion_blur = opts.get('motion_blur', None)
apply_ct = opts.get('apply_ct', False)
apply_ct_type = opts.get('apply_ct_type', 0)
apply_ct = opts.get('apply_ct', ColorTransferMode.NONE)
normalize_tanh = opts.get('normalize_tanh', False)
img_type = SPTF.NONE
@ -225,16 +231,15 @@ class SampleProcessor(object):
if ct_sample_bgr is None:
ct_sample_bgr = ct_sample.load_bgr()
# TODO enum, apply_ct_type
if apply_ct_type == 3 and ct_sample_mask is None:
if apply_ct == ColorTransferMode.LCT:
img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr)
elif apply_ct == ColorTransferMode.RCT:
img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True)
elif apply_ct == ColorTransferMode.RCT_MASKED and ct_sample_mask is None:
ct_sample_mask = ct_sample.load_fanseg_mask() or \
LandmarksProcessor.get_image_hull_mask(ct_sample_bgr.shape, ct_sample.landmarks)
if apply_ct_type == 1:
img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr)
if apply_ct_type == 2:
img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True)
if apply_ct_type == 3:
img_bgr = imagelib.reinhard_color_transfer(img_bgr,
ct_sample_bgr,
clip=True,