mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
Color Transfer mode option
This commit is contained in:
parent
1cd9cf70e6
commit
90ae878b78
3 changed files with 59 additions and 44 deletions
|
@ -7,8 +7,11 @@ from facelib import FaceType
|
|||
from samplelib import *
|
||||
from interact import interact as io
|
||||
|
||||
from samplelib.SampleProcessor import ColorTransferMode
|
||||
|
||||
# SAE - Styled AutoEncoder
|
||||
|
||||
|
||||
class SAEModel(ModelBase):
|
||||
encoderH5 = 'encoder.h5'
|
||||
inter_BH5 = 'inter_B.h5'
|
||||
|
@ -119,8 +122,8 @@ class SAEModel(ModelBase):
|
|||
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."),
|
||||
0.0, 100.0)
|
||||
|
||||
default_apply_random_ct = 0 if is_first_run else self.options.get('apply_random_ct', 0)
|
||||
self.options['apply_random_ct'] = io.input_bool(
|
||||
default_apply_random_ct = ColorTransferMode.NONE if is_first_run else self.options.get('apply_random_ct', ColorTransferMode.NONE)
|
||||
self.options['apply_random_ct'] = io.input_int(
|
||||
"Apply random color transfer to src faceset? (0) None, (1) LCT, (2) RCT, (3) RCT-masked ?:help skip:%s) : " % (
|
||||
yn_str[default_apply_random_ct]), default_apply_random_ct,
|
||||
help_message="Increase variativity of src samples by apply LCT color transfer from random dst "
|
||||
|
@ -140,7 +143,7 @@ class SAEModel(ModelBase):
|
|||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
||||
self.options['apply_random_ct'] = self.options.get('apply_random_ct', 0)
|
||||
self.options['apply_random_ct'] = self.options.get('apply_random_ct', ColorTransferMode.NONE)
|
||||
self.options['clipgrad'] = self.options.get('clipgrad', False)
|
||||
|
||||
if is_first_run:
|
||||
|
@ -169,7 +172,7 @@ class SAEModel(ModelBase):
|
|||
|
||||
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
|
||||
|
||||
apply_random_ct = self.options.get('apply_random_ct', 0)
|
||||
apply_random_ct = self.options.get('apply_random_ct', ColorTransferMode.NONE)
|
||||
masked_training = True
|
||||
|
||||
warped_src = Input(bgr_shape)
|
||||
|
@ -456,8 +459,7 @@ class SAEModel(ModelBase):
|
|||
self.set_training_data_generators([
|
||||
SampleGeneratorFace(training_data_src_path,
|
||||
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
|
||||
random_ct_type=apply_random_ct,
|
||||
random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
||||
scale_range=np.array([-0.05,
|
||||
|
@ -590,7 +592,7 @@ class SAEModel(ModelBase):
|
|||
face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
|
||||
|
||||
from converters import ConverterMasked
|
||||
# TODO apply_random_ct
|
||||
|
||||
return ConverterMasked(self.predictor_func,
|
||||
predictor_input_size=self.options['resolution'],
|
||||
predictor_masked=self.options['learn_mask'],
|
||||
|
|
|
@ -9,7 +9,6 @@ from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
|||
SampleType)
|
||||
from utils import iter_utils
|
||||
|
||||
|
||||
'''
|
||||
arg
|
||||
output_sample_types = [
|
||||
|
@ -17,13 +16,17 @@ output_sample_types = [
|
|||
...
|
||||
]
|
||||
'''
|
||||
|
||||
|
||||
class SampleGeneratorFace(SampleGeneratorBase):
|
||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, random_ct_type=0, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
|
||||
def __init__(self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None,
|
||||
random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(),
|
||||
output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None,
|
||||
**kwargs):
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
self.add_sample_idx = add_sample_idx
|
||||
self.apply_ct_type = random_ct_type
|
||||
|
||||
if sort_by_yaw_target_samples_path is not None:
|
||||
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
|
||||
|
@ -37,17 +40,20 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
|
||||
self.generators_random_seed = generators_random_seed
|
||||
|
||||
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
||||
samples = SampleLoader.load(self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
||||
|
||||
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) if random_ct_samples_path is not None else None
|
||||
ct_samples = SampleLoader.load(SampleType.FACE,
|
||||
random_ct_samples_path) if random_ct_samples_path is not None else None
|
||||
self.random_ct_sample_chance = 100
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples, ct_samples) )]
|
||||
self.generators = [iter_utils.ThisThreadGenerator(self.batch_func, (0, samples, ct_samples))]
|
||||
else:
|
||||
self.generators_count = min ( generators_count, len(samples) )
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ]
|
||||
self.generators_count = min(generators_count, len(samples))
|
||||
self.generators = [
|
||||
iter_utils.SubprocessGenerator(self.batch_func, (i, samples[i::self.generators_count], ct_samples)) for
|
||||
i in range(self.generators_count)]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
|
@ -56,14 +62,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
|
||||
def __next__(self):
|
||||
self.generator_counter += 1
|
||||
generator = self.generators[self.generator_counter % len(self.generators) ]
|
||||
generator = self.generators[self.generator_counter % len(self.generators)]
|
||||
return next(generator)
|
||||
|
||||
def batch_func(self, param ):
|
||||
def batch_func(self, param):
|
||||
generator_id, samples, ct_samples = param
|
||||
|
||||
if self.generators_random_seed is not None:
|
||||
np.random.seed ( self.generators_random_seed[generator_id] )
|
||||
np.random.seed(self.generators_random_seed[generator_id])
|
||||
|
||||
samples_len = len(samples)
|
||||
samples_idxs = [*range(samples_len)]
|
||||
|
@ -74,14 +80,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
raise ValueError('No training data provided.')
|
||||
|
||||
if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||
if all ( [ samples[idx] == None for idx in samples_idxs] ):
|
||||
if all([samples[idx] == None for idx in samples_idxs]):
|
||||
raise ValueError('Not enough training data. Gather more faces!')
|
||||
|
||||
if self.sample_type == SampleType.FACE:
|
||||
shuffle_idxs = []
|
||||
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||
shuffle_idxs = []
|
||||
shuffle_idxs_2D = [[]]*samples_len
|
||||
shuffle_idxs_2D = [[]] * samples_len
|
||||
|
||||
while True:
|
||||
batches = None
|
||||
|
@ -95,7 +101,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
np.random.shuffle(shuffle_idxs)
|
||||
|
||||
idx = shuffle_idxs.pop()
|
||||
sample = samples[ idx ]
|
||||
sample = samples[idx]
|
||||
|
||||
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||
if len(shuffle_idxs) == 0:
|
||||
|
@ -105,8 +111,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
idx = shuffle_idxs.pop()
|
||||
if samples[idx] != None:
|
||||
if len(shuffle_idxs_2D[idx]) == 0:
|
||||
a = shuffle_idxs_2D[idx] = [ *range(len(samples[idx])) ]
|
||||
np.random.shuffle (a)
|
||||
a = shuffle_idxs_2D[idx] = [*range(len(samples[idx]))]
|
||||
np.random.shuffle(a)
|
||||
|
||||
idx2 = shuffle_idxs_2D[idx].pop()
|
||||
sample = samples[idx][idx2]
|
||||
|
@ -115,29 +121,31 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
|
||||
if sample is not None:
|
||||
try:
|
||||
ct_sample=None
|
||||
ct_sample = None
|
||||
if ct_samples is not None:
|
||||
if np.random.randint(100) < self.random_ct_sample_chance:
|
||||
ct_sample=ct_samples[np.random.randint(ct_samples_len)]
|
||||
ct_sample = ct_samples[np.random.randint(ct_samples_len)]
|
||||
|
||||
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, apply_ct_type=self.apply_ct_type)
|
||||
x = SampleProcessor.process(sample, self.sample_process_options, self.output_sample_types,
|
||||
self.debug, ct_sample=ct_sample)
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||
raise Exception(
|
||||
"Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc()))
|
||||
|
||||
if type(x) != tuple and type(x) != list:
|
||||
raise Exception('SampleProcessor.process returns NOT tuple/list')
|
||||
|
||||
if batches is None:
|
||||
batches = [ [] for _ in range(len(x)) ]
|
||||
batches = [[] for _ in range(len(x))]
|
||||
if self.add_sample_idx:
|
||||
batches += [ [] ]
|
||||
i_sample_idx = len(batches)-1
|
||||
batches += [[]]
|
||||
i_sample_idx = len(batches) - 1
|
||||
|
||||
for i in range(len(x)):
|
||||
batches[i].append ( x[i] )
|
||||
batches[i].append(x[i])
|
||||
|
||||
if self.add_sample_idx:
|
||||
batches[i_sample_idx].append (idx)
|
||||
batches[i_sample_idx].append(idx)
|
||||
|
||||
break
|
||||
yield [ np.array(batch) for batch in batches]
|
||||
yield [np.array(batch) for batch in batches]
|
||||
|
|
|
@ -42,6 +42,13 @@ opts:
|
|||
"""
|
||||
|
||||
|
||||
class ColorTransferMode(IntEnum):
|
||||
NONE = 0
|
||||
LCT = 1
|
||||
RCT = 2
|
||||
RCT_MASKED = 3
|
||||
|
||||
|
||||
class SampleProcessor(object):
|
||||
class Types(IntEnum):
|
||||
NONE = 0
|
||||
|
@ -82,7 +89,7 @@ class SampleProcessor(object):
|
|||
self.ty_range = ty_range
|
||||
|
||||
@staticmethod
|
||||
def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None, apply_ct_type=0):
|
||||
def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None):
|
||||
SPTF = SampleProcessor.Types
|
||||
|
||||
sample_bgr = sample.load_bgr()
|
||||
|
@ -120,8 +127,7 @@ class SampleProcessor(object):
|
|||
normalize_std_dev = opts.get('normalize_std_dev', False)
|
||||
normalize_vgg = opts.get('normalize_vgg', False)
|
||||
motion_blur = opts.get('motion_blur', None)
|
||||
apply_ct = opts.get('apply_ct', False)
|
||||
apply_ct_type = opts.get('apply_ct_type', 0)
|
||||
apply_ct = opts.get('apply_ct', ColorTransferMode.NONE)
|
||||
normalize_tanh = opts.get('normalize_tanh', False)
|
||||
|
||||
img_type = SPTF.NONE
|
||||
|
@ -225,16 +231,15 @@ class SampleProcessor(object):
|
|||
if ct_sample_bgr is None:
|
||||
ct_sample_bgr = ct_sample.load_bgr()
|
||||
|
||||
# TODO enum, apply_ct_type
|
||||
if apply_ct_type == 3 and ct_sample_mask is None:
|
||||
if apply_ct == ColorTransferMode.LCT:
|
||||
img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr)
|
||||
|
||||
elif apply_ct == ColorTransferMode.RCT:
|
||||
img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True)
|
||||
|
||||
elif apply_ct == ColorTransferMode.RCT_MASKED and ct_sample_mask is None:
|
||||
ct_sample_mask = ct_sample.load_fanseg_mask() or \
|
||||
LandmarksProcessor.get_image_hull_mask(ct_sample_bgr.shape, ct_sample.landmarks)
|
||||
|
||||
if apply_ct_type == 1:
|
||||
img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr)
|
||||
if apply_ct_type == 2:
|
||||
img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True)
|
||||
if apply_ct_type == 3:
|
||||
img_bgr = imagelib.reinhard_color_transfer(img_bgr,
|
||||
ct_sample_bgr,
|
||||
clip=True,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue