mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-21 22:13:20 -07:00
feature: random LAB rotation
This commit is contained in:
parent
79a233fc55
commit
fcde64fab5
3 changed files with 85 additions and 56 deletions
|
@ -381,6 +381,25 @@ def color_augmentation(img):
|
|||
return (face / 255.0).astype(np.float32)
|
||||
|
||||
|
||||
def random_lab_rotation(image, seed=None):
|
||||
"""
|
||||
Randomly rotates image color around the L axis in LAB colorspace,
|
||||
keeping perceptual lightness constant.
|
||||
"""
|
||||
image = cv2.cvtColor(image.astype(np.float32), cv2.COLOR_BGR2LAB)
|
||||
M = np.eye(3)
|
||||
M[1:, 1:] = special_ortho_group.rvs(2, 1, seed)
|
||||
image = image.dot(M)
|
||||
l, a, b = cv2.split(image)
|
||||
l = np.clip(l, 0, 100)
|
||||
a = np.clip(a, -127, 127)
|
||||
b = np.clip(b, -127, 127)
|
||||
image = cv2.merge([l, a, b])
|
||||
image = cv2.cvtColor(image.astype(np.float32), cv2.COLOR_LAB2BGR)
|
||||
np.clip(image, 0, 1, out=image)
|
||||
return image
|
||||
|
||||
|
||||
def random_lab(image):
|
||||
""" Perform random color/lightness adjustment in L*a*b* colorspace """
|
||||
amount_l = 30 / 100
|
||||
|
|
|
@ -58,6 +58,7 @@ class SAEHDModel(ModelBase):
|
|||
default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0)
|
||||
default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0)
|
||||
default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none')
|
||||
default_random_color = self.options['random_color'] = self.load_or_def_option('random_color', False)
|
||||
default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False)
|
||||
default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False)
|
||||
|
||||
|
@ -167,6 +168,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn the area outside mask of the predicted face to be the same as dst. If you want to use this option with 'whole_face' you have to use XSeg trained mask. For whole_face you have to use XSeg trained mask. This can make face more like dst. Enabling this option increases the chance of model collapse. Typical value is 2.0"), 0.0, 100.0 )
|
||||
|
||||
self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot', 'fs-aug'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best. FS aug adds random color to dst and src")
|
||||
self.options['random_color'] = io.input_bool ("Random color", default_random_color, help_message="Samples are randomly rotated around the L axis in LAB colorspace, helps generalize training")
|
||||
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
|
||||
|
||||
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly.")
|
||||
|
@ -650,11 +652,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
if ct_mode == 'fs-aug':
|
||||
fs_aug = 'fs-aug'
|
||||
|
||||
channel_type = SampleProcessor.ChannelType.LAB_RAND_TRANSFORM if self.options['random_color'] else SampleProcessor.ChannelType.BGR
|
||||
|
||||
|
||||
self.set_training_data_generators ([
|
||||
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : channel_type, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : channel_type, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
|
@ -663,8 +668,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : channel_type, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : channel_type, 'ct_mode': fs_aug, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
|
|
|
@ -8,6 +8,7 @@ import numpy as np
|
|||
from core import imagelib
|
||||
from core.cv2ex import *
|
||||
from core.imagelib import sd
|
||||
from core.imagelib.color_transfer import random_lab_rotation
|
||||
from facelib import FaceType, LandmarksProcessor
|
||||
|
||||
|
||||
|
@ -26,6 +27,8 @@ class SampleProcessor(object):
|
|||
BGR = 1 #BGR
|
||||
G = 2 #Grayscale
|
||||
GGG = 3 #3xGrayscale
|
||||
LAB_RAND_TRANSFORM = 4 # LAB random transform
|
||||
|
||||
|
||||
class FaceMaskType(IntEnum):
|
||||
NONE = 0
|
||||
|
@ -231,6 +234,8 @@ class SampleProcessor(object):
|
|||
# Transform from BGR to desired channel_type
|
||||
if channel_type == SPCT.BGR:
|
||||
out_sample = img
|
||||
elif channel_type == SPCT.LAB_RAND_TRANSFORM:
|
||||
out_sample = random_lab_rotation(img)
|
||||
elif channel_type == SPCT.G:
|
||||
out_sample = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[...,None]
|
||||
elif channel_type == SPCT.GGG:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue