mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
Select color transfer type
This commit is contained in:
parent
00893092b1
commit
1cd9cf70e6
3 changed files with 32 additions and 20 deletions
|
@ -119,11 +119,14 @@ class SAEModel(ModelBase):
|
|||
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."),
|
||||
0.0, 100.0)
|
||||
|
||||
default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False)
|
||||
default_apply_random_ct = 0 if is_first_run else self.options.get('apply_random_ct', 0)
|
||||
self.options['apply_random_ct'] = io.input_bool(
|
||||
"Apply random color transfer to src faceset? (y/n, ?:help skip:%s) : " % (
|
||||
"Apply random color transfer to src faceset? (0) None, (1) LCT, (2) RCT, (3) RCT-masked ?:help skip:%s) : " % (
|
||||
yn_str[default_apply_random_ct]), default_apply_random_ct,
|
||||
help_message="Increase variativity of src samples by apply LCT color transfer from random dst samples. It is like 'face_style' learning, but more precise color transfer and without risk of model collapse, also it does not require additional GPU resources, but the training time may be longer, due to the src faceset is becoming more diverse.")
|
||||
help_message="Increase variativity of src samples by apply LCT color transfer from random dst "
|
||||
"samples. It is like 'face_style' learning, but more precise color transfer and without "
|
||||
"risk of model collapse, also it does not require additional GPU resources, "
|
||||
"but the training time may be longer, due to the src faceset is becoming more diverse.")
|
||||
|
||||
if nnlib.device.backend != 'plaidML': # todo https://github.com/plaidml/plaidml/issues/301
|
||||
default_clipgrad = False if is_first_run else self.options.get('clipgrad', False)
|
||||
|
@ -137,7 +140,7 @@ class SAEModel(ModelBase):
|
|||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
||||
self.options['apply_random_ct'] = self.options.get('apply_random_ct', False)
|
||||
self.options['apply_random_ct'] = self.options.get('apply_random_ct', 0)
|
||||
self.options['clipgrad'] = self.options.get('clipgrad', False)
|
||||
|
||||
if is_first_run:
|
||||
|
@ -166,7 +169,7 @@ class SAEModel(ModelBase):
|
|||
|
||||
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
|
||||
|
||||
apply_random_ct = self.options.get('apply_random_ct', False)
|
||||
apply_random_ct = self.options.get('apply_random_ct', 0)
|
||||
masked_training = True
|
||||
|
||||
warped_src = Input(bgr_shape)
|
||||
|
@ -454,6 +457,7 @@ class SAEModel(ModelBase):
|
|||
SampleGeneratorFace(training_data_src_path,
|
||||
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
|
||||
random_ct_type=apply_random_ct,
|
||||
debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
||||
scale_range=np.array([-0.05,
|
||||
|
@ -586,6 +590,7 @@ class SAEModel(ModelBase):
|
|||
face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
|
||||
|
||||
from converters import ConverterMasked
|
||||
# TODO apply_random_ct
|
||||
return ConverterMasked(self.predictor_func,
|
||||
predictor_input_size=self.options['resolution'],
|
||||
predictor_masked=self.options['learn_mask'],
|
||||
|
|
|
@ -18,11 +18,12 @@ output_sample_types = [
|
|||
]
|
||||
'''
|
||||
class SampleGeneratorFace(SampleGeneratorBase):
|
||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
|
||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, random_ct_type=0, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
self.add_sample_idx = add_sample_idx
|
||||
self.apply_ct_type = random_ct_type
|
||||
|
||||
if sort_by_yaw_target_samples_path is not None:
|
||||
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
|
||||
|
@ -119,7 +120,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
if np.random.randint(100) < self.random_ct_sample_chance:
|
||||
ct_sample=ct_samples[np.random.randint(ct_samples_len)]
|
||||
|
||||
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample)
|
||||
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, apply_ct_type=self.apply_ct_type)
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||
|
||||
|
|
|
@ -82,7 +82,7 @@ class SampleProcessor(object):
|
|||
self.ty_range = ty_range
|
||||
|
||||
@staticmethod
|
||||
def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None):
|
||||
def process(sample, sample_process_options, output_sample_types, debug, ct_sample=None, apply_ct_type=0):
|
||||
SPTF = SampleProcessor.Types
|
||||
|
||||
sample_bgr = sample.load_bgr()
|
||||
|
@ -121,6 +121,7 @@ class SampleProcessor(object):
|
|||
normalize_vgg = opts.get('normalize_vgg', False)
|
||||
motion_blur = opts.get('motion_blur', None)
|
||||
apply_ct = opts.get('apply_ct', False)
|
||||
apply_ct_type = opts.get('apply_ct_type', 0)
|
||||
normalize_tanh = opts.get('normalize_tanh', False)
|
||||
|
||||
img_type = SPTF.NONE
|
||||
|
@ -223,18 +224,23 @@ class SampleProcessor(object):
|
|||
if apply_ct and ct_sample is not None:
|
||||
if ct_sample_bgr is None:
|
||||
ct_sample_bgr = ct_sample.load_bgr()
|
||||
if ct_sample_mask is None:
|
||||
|
||||
# TODO enum, apply_ct_type
|
||||
if apply_ct_type == 3 and ct_sample_mask is None:
|
||||
ct_sample_mask = ct_sample.load_fanseg_mask() or \
|
||||
LandmarksProcessor.get_image_hull_mask(ct_sample_bgr.shape, ct_sample.landmarks)
|
||||
|
||||
if apply_ct_type == 1:
|
||||
img_bgr = imagelib.linear_color_transfer(img_bgr, ct_sample_bgr)
|
||||
if apply_ct_type == 2:
|
||||
img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True)
|
||||
if apply_ct_type == 3:
|
||||
img_bgr = imagelib.reinhard_color_transfer(img_bgr,
|
||||
ct_sample_bgr,
|
||||
clip=True,
|
||||
target_mask=img_mask,
|
||||
source_mask=ct_sample_mask)
|
||||
# img_bgr = imagelib.reinhard_color_transfer(img_bgr,
|
||||
# ct_sample_bgr,
|
||||
# clip=True)
|
||||
|
||||
|
||||
if normalize_std_dev:
|
||||
img_bgr = (img_bgr - img_bgr.mean((0, 1))) / img_bgr.std((0, 1))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue