diff --git a/imagelib/color_transfer.py b/imagelib/color_transfer.py index 20cc12c..4e2b99d 100644 --- a/imagelib/color_transfer.py +++ b/imagelib/color_transfer.py @@ -6,6 +6,8 @@ import numpy as np import scipy as sp import scipy.sparse from scipy.sparse.linalg import spsolve +from scipy.stats import special_ortho_group + class ColorTransferMode(IntEnum): NONE = 0 @@ -145,6 +147,20 @@ def seamless_clone(source, target, mask): return np.clip( np.dstack(result), 0, 1 ) +def random_color_transform(image, seed=None): + image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) + M = np.eye(3) + M[1:, 1:] = special_ortho_group.rvs(2, 1, seed) + image = image.dot(M) + l, a, b = cv2.split(image) + l = np.clip(l, 0, 100) + a = np.clip(a, -127, 127) + b = np.clip(b, -127, 127) + image = cv2.merge([l, a, b]) + image = cv2.cvtColor(image.astype(np.float32), cv2.COLOR_LAB2BGR) + np.clip(image, 0, 1, out=image) + return image + def reinhard_color_transfer(source, target, clip=False, preserve_paper=False, source_mask=None, target_mask=None): """ Transfers the color distribution from the target to the source diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index e806782..4ac420a 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -142,6 +142,11 @@ class SAEModel(ModelBase): "but the training time may be longer, due to the src faceset is becoming more diverse."), ColorTransferMode.NONE, ColorTransferMode.MASKED_RCT_PAPER_CLIP) + default_random_color_change = False if is_first_run else self.options.get('random_color_change', False) + self.options['random_color_change'] = io.input_bool( + "Enable random color change? (y/n, ?:help skip:%s) : " % (yn_str[default_random_color_change]), default_random_color_change, + help_message="") + if nnlib.device.backend != 'plaidML': # todo https://github.com/plaidml/plaidml/issues/301 default_clipgrad = False if is_first_run else self.options.get('clipgrad', False) self.options['clipgrad'] = io.input_bool( @@ -464,7 +469,9 @@ class SAEModel(ModelBase): face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF global t_mode_bgr - t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE + t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_LAB_RAND_TRANSFORM + if self.options['random_color_change']: + t_mode_bgr = t.MODE_LAB_RAND_TRANSFORM global training_data_src_path training_data_src_path = self.training_data_src_path diff --git a/samplelib/SampleProcessor.py b/samplelib/SampleProcessor.py index 7d7c4b5..9232bd8 100644 --- a/samplelib/SampleProcessor.py +++ b/samplelib/SampleProcessor.py @@ -6,7 +6,7 @@ import numpy as np import imagelib from facelib import FaceType, LandmarksProcessor -from imagelib.color_transfer import ColorTransferMode +from imagelib.color_transfer import ColorTransferMode, random_color_transform """ output_sample_types = [ @@ -72,6 +72,7 @@ class SampleProcessor(object): MODE_GGG = 42 #3xGrayscale MODE_M = 43 #mask only MODE_BGR_SHUFFLE = 44 #BGR shuffle + MODE_LAB_RAND_TRANSFORM = 45 #Random transform in LAB space MODE_END = 50 class Options(object): @@ -281,6 +282,9 @@ class SampleProcessor(object): elif mode_type == SPTF.MODE_BGR_SHUFFLE: rnd_state = np.random.RandomState (sample_rnd_seed) img = np.take (img_bgr, rnd_state.permutation(img_bgr.shape[-1]), axis=-1) + elif mode_type == SPTF.MODE_LAB_RAND_TRANSFORM: + rnd_state = np.random.RandomState (sample_rnd_seed) + img = random_color_transform(img_bgr, rnd_state) elif mode_type == SPTF.MODE_G: img = np.concatenate ( (np.expand_dims(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY),-1),img_mask) , -1 ) elif mode_type == SPTF.MODE_GGG: