Revert Converted RCT to legacy mode

This commit is contained in:
Jeremy Hummel 2019-08-14 13:18:35 -07:00
commit 89327ae49d
3 changed files with 211 additions and 3 deletions

View file

@ -5,6 +5,7 @@ import cv2
import numpy as np import numpy as np
import imagelib import imagelib
import imagelib_legacy
from facelib import FaceType, FANSegmentator, LandmarksProcessor from facelib import FaceType, FANSegmentator, LandmarksProcessor
from interact import interact as io from interact import interact as io
from joblib import SubprocessFunctionCaller from joblib import SubprocessFunctionCaller
@ -143,7 +144,7 @@ class ConverterMasked(Converter):
# ColorTransferMode.NONE, ColorTransferMode.MASKED_RCT_PAPER_CLIP) # ColorTransferMode.NONE, ColorTransferMode.MASKED_RCT_PAPER_CLIP)
self.color_transfer_mode = np.clip(io.input_int( self.color_transfer_mode = np.clip(io.input_int(
"Apply color transfer to predicted face? (0) None, (1) LCT ?:help skip:%s) : " % default_apply_random_ct, "Apply color transfer to predicted face? (0) None, (1) LCT (2) RCT-legacy?:help skip:%s) : " % default_apply_random_ct,
default_apply_random_ct, default_apply_random_ct,
help_message="Increase variativity of src samples by apply color transfer from random dst " help_message="Increase variativity of src samples by apply color transfer from random dst "
"samples. It is like 'face_style' learning, but more precise color transfer and without " "samples. It is like 'face_style' learning, but more precise color transfer and without "
@ -152,8 +153,9 @@ class ConverterMasked(Converter):
"diverse.\n\n" "diverse.\n\n"
"===Modes===:\n" "===Modes===:\n"
"(0) None - No transformation\n" "(0) None - No transformation\n"
"(1) LCT - Linear Color Transfer"), "(1) LCT - Linear Color Transfer\n"
ColorTransferMode.NONE, ColorTransferMode.LCT) "(2) RCT - Reinhard Color Transfer (legacy/original version)"),
ColorTransferMode.NONE, ColorTransferMode.RCT)
self.super_resolution = io.input_bool("Apply super resolution? (y/n ?:help skip:n) : ", False, self.super_resolution = io.input_bool("Apply super resolution? (y/n ?:help skip:n) : ", False,
help_message="Enhance details by applying DCSCN network.") help_message="Enhance details by applying DCSCN network.")
@ -369,6 +371,12 @@ class ConverterMasked(Converter):
if self.color_transfer_mode == ColorTransferMode.LCT: if self.color_transfer_mode == ColorTransferMode.LCT:
prd_face_bgr = imagelib.linear_color_transfer(prd_face_bgr, dst_face_bgr) prd_face_bgr = imagelib.linear_color_transfer(prd_face_bgr, dst_face_bgr)
if self.color_transfer_mode == ColorTransferMode.RCT:
prd_face_bgr = imagelib_legacy.reinhard_color_transfer ( np.clip( (prd_face_bgr*255).astype(np.uint8), 0, 255),
np.clip( (dst_face_bgr*255).astype(np.uint8), 0, 255),
source_mask=prd_face_mask_a, target_mask=prd_face_mask_a)
prd_face_bgr = np.clip( prd_face_bgr.astype(np.float32) / 255.0, 0.0, 1.0)
# FIXME # FIXME
# elif ColorTransferMode.RCT <= self.color_transfer_mode <= ColorTransferMode.MASKED_RCT_PAPER_CLIP: # elif ColorTransferMode.RCT <= self.color_transfer_mode <= ColorTransferMode.MASKED_RCT_PAPER_CLIP:
# ct_options = { # ct_options = {
@ -489,6 +497,12 @@ class ConverterMasked(Converter):
if self.color_transfer_mode == ColorTransferMode.LCT: if self.color_transfer_mode == ColorTransferMode.LCT:
new_out_face_bgr = imagelib.linear_color_transfer(out_face_bgr, dst_face_bgr) new_out_face_bgr = imagelib.linear_color_transfer(out_face_bgr, dst_face_bgr)
if self.color_transfer_mode == ColorTransferMode.RCT:
new_out_face_bgr = imagelib_legacy.reinhard_color_transfer ( np.clip( (out_face_bgr*255).astype(np.uint8), 0, 255),
np.clip( (dst_face_bgr*255).astype(np.uint8), 0, 255),
source_mask=face_mask_blurry_aaa, target_mask=face_mask_blurry_aaa)
new_out_face_bgr = np.clip( new_out_face_bgr.astype(np.float32) / 255.0, 0.0, 1.0)
# FIXME # FIXME
# elif ColorTransferMode.RCT <= self.color_transfer_mode <= ColorTransferMode.MASKED_RCT_PAPER_CLIP: # elif ColorTransferMode.RCT <= self.color_transfer_mode <= ColorTransferMode.MASKED_RCT_PAPER_CLIP:
# ct_options = { # ct_options = {

View file

@ -0,0 +1,3 @@
from .color_transfer import color_hist_match
from .color_transfer import reinhard_color_transfer
from .color_transfer import linear_color_transfer

View file

@ -0,0 +1,191 @@
import numpy as np
import cv2
def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None):
"""
Transfers the color distribution from the source to the target
image using the mean and standard deviations of the L*a*b*
color space.
This implementation is (loosely) based on to the "Color Transfer
between Images" paper by Reinhard et al., 2001.
Parameters:
-------
source: NumPy array
OpenCV image in BGR color space (the source image)
target: NumPy array
OpenCV image in BGR color space (the target image)
clip: Should components of L*a*b* image be scaled by np.clip before
converting back to BGR color space?
If False then components will be min-max scaled appropriately.
Clipping will keep target image brightness truer to the input.
Scaling will adjust image brightness to avoid washed out portions
in the resulting color transfer that can be caused by clipping.
preserve_paper: Should color transfer strictly follow methodology
layed out in original paper? The method does not always produce
aesthetically pleasing results.
If False then L*a*b* components will scaled using the reciprocal of
the scaling factor proposed in the paper. This method seems to produce
more consistently aesthetically pleasing results
Returns:
-------
transfer: NumPy array
OpenCV image (w, h, 3) NumPy array (uint8)
"""
# convert the images from the RGB to L*ab* color space, being
# sure to utilizing the floating point data type (note: OpenCV
# expects floats to be 32-bit, so use that instead of 64-bit)
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
# compute color statistics for the source and target images
src_input = source if source_mask is None else source*source_mask
tgt_input = target if target_mask is None else target*target_mask
(lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input)
(lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input)
# subtract the means from the target image
(l, a, b) = cv2.split(target)
l -= lMeanTar
a -= aMeanTar
b -= bMeanTar
if preserve_paper:
# scale by the standard deviations using paper proposed factor
l = (lStdTar / lStdSrc) * l
a = (aStdTar / aStdSrc) * a
b = (bStdTar / bStdSrc) * b
else:
# scale by the standard deviations using reciprocal of paper proposed factor
l = (lStdSrc / lStdTar) * l
a = (aStdSrc / aStdTar) * a
b = (bStdSrc / bStdTar) * b
# add in the source mean
l += lMeanSrc
a += aMeanSrc
b += bMeanSrc
# clip/scale the pixel intensities to [0, 255] if they fall
# outside this range
l = _scale_array(l, clip=clip)
a = _scale_array(a, clip=clip)
b = _scale_array(b, clip=clip)
# merge the channels together and convert back to the RGB color
# space, being sure to utilize the 8-bit unsigned integer data
# type
transfer = cv2.merge([l, a, b])
transfer = cv2.cvtColor(transfer.astype(np.uint8), cv2.COLOR_LAB2BGR)
# return the color transferred image
return transfer
def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5):
'''
Matches the colour distribution of the target image to that of the source image
using a linear transform.
Images are expected to be of form (w,h,c) and float in [0,1].
Modes are chol, pca or sym for different choices of basis.
'''
mu_t = target_img.mean(0).mean(0)
t = target_img - mu_t
t = t.transpose(2,0,1).reshape(3,-1)
Ct = t.dot(t.T) / t.shape[1] + eps * np.eye(t.shape[0])
mu_s = source_img.mean(0).mean(0)
s = source_img - mu_s
s = s.transpose(2,0,1).reshape(3,-1)
Cs = s.dot(s.T) / s.shape[1] + eps * np.eye(s.shape[0])
if mode == 'chol':
chol_t = np.linalg.cholesky(Ct)
chol_s = np.linalg.cholesky(Cs)
ts = chol_s.dot(np.linalg.inv(chol_t)).dot(t)
if mode == 'pca':
eva_t, eve_t = np.linalg.eigh(Ct)
Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T)
eva_s, eve_s = np.linalg.eigh(Cs)
Qs = eve_s.dot(np.sqrt(np.diag(eva_s))).dot(eve_s.T)
ts = Qs.dot(np.linalg.inv(Qt)).dot(t)
if mode == 'sym':
eva_t, eve_t = np.linalg.eigh(Ct)
Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T)
Qt_Cs_Qt = Qt.dot(Cs).dot(Qt)
eva_QtCsQt, eve_QtCsQt = np.linalg.eigh(Qt_Cs_Qt)
QtCsQt = eve_QtCsQt.dot(np.sqrt(np.diag(eva_QtCsQt))).dot(eve_QtCsQt.T)
ts = np.linalg.inv(Qt).dot(QtCsQt).dot(np.linalg.inv(Qt)).dot(t)
matched_img = ts.reshape(*target_img.transpose(2,0,1).shape).transpose(1,2,0)
matched_img += mu_s
matched_img[matched_img>1] = 1
matched_img[matched_img<0] = 0
return matched_img
def lab_image_stats(image):
# compute the mean and standard deviation of each channel
(l, a, b) = cv2.split(image)
(lMean, lStd) = (l.mean(), l.std())
(aMean, aStd) = (a.mean(), a.std())
(bMean, bStd) = (b.mean(), b.std())
# return the color statistics
return (lMean, lStd, aMean, aStd, bMean, bStd)
def _scale_array(arr, clip=True):
if clip:
return np.clip(arr, 0, 255)
mn = arr.min()
mx = arr.max()
scale_range = (max([mn, 0]), min([mx, 255]))
if mn < scale_range[0] or mx > scale_range[1]:
return (scale_range[1] - scale_range[0]) * (arr - mn) / (mx - mn) + scale_range[0]
return arr
def channel_hist_match(source, template, hist_match_threshold=255, mask=None):
# Code borrowed from:
# https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x
masked_source = source
masked_template = template
if mask is not None:
masked_source = source * mask
masked_template = template * mask
oldshape = source.shape
source = source.ravel()
template = template.ravel()
masked_source = masked_source.ravel()
masked_template = masked_template.ravel()
s_values, bin_idx, s_counts = np.unique(source, return_inverse=True,
return_counts=True)
t_values, t_counts = np.unique(template, return_counts=True)
ms_values, mbin_idx, ms_counts = np.unique(source, return_inverse=True,
return_counts=True)
mt_values, mt_counts = np.unique(template, return_counts=True)
s_quantiles = np.cumsum(s_counts).astype(np.float64)
s_quantiles = hist_match_threshold * s_quantiles / s_quantiles[-1]
t_quantiles = np.cumsum(t_counts).astype(np.float64)
t_quantiles = 255 * t_quantiles / t_quantiles[-1]
interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)
return interp_t_values[bin_idx].reshape(oldshape)
def color_hist_match(src_im, tar_im, hist_match_threshold=255):
h,w,c = src_im.shape
matched_R = channel_hist_match(src_im[:,:,0], tar_im[:,:,0], hist_match_threshold, None)
matched_G = channel_hist_match(src_im[:,:,1], tar_im[:,:,1], hist_match_threshold, None)
matched_B = channel_hist_match(src_im[:,:,2], tar_im[:,:,2], hist_match_threshold, None)
to_stack = (matched_R, matched_G, matched_B)
for i in range(3, c):
to_stack += ( src_im[:,:,i],)
matched = np.stack(to_stack, axis=-1).astype(src_im.dtype)
return matched