mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
Color Transfer mode option
This commit is contained in:
parent
1cd9cf70e6
commit
90ae878b78
3 changed files with 59 additions and 44 deletions
|
@ -7,8 +7,11 @@ from facelib import FaceType
|
|||
from samplelib import *
|
||||
from interact import interact as io
|
||||
|
||||
from samplelib.SampleProcessor import ColorTransferMode
|
||||
|
||||
# SAE - Styled AutoEncoder
|
||||
|
||||
|
||||
class SAEModel(ModelBase):
|
||||
encoderH5 = 'encoder.h5'
|
||||
inter_BH5 = 'inter_B.h5'
|
||||
|
@ -119,8 +122,8 @@ class SAEModel(ModelBase):
|
|||
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."),
|
||||
0.0, 100.0)
|
||||
|
||||
default_apply_random_ct = 0 if is_first_run else self.options.get('apply_random_ct', 0)
|
||||
self.options['apply_random_ct'] = io.input_bool(
|
||||
default_apply_random_ct = ColorTransferMode.NONE if is_first_run else self.options.get('apply_random_ct', ColorTransferMode.NONE)
|
||||
self.options['apply_random_ct'] = io.input_int(
|
||||
"Apply random color transfer to src faceset? (0) None, (1) LCT, (2) RCT, (3) RCT-masked ?:help skip:%s) : " % (
|
||||
yn_str[default_apply_random_ct]), default_apply_random_ct,
|
||||
help_message="Increase variativity of src samples by apply LCT color transfer from random dst "
|
||||
|
@ -140,7 +143,7 @@ class SAEModel(ModelBase):
|
|||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
||||
self.options['apply_random_ct'] = self.options.get('apply_random_ct', 0)
|
||||
self.options['apply_random_ct'] = self.options.get('apply_random_ct', ColorTransferMode.NONE)
|
||||
self.options['clipgrad'] = self.options.get('clipgrad', False)
|
||||
|
||||
if is_first_run:
|
||||
|
@ -169,7 +172,7 @@ class SAEModel(ModelBase):
|
|||
|
||||
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
|
||||
|
||||
apply_random_ct = self.options.get('apply_random_ct', 0)
|
||||
apply_random_ct = self.options.get('apply_random_ct', ColorTransferMode.NONE)
|
||||
masked_training = True
|
||||
|
||||
warped_src = Input(bgr_shape)
|
||||
|
@ -456,8 +459,7 @@ class SAEModel(ModelBase):
|
|||
self.set_training_data_generators([
|
||||
SampleGeneratorFace(training_data_src_path,
|
||||
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
|
||||
random_ct_type=apply_random_ct,
|
||||
random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
||||
scale_range=np.array([-0.05,
|
||||
|
@ -590,7 +592,7 @@ class SAEModel(ModelBase):
|
|||
face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
|
||||
|
||||
from converters import ConverterMasked
|
||||
# TODO apply_random_ct
|
||||
|
||||
return ConverterMasked(self.predictor_func,
|
||||
predictor_input_size=self.options['resolution'],
|
||||
predictor_masked=self.options['learn_mask'],
|
||||
|
|
|
@ -9,7 +9,6 @@ from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
|||
SampleType)
|
||||
from utils import iter_utils
|
||||
|
||||
|
||||
'''
|
||||
arg
|
||||
output_sample_types = [
|
||||
|
@ -17,13 +16,17 @@ output_sample_types = [
|
|||
...
|
||||
]
|
||||
'''
|
||||
|
||||
|
||||
class SampleGeneratorFace(SampleGeneratorBase):
|
||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, random_ct_type=0, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
|
||||
def __init__(self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None,
|
||||
random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(),
|
||||
output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None,
|
||||
**kwargs):
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
self.add_sample_idx = add_sample_idx
|
||||
self.apply_ct_type = random_ct_type
|
||||
|
||||
if sort_by_yaw_target_samples_path is not None:
|
||||
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
|
||||
|
@ -39,7 +42,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
|
||||
samples = SampleLoader.load(self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
||||
|
||||
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) if random_ct_samples_path is not None else None
|
||||
ct_samples = SampleLoader.load(SampleType.FACE,
|
||||
random_ct_samples_path) if random_ct_samples_path is not None else None
|
||||
self.random_ct_sample_chance = 100
|
||||
|
||||
if self.debug:
|
||||
|
@ -47,7 +51,9 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
self.generators = [iter_utils.ThisThreadGenerator(self.batch_func, (0, samples, ct_samples))]
|
||||
else:
|
||||
self.generators_count = min(generators_count, len(samples))
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ]
|
||||
self.generators = [
|
||||
iter_utils.SubprocessGenerator(self.batch_func, (i, samples[i::self.generators_count], ct_samples)) for
|
||||
i in range(self.generators_count)]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
|
@ -120,9 +126,11 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
if np.random.randint(100) < self.random_ct_sample_chance:
|
||||
ct_sample = ct_samples[np.random.randint(ct_samples_len)]
|
||||
|
||||
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, apply_ct_type=self.apply_ct_type)
|
||||
x = SampleProcessor.process(sample, self.sample_process_options, self.output_sample_types,
|
||||
self.debug, ct_sample=ct_sample)
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||
raise Exception(
|
||||
"Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc()))
|
||||
|
||||
if type(x) != tuple and type(x) != list:
|
||||
raise Exception('SampleProcessor.process returns NOT tuple/list')
|
||||
|
|
|
@ -42,6 +42,13 @@ opts:
|
|||
"""
|
||||
|
||||
|
||||
class ColorTransferMode(IntEnum):
|
||||
NONE = 0
|
||||
LCT = 1
|
||||
RCT = 2
|
||||
RCT_MASKED = 3
|
||||
|
||||
|
||||
class SampleProcessor(object):
|
||||
class Types(IntEnum):
|
||||
NONE = 0
|
||||
|
@ -82,7 +89,7 @@ class SampleProcessor(object):
|
|||
self.ty_range = ty_range
|
||||
|
||||
@staticmethod
|
||||
def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None, apply_ct_type=0):
|
||||
def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None):
|
||||
SPTF = SampleProcessor.Types
|
||||
|
||||
sample_bgr = sample.load_bgr()
|
||||
|
@ -120,8 +127,7 @@ class SampleProcessor(object):
|
|||
normalize_std_dev = opts.get('normalize_std_dev', False)
|
||||
normalize_vgg = opts.get('normalize_vgg', False)
|
||||
motion_blur = opts.get('motion_blur', None)
|
||||
apply_ct = opts.get('apply_ct', False)
|
||||
apply_ct_type = opts.get('apply_ct_type', 0)
|
||||
apply_ct = opts.get('apply_ct', ColorTransferMode.NONE)
|
||||
normalize_tanh = opts.get('normalize_tanh', False)
|
||||
|
||||
img_type = SPTF.NONE
|
||||
|
@ -225,16 +231,15 @@ class SampleProcessor(object):
|
|||
if ct_sample_bgr is None:
|
||||
ct_sample_bgr = ct_sample.load_bgr()
|
||||
|
||||
# TODO enum, apply_ct_type
|
||||
if apply_ct_type == 3 and ct_sample_mask is None:
|
||||
if apply_ct == ColorTransferMode.LCT:
|
||||
img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr)
|
||||
|
||||
elif apply_ct == ColorTransferMode.RCT:
|
||||
img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True)
|
||||
|
||||
elif apply_ct == ColorTransferMode.RCT_MASKED and ct_sample_mask is None:
|
||||
ct_sample_mask = ct_sample.load_fanseg_mask() or \
|
||||
LandmarksProcessor.get_image_hull_mask(ct_sample_bgr.shape, ct_sample.landmarks)
|
||||
|
||||
if apply_ct_type == 1:
|
||||
img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr)
|
||||
if apply_ct_type == 2:
|
||||
img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True)
|
||||
if apply_ct_type == 3:
|
||||
img_bgr = imagelib.reinhard_color_transfer(img_bgr,
|
||||
ct_sample_bgr,
|
||||
clip=True,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue