diff --git a/converters/ConverterMasked.py b/converters/ConverterMasked.py index 7cc233d..139ab2f 100644 --- a/converters/ConverterMasked.py +++ b/converters/ConverterMasked.py @@ -5,6 +5,7 @@ import cv2 import numpy as np import imagelib +import imagelib_legacy from facelib import FaceType, FANSegmentator, LandmarksProcessor from interact import interact as io from joblib import SubprocessFunctionCaller @@ -143,7 +144,7 @@ class ConverterMasked(Converter): # ColorTransferMode.NONE, ColorTransferMode.MASKED_RCT_PAPER_CLIP) self.color_transfer_mode = np.clip(io.input_int( - "Apply color transfer to predicted face? (0) None, (1) LCT ?:help skip:%s) : " % default_apply_random_ct, + "Apply color transfer to predicted face? (0) None, (1) LCT (2) RCT-legacy?:help skip:%s) : " % default_apply_random_ct, default_apply_random_ct, help_message="Increase variativity of src samples by apply color transfer from random dst " "samples. It is like 'face_style' learning, but more precise color transfer and without " @@ -152,8 +153,9 @@ class ConverterMasked(Converter): "diverse.\n\n" "===Modes===:\n" "(0) None - No transformation\n" - "(1) LCT - Linear Color Transfer"), - ColorTransferMode.NONE, ColorTransferMode.LCT) + "(1) LCT - Linear Color Transfer\n" + "(2) RCT - Reinhard Color Transfer (legacy/original version)"), + ColorTransferMode.NONE, ColorTransferMode.RCT) self.super_resolution = io.input_bool("Apply super resolution? (y/n ?:help skip:n) : ", False, help_message="Enhance details by applying DCSCN network.") @@ -369,6 +371,12 @@ class ConverterMasked(Converter): if self.color_transfer_mode == ColorTransferMode.LCT: prd_face_bgr = imagelib.linear_color_transfer(prd_face_bgr, dst_face_bgr) + if self.color_transfer_mode == ColorTransferMode.RCT: + prd_face_bgr = imagelib_legacy.reinhard_color_transfer ( np.clip( (prd_face_bgr*255).astype(np.uint8), 0, 255), + np.clip( (dst_face_bgr*255).astype(np.uint8), 0, 255), + source_mask=prd_face_mask_a, target_mask=prd_face_mask_a) + prd_face_bgr = np.clip( prd_face_bgr.astype(np.float32) / 255.0, 0.0, 1.0) + # FIXME # elif ColorTransferMode.RCT <= self.color_transfer_mode <= ColorTransferMode.MASKED_RCT_PAPER_CLIP: # ct_options = { @@ -489,6 +497,12 @@ class ConverterMasked(Converter): if self.color_transfer_mode == ColorTransferMode.LCT: new_out_face_bgr = imagelib.linear_color_transfer(out_face_bgr, dst_face_bgr) + if self.color_transfer_mode == ColorTransferMode.RCT: + new_out_face_bgr = imagelib_legacy.reinhard_color_transfer ( np.clip( (out_face_bgr*255).astype(np.uint8), 0, 255), + np.clip( (dst_face_bgr*255).astype(np.uint8), 0, 255), + source_mask=face_mask_blurry_aaa, target_mask=face_mask_blurry_aaa) + new_out_face_bgr = np.clip( new_out_face_bgr.astype(np.float32) / 255.0, 0.0, 1.0) + # FIXME # elif ColorTransferMode.RCT <= self.color_transfer_mode <= ColorTransferMode.MASKED_RCT_PAPER_CLIP: # ct_options = { diff --git a/imagelib_legacy/__init__.py b/imagelib_legacy/__init__.py new file mode 100644 index 0000000..6281e0e --- /dev/null +++ b/imagelib_legacy/__init__.py @@ -0,0 +1,3 @@ +from .color_transfer import color_hist_match +from .color_transfer import reinhard_color_transfer +from .color_transfer import linear_color_transfer diff --git a/imagelib_legacy/color_transfer.py b/imagelib_legacy/color_transfer.py new file mode 100644 index 0000000..545df3a --- /dev/null +++ b/imagelib_legacy/color_transfer.py @@ -0,0 +1,191 @@ +import numpy as np +import cv2 + +def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None): + """ + Transfers the color distribution from the source to the target + image using the mean and standard deviations of the L*a*b* + color space. + + This implementation is (loosely) based on to the "Color Transfer + between Images" paper by Reinhard et al., 2001. + + Parameters: + ------- + source: NumPy array + OpenCV image in BGR color space (the source image) + target: NumPy array + OpenCV image in BGR color space (the target image) + clip: Should components of L*a*b* image be scaled by np.clip before + converting back to BGR color space? + If False then components will be min-max scaled appropriately. + Clipping will keep target image brightness truer to the input. + Scaling will adjust image brightness to avoid washed out portions + in the resulting color transfer that can be caused by clipping. + preserve_paper: Should color transfer strictly follow methodology + layed out in original paper? The method does not always produce + aesthetically pleasing results. + If False then L*a*b* components will scaled using the reciprocal of + the scaling factor proposed in the paper. This method seems to produce + more consistently aesthetically pleasing results + + Returns: + ------- + transfer: NumPy array + OpenCV image (w, h, 3) NumPy array (uint8) + """ + + + # convert the images from the RGB to L*ab* color space, being + # sure to utilizing the floating point data type (note: OpenCV + # expects floats to be 32-bit, so use that instead of 64-bit) + source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32) + target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32) + + # compute color statistics for the source and target images + src_input = source if source_mask is None else source*source_mask + tgt_input = target if target_mask is None else target*target_mask + (lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input) + (lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input) + + # subtract the means from the target image + (l, a, b) = cv2.split(target) + l -= lMeanTar + a -= aMeanTar + b -= bMeanTar + + if preserve_paper: + # scale by the standard deviations using paper proposed factor + l = (lStdTar / lStdSrc) * l + a = (aStdTar / aStdSrc) * a + b = (bStdTar / bStdSrc) * b + else: + # scale by the standard deviations using reciprocal of paper proposed factor + l = (lStdSrc / lStdTar) * l + a = (aStdSrc / aStdTar) * a + b = (bStdSrc / bStdTar) * b + + # add in the source mean + l += lMeanSrc + a += aMeanSrc + b += bMeanSrc + + # clip/scale the pixel intensities to [0, 255] if they fall + # outside this range + l = _scale_array(l, clip=clip) + a = _scale_array(a, clip=clip) + b = _scale_array(b, clip=clip) + + # merge the channels together and convert back to the RGB color + # space, being sure to utilize the 8-bit unsigned integer data + # type + transfer = cv2.merge([l, a, b]) + transfer = cv2.cvtColor(transfer.astype(np.uint8), cv2.COLOR_LAB2BGR) + + # return the color transferred image + return transfer + +def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5): + ''' + Matches the colour distribution of the target image to that of the source image + using a linear transform. + Images are expected to be of form (w,h,c) and float in [0,1]. + Modes are chol, pca or sym for different choices of basis. + ''' + mu_t = target_img.mean(0).mean(0) + t = target_img - mu_t + t = t.transpose(2,0,1).reshape(3,-1) + Ct = t.dot(t.T) / t.shape[1] + eps * np.eye(t.shape[0]) + mu_s = source_img.mean(0).mean(0) + s = source_img - mu_s + s = s.transpose(2,0,1).reshape(3,-1) + Cs = s.dot(s.T) / s.shape[1] + eps * np.eye(s.shape[0]) + if mode == 'chol': + chol_t = np.linalg.cholesky(Ct) + chol_s = np.linalg.cholesky(Cs) + ts = chol_s.dot(np.linalg.inv(chol_t)).dot(t) + if mode == 'pca': + eva_t, eve_t = np.linalg.eigh(Ct) + Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T) + eva_s, eve_s = np.linalg.eigh(Cs) + Qs = eve_s.dot(np.sqrt(np.diag(eva_s))).dot(eve_s.T) + ts = Qs.dot(np.linalg.inv(Qt)).dot(t) + if mode == 'sym': + eva_t, eve_t = np.linalg.eigh(Ct) + Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T) + Qt_Cs_Qt = Qt.dot(Cs).dot(Qt) + eva_QtCsQt, eve_QtCsQt = np.linalg.eigh(Qt_Cs_Qt) + QtCsQt = eve_QtCsQt.dot(np.sqrt(np.diag(eva_QtCsQt))).dot(eve_QtCsQt.T) + ts = np.linalg.inv(Qt).dot(QtCsQt).dot(np.linalg.inv(Qt)).dot(t) + matched_img = ts.reshape(*target_img.transpose(2,0,1).shape).transpose(1,2,0) + matched_img += mu_s + matched_img[matched_img>1] = 1 + matched_img[matched_img<0] = 0 + return matched_img + +def lab_image_stats(image): + # compute the mean and standard deviation of each channel + (l, a, b) = cv2.split(image) + (lMean, lStd) = (l.mean(), l.std()) + (aMean, aStd) = (a.mean(), a.std()) + (bMean, bStd) = (b.mean(), b.std()) + + # return the color statistics + return (lMean, lStd, aMean, aStd, bMean, bStd) + +def _scale_array(arr, clip=True): + if clip: + return np.clip(arr, 0, 255) + + mn = arr.min() + mx = arr.max() + scale_range = (max([mn, 0]), min([mx, 255])) + + if mn < scale_range[0] or mx > scale_range[1]: + return (scale_range[1] - scale_range[0]) * (arr - mn) / (mx - mn) + scale_range[0] + + return arr + +def channel_hist_match(source, template, hist_match_threshold=255, mask=None): + # Code borrowed from: + # https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x + masked_source = source + masked_template = template + + if mask is not None: + masked_source = source * mask + masked_template = template * mask + + oldshape = source.shape + source = source.ravel() + template = template.ravel() + masked_source = masked_source.ravel() + masked_template = masked_template.ravel() + s_values, bin_idx, s_counts = np.unique(source, return_inverse=True, + return_counts=True) + t_values, t_counts = np.unique(template, return_counts=True) + ms_values, mbin_idx, ms_counts = np.unique(source, return_inverse=True, + return_counts=True) + mt_values, mt_counts = np.unique(template, return_counts=True) + + s_quantiles = np.cumsum(s_counts).astype(np.float64) + s_quantiles = hist_match_threshold * s_quantiles / s_quantiles[-1] + t_quantiles = np.cumsum(t_counts).astype(np.float64) + t_quantiles = 255 * t_quantiles / t_quantiles[-1] + interp_t_values = np.interp(s_quantiles, t_quantiles, t_values) + + return interp_t_values[bin_idx].reshape(oldshape) + +def color_hist_match(src_im, tar_im, hist_match_threshold=255): + h,w,c = src_im.shape + matched_R = channel_hist_match(src_im[:,:,0], tar_im[:,:,0], hist_match_threshold, None) + matched_G = channel_hist_match(src_im[:,:,1], tar_im[:,:,1], hist_match_threshold, None) + matched_B = channel_hist_match(src_im[:,:,2], tar_im[:,:,2], hist_match_threshold, None) + + to_stack = (matched_R, matched_G, matched_B) + for i in range(3, c): + to_stack += ( src_im[:,:,i],) + + + matched = np.stack(to_stack, axis=-1).astype(src_im.dtype) + return matched