mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
Color Transfer mode option
This commit is contained in:
parent
1cd9cf70e6
commit
90ae878b78
3 changed files with 59 additions and 44 deletions
|
@ -7,8 +7,11 @@ from facelib import FaceType
|
||||||
from samplelib import *
|
from samplelib import *
|
||||||
from interact import interact as io
|
from interact import interact as io
|
||||||
|
|
||||||
|
from samplelib.SampleProcessor import ColorTransferMode
|
||||||
|
|
||||||
# SAE - Styled AutoEncoder
|
# SAE - Styled AutoEncoder
|
||||||
|
|
||||||
|
|
||||||
class SAEModel(ModelBase):
|
class SAEModel(ModelBase):
|
||||||
encoderH5 = 'encoder.h5'
|
encoderH5 = 'encoder.h5'
|
||||||
inter_BH5 = 'inter_B.h5'
|
inter_BH5 = 'inter_B.h5'
|
||||||
|
@ -119,8 +122,8 @@ class SAEModel(ModelBase):
|
||||||
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."),
|
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."),
|
||||||
0.0, 100.0)
|
0.0, 100.0)
|
||||||
|
|
||||||
default_apply_random_ct = 0 if is_first_run else self.options.get('apply_random_ct', 0)
|
default_apply_random_ct = ColorTransferMode.NONE if is_first_run else self.options.get('apply_random_ct', ColorTransferMode.NONE)
|
||||||
self.options['apply_random_ct'] = io.input_bool(
|
self.options['apply_random_ct'] = io.input_int(
|
||||||
"Apply random color transfer to src faceset? (0) None, (1) LCT, (2) RCT, (3) RCT-masked ?:help skip:%s) : " % (
|
"Apply random color transfer to src faceset? (0) None, (1) LCT, (2) RCT, (3) RCT-masked ?:help skip:%s) : " % (
|
||||||
yn_str[default_apply_random_ct]), default_apply_random_ct,
|
yn_str[default_apply_random_ct]), default_apply_random_ct,
|
||||||
help_message="Increase variativity of src samples by apply LCT color transfer from random dst "
|
help_message="Increase variativity of src samples by apply LCT color transfer from random dst "
|
||||||
|
@ -140,7 +143,7 @@ class SAEModel(ModelBase):
|
||||||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||||
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
||||||
self.options['apply_random_ct'] = self.options.get('apply_random_ct', 0)
|
self.options['apply_random_ct'] = self.options.get('apply_random_ct', ColorTransferMode.NONE)
|
||||||
self.options['clipgrad'] = self.options.get('clipgrad', False)
|
self.options['clipgrad'] = self.options.get('clipgrad', False)
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
|
@ -169,7 +172,7 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
|
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
|
||||||
|
|
||||||
apply_random_ct = self.options.get('apply_random_ct', 0)
|
apply_random_ct = self.options.get('apply_random_ct', ColorTransferMode.NONE)
|
||||||
masked_training = True
|
masked_training = True
|
||||||
|
|
||||||
warped_src = Input(bgr_shape)
|
warped_src = Input(bgr_shape)
|
||||||
|
@ -456,8 +459,7 @@ class SAEModel(ModelBase):
|
||||||
self.set_training_data_generators([
|
self.set_training_data_generators([
|
||||||
SampleGeneratorFace(training_data_src_path,
|
SampleGeneratorFace(training_data_src_path,
|
||||||
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||||
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
|
random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None,
|
||||||
random_ct_type=apply_random_ct,
|
|
||||||
debug=self.is_debug(), batch_size=self.batch_size,
|
debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
||||||
scale_range=np.array([-0.05,
|
scale_range=np.array([-0.05,
|
||||||
|
@ -590,7 +592,7 @@ class SAEModel(ModelBase):
|
||||||
face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
|
face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
|
||||||
|
|
||||||
from converters import ConverterMasked
|
from converters import ConverterMasked
|
||||||
# TODO apply_random_ct
|
|
||||||
return ConverterMasked(self.predictor_func,
|
return ConverterMasked(self.predictor_func,
|
||||||
predictor_input_size=self.options['resolution'],
|
predictor_input_size=self.options['resolution'],
|
||||||
predictor_masked=self.options['learn_mask'],
|
predictor_masked=self.options['learn_mask'],
|
||||||
|
|
|
@ -9,7 +9,6 @@ from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
||||||
SampleType)
|
SampleType)
|
||||||
from utils import iter_utils
|
from utils import iter_utils
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
arg
|
arg
|
||||||
output_sample_types = [
|
output_sample_types = [
|
||||||
|
@ -17,13 +16,17 @@ output_sample_types = [
|
||||||
...
|
...
|
||||||
]
|
]
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
class SampleGeneratorFace(SampleGeneratorBase):
|
class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, random_ct_type=0, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
|
def __init__(self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None,
|
||||||
|
random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(),
|
||||||
|
output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None,
|
||||||
|
**kwargs):
|
||||||
super().__init__(samples_path, debug, batch_size)
|
super().__init__(samples_path, debug, batch_size)
|
||||||
self.sample_process_options = sample_process_options
|
self.sample_process_options = sample_process_options
|
||||||
self.output_sample_types = output_sample_types
|
self.output_sample_types = output_sample_types
|
||||||
self.add_sample_idx = add_sample_idx
|
self.add_sample_idx = add_sample_idx
|
||||||
self.apply_ct_type = random_ct_type
|
|
||||||
|
|
||||||
if sort_by_yaw_target_samples_path is not None:
|
if sort_by_yaw_target_samples_path is not None:
|
||||||
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
|
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
|
||||||
|
@ -37,17 +40,20 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
|
|
||||||
self.generators_random_seed = generators_random_seed
|
self.generators_random_seed = generators_random_seed
|
||||||
|
|
||||||
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
samples = SampleLoader.load(self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
||||||
|
|
||||||
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) if random_ct_samples_path is not None else None
|
ct_samples = SampleLoader.load(SampleType.FACE,
|
||||||
|
random_ct_samples_path) if random_ct_samples_path is not None else None
|
||||||
self.random_ct_sample_chance = 100
|
self.random_ct_sample_chance = 100
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.generators_count = 1
|
self.generators_count = 1
|
||||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples, ct_samples) )]
|
self.generators = [iter_utils.ThisThreadGenerator(self.batch_func, (0, samples, ct_samples))]
|
||||||
else:
|
else:
|
||||||
self.generators_count = min ( generators_count, len(samples) )
|
self.generators_count = min(generators_count, len(samples))
|
||||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ]
|
self.generators = [
|
||||||
|
iter_utils.SubprocessGenerator(self.batch_func, (i, samples[i::self.generators_count], ct_samples)) for
|
||||||
|
i in range(self.generators_count)]
|
||||||
|
|
||||||
self.generator_counter = -1
|
self.generator_counter = -1
|
||||||
|
|
||||||
|
@ -56,14 +62,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
self.generator_counter += 1
|
self.generator_counter += 1
|
||||||
generator = self.generators[self.generator_counter % len(self.generators) ]
|
generator = self.generators[self.generator_counter % len(self.generators)]
|
||||||
return next(generator)
|
return next(generator)
|
||||||
|
|
||||||
def batch_func(self, param ):
|
def batch_func(self, param):
|
||||||
generator_id, samples, ct_samples = param
|
generator_id, samples, ct_samples = param
|
||||||
|
|
||||||
if self.generators_random_seed is not None:
|
if self.generators_random_seed is not None:
|
||||||
np.random.seed ( self.generators_random_seed[generator_id] )
|
np.random.seed(self.generators_random_seed[generator_id])
|
||||||
|
|
||||||
samples_len = len(samples)
|
samples_len = len(samples)
|
||||||
samples_idxs = [*range(samples_len)]
|
samples_idxs = [*range(samples_len)]
|
||||||
|
@ -74,14 +80,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
raise ValueError('No training data provided.')
|
raise ValueError('No training data provided.')
|
||||||
|
|
||||||
if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||||
if all ( [ samples[idx] == None for idx in samples_idxs] ):
|
if all([samples[idx] == None for idx in samples_idxs]):
|
||||||
raise ValueError('Not enough training data. Gather more faces!')
|
raise ValueError('Not enough training data. Gather more faces!')
|
||||||
|
|
||||||
if self.sample_type == SampleType.FACE:
|
if self.sample_type == SampleType.FACE:
|
||||||
shuffle_idxs = []
|
shuffle_idxs = []
|
||||||
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||||
shuffle_idxs = []
|
shuffle_idxs = []
|
||||||
shuffle_idxs_2D = [[]]*samples_len
|
shuffle_idxs_2D = [[]] * samples_len
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
batches = None
|
batches = None
|
||||||
|
@ -95,7 +101,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
np.random.shuffle(shuffle_idxs)
|
np.random.shuffle(shuffle_idxs)
|
||||||
|
|
||||||
idx = shuffle_idxs.pop()
|
idx = shuffle_idxs.pop()
|
||||||
sample = samples[ idx ]
|
sample = samples[idx]
|
||||||
|
|
||||||
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||||
if len(shuffle_idxs) == 0:
|
if len(shuffle_idxs) == 0:
|
||||||
|
@ -105,8 +111,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
idx = shuffle_idxs.pop()
|
idx = shuffle_idxs.pop()
|
||||||
if samples[idx] != None:
|
if samples[idx] != None:
|
||||||
if len(shuffle_idxs_2D[idx]) == 0:
|
if len(shuffle_idxs_2D[idx]) == 0:
|
||||||
a = shuffle_idxs_2D[idx] = [ *range(len(samples[idx])) ]
|
a = shuffle_idxs_2D[idx] = [*range(len(samples[idx]))]
|
||||||
np.random.shuffle (a)
|
np.random.shuffle(a)
|
||||||
|
|
||||||
idx2 = shuffle_idxs_2D[idx].pop()
|
idx2 = shuffle_idxs_2D[idx].pop()
|
||||||
sample = samples[idx][idx2]
|
sample = samples[idx][idx2]
|
||||||
|
@ -115,29 +121,31 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
|
|
||||||
if sample is not None:
|
if sample is not None:
|
||||||
try:
|
try:
|
||||||
ct_sample=None
|
ct_sample = None
|
||||||
if ct_samples is not None:
|
if ct_samples is not None:
|
||||||
if np.random.randint(100) < self.random_ct_sample_chance:
|
if np.random.randint(100) < self.random_ct_sample_chance:
|
||||||
ct_sample=ct_samples[np.random.randint(ct_samples_len)]
|
ct_sample = ct_samples[np.random.randint(ct_samples_len)]
|
||||||
|
|
||||||
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, apply_ct_type=self.apply_ct_type)
|
x = SampleProcessor.process(sample, self.sample_process_options, self.output_sample_types,
|
||||||
|
self.debug, ct_sample=ct_sample)
|
||||||
except:
|
except:
|
||||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
raise Exception(
|
||||||
|
"Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc()))
|
||||||
|
|
||||||
if type(x) != tuple and type(x) != list:
|
if type(x) != tuple and type(x) != list:
|
||||||
raise Exception('SampleProcessor.process returns NOT tuple/list')
|
raise Exception('SampleProcessor.process returns NOT tuple/list')
|
||||||
|
|
||||||
if batches is None:
|
if batches is None:
|
||||||
batches = [ [] for _ in range(len(x)) ]
|
batches = [[] for _ in range(len(x))]
|
||||||
if self.add_sample_idx:
|
if self.add_sample_idx:
|
||||||
batches += [ [] ]
|
batches += [[]]
|
||||||
i_sample_idx = len(batches)-1
|
i_sample_idx = len(batches) - 1
|
||||||
|
|
||||||
for i in range(len(x)):
|
for i in range(len(x)):
|
||||||
batches[i].append ( x[i] )
|
batches[i].append(x[i])
|
||||||
|
|
||||||
if self.add_sample_idx:
|
if self.add_sample_idx:
|
||||||
batches[i_sample_idx].append (idx)
|
batches[i_sample_idx].append(idx)
|
||||||
|
|
||||||
break
|
break
|
||||||
yield [ np.array(batch) for batch in batches]
|
yield [np.array(batch) for batch in batches]
|
||||||
|
|
|
@ -42,6 +42,13 @@ opts:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ColorTransferMode(IntEnum):
|
||||||
|
NONE = 0
|
||||||
|
LCT = 1
|
||||||
|
RCT = 2
|
||||||
|
RCT_MASKED = 3
|
||||||
|
|
||||||
|
|
||||||
class SampleProcessor(object):
|
class SampleProcessor(object):
|
||||||
class Types(IntEnum):
|
class Types(IntEnum):
|
||||||
NONE = 0
|
NONE = 0
|
||||||
|
@ -82,7 +89,7 @@ class SampleProcessor(object):
|
||||||
self.ty_range = ty_range
|
self.ty_range = ty_range
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None, apply_ct_type=0):
|
def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None):
|
||||||
SPTF = SampleProcessor.Types
|
SPTF = SampleProcessor.Types
|
||||||
|
|
||||||
sample_bgr = sample.load_bgr()
|
sample_bgr = sample.load_bgr()
|
||||||
|
@ -120,8 +127,7 @@ class SampleProcessor(object):
|
||||||
normalize_std_dev = opts.get('normalize_std_dev', False)
|
normalize_std_dev = opts.get('normalize_std_dev', False)
|
||||||
normalize_vgg = opts.get('normalize_vgg', False)
|
normalize_vgg = opts.get('normalize_vgg', False)
|
||||||
motion_blur = opts.get('motion_blur', None)
|
motion_blur = opts.get('motion_blur', None)
|
||||||
apply_ct = opts.get('apply_ct', False)
|
apply_ct = opts.get('apply_ct', ColorTransferMode.NONE)
|
||||||
apply_ct_type = opts.get('apply_ct_type', 0)
|
|
||||||
normalize_tanh = opts.get('normalize_tanh', False)
|
normalize_tanh = opts.get('normalize_tanh', False)
|
||||||
|
|
||||||
img_type = SPTF.NONE
|
img_type = SPTF.NONE
|
||||||
|
@ -225,16 +231,15 @@ class SampleProcessor(object):
|
||||||
if ct_sample_bgr is None:
|
if ct_sample_bgr is None:
|
||||||
ct_sample_bgr = ct_sample.load_bgr()
|
ct_sample_bgr = ct_sample.load_bgr()
|
||||||
|
|
||||||
# TODO enum, apply_ct_type
|
if apply_ct == ColorTransferMode.LCT:
|
||||||
if apply_ct_type == 3 and ct_sample_mask is None:
|
img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr)
|
||||||
|
|
||||||
|
elif apply_ct == ColorTransferMode.RCT:
|
||||||
|
img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True)
|
||||||
|
|
||||||
|
elif apply_ct == ColorTransferMode.RCT_MASKED and ct_sample_mask is None:
|
||||||
ct_sample_mask = ct_sample.load_fanseg_mask() or \
|
ct_sample_mask = ct_sample.load_fanseg_mask() or \
|
||||||
LandmarksProcessor.get_image_hull_mask(ct_sample_bgr.shape, ct_sample.landmarks)
|
LandmarksProcessor.get_image_hull_mask(ct_sample_bgr.shape, ct_sample.landmarks)
|
||||||
|
|
||||||
if apply_ct_type == 1:
|
|
||||||
img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr)
|
|
||||||
if apply_ct_type == 2:
|
|
||||||
img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True)
|
|
||||||
if apply_ct_type == 3:
|
|
||||||
img_bgr = imagelib.reinhard_color_transfer(img_bgr,
|
img_bgr = imagelib.reinhard_color_transfer(img_bgr,
|
||||||
ct_sample_bgr,
|
ct_sample_bgr,
|
||||||
clip=True,
|
clip=True,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue