Formatting

This commit is contained in:
Jeremy Hummel 2019-08-10 00:47:48 -07:00
commit 62c7be73b6

View file

@ -1,104 +1,105 @@
import numpy as np import numpy as np
import cv2 import cv2
def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None): def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None):
""" """
Transfers the color distribution from the source to the target Transfers the color distribution from the source to the target
image using the mean and standard deviations of the L*a*b* image using the mean and standard deviations of the L*a*b*
color space. color space.
This implementation is (loosely) based on to the "Color Transfer This implementation is (loosely) based on to the "Color Transfer
between Images" paper by Reinhard et al., 2001. between Images" paper by Reinhard et al., 2001.
Parameters: Parameters:
------- -------
source: NumPy array source: NumPy array
OpenCV image in BGR color space (the source image) OpenCV image in BGR color space (the source image)
target: NumPy array target: NumPy array
OpenCV image in BGR color space (the target image) OpenCV image in BGR color space (the target image)
clip: Should components of L*a*b* image be scaled by np.clip before clip: Should components of L*a*b* image be scaled by np.clip before
converting back to BGR color space? converting back to BGR color space?
If False then components will be min-max scaled appropriately. If False then components will be min-max scaled appropriately.
Clipping will keep target image brightness truer to the input. Clipping will keep target image brightness truer to the input.
Scaling will adjust image brightness to avoid washed out portions Scaling will adjust image brightness to avoid washed out portions
in the resulting color transfer that can be caused by clipping. in the resulting color transfer that can be caused by clipping.
preserve_paper: Should color transfer strictly follow methodology preserve_paper: Should color transfer strictly follow methodology
layed out in original paper? The method does not always produce layed out in original paper? The method does not always produce
aesthetically pleasing results. aesthetically pleasing results.
If False then L*a*b* components will scaled using the reciprocal of If False then L*a*b* components will scaled using the reciprocal of
the scaling factor proposed in the paper. This method seems to produce the scaling factor proposed in the paper. This method seems to produce
more consistently aesthetically pleasing results more consistently aesthetically pleasing results
Returns: Returns:
------- -------
transfer: NumPy array transfer: NumPy array
OpenCV image (w, h, 3) NumPy array (uint8) OpenCV image (w, h, 3) NumPy array (uint8)
""" """
# convert the images from the RGB to L*ab* color space, being
# sure to utilizing the floating point data type (note: OpenCV
# expects floats to be 32-bit, so use that instead of 64-bit)
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
# convert the images from the RGB to L*ab* color space, being # compute color statistics for the source and target images
# sure to utilizing the floating point data type (note: OpenCV src_input = source if source_mask is None else source * source_mask
# expects floats to be 32-bit, so use that instead of 64-bit) tgt_input = target if target_mask is None else target * target_mask
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32) (lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input)
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32) (lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input)
# compute color statistics for the source and target images # subtract the means from the target image
src_input = source if source_mask is None else source*source_mask (l, a, b) = cv2.split(target)
tgt_input = target if target_mask is None else target*target_mask l -= lMeanTar
(lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input) a -= aMeanTar
(lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input) b -= bMeanTar
# subtract the means from the target image if preserve_paper:
(l, a, b) = cv2.split(target) # scale by the standard deviations using paper proposed factor
l -= lMeanTar l = (lStdTar / lStdSrc) * l
a -= aMeanTar a = (aStdTar / aStdSrc) * a
b -= bMeanTar b = (bStdTar / bStdSrc) * b
else:
# scale by the standard deviations using reciprocal of paper proposed factor
l = (lStdSrc / lStdTar) * l
a = (aStdSrc / aStdTar) * a
b = (bStdSrc / bStdTar) * b
if preserve_paper: # add in the source mean
# scale by the standard deviations using paper proposed factor l += lMeanSrc
l = (lStdTar / lStdSrc) * l a += aMeanSrc
a = (aStdTar / aStdSrc) * a b += bMeanSrc
b = (bStdTar / bStdSrc) * b
else:
# scale by the standard deviations using reciprocal of paper proposed factor
l = (lStdSrc / lStdTar) * l
a = (aStdSrc / aStdTar) * a
b = (bStdSrc / bStdTar) * b
# add in the source mean # clip/scale the pixel intensities to [0, 255] if they fall
l += lMeanSrc # outside this range
a += aMeanSrc l = _scale_array(l, clip=clip)
b += bMeanSrc a = _scale_array(a, clip=clip)
b = _scale_array(b, clip=clip)
# clip/scale the pixel intensities to [0, 255] if they fall # merge the channels together and convert back to the RGB color
# outside this range # space, being sure to utilize the 8-bit unsigned integer data
l = _scale_array(l, clip=clip) # type
a = _scale_array(a, clip=clip) transfer = cv2.merge([l, a, b])
b = _scale_array(b, clip=clip) transfer = cv2.cvtColor(transfer.astype(np.uint8), cv2.COLOR_LAB2BGR)
# merge the channels together and convert back to the RGB color # return the color transferred image
# space, being sure to utilize the 8-bit unsigned integer data return transfer
# type
transfer = cv2.merge([l, a, b])
transfer = cv2.cvtColor(transfer.astype(np.uint8), cv2.COLOR_LAB2BGR)
# return the color transferred image
return transfer
def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5): def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5):
''' """
Matches the colour distribution of the target image to that of the source image Matches the colour distribution of the target image to that of the source image
using a linear transform. using a linear transform.
Images are expected to be of form (w,h,c) and float in [0,1]. Images are expected to be of form (w,h,c) and float in [0,1].
Modes are chol, pca or sym for different choices of basis. Modes are chol, pca or sym for different choices of basis.
''' """
mu_t = target_img.mean(0).mean(0) mu_t = target_img.mean(0).mean(0)
t = target_img - mu_t t = target_img - mu_t
t = t.transpose(2,0,1).reshape(3,-1) t = t.transpose(2, 0, 1).reshape(3, -1)
Ct = t.dot(t.T) / t.shape[1] + eps * np.eye(t.shape[0]) Ct = t.dot(t.T) / t.shape[1] + eps * np.eye(t.shape[0])
mu_s = source_img.mean(0).mean(0) mu_s = source_img.mean(0).mean(0)
s = source_img - mu_s s = source_img - mu_s
s = s.transpose(2,0,1).reshape(3,-1) s = s.transpose(2, 0, 1).reshape(3, -1)
Cs = s.dot(s.T) / s.shape[1] + eps * np.eye(s.shape[0]) Cs = s.dot(s.T) / s.shape[1] + eps * np.eye(s.shape[0])
if mode == 'chol': if mode == 'chol':
chol_t = np.linalg.cholesky(Ct) chol_t = np.linalg.cholesky(Ct)
@ -117,12 +118,13 @@ def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5):
eva_QtCsQt, eve_QtCsQt = np.linalg.eigh(Qt_Cs_Qt) eva_QtCsQt, eve_QtCsQt = np.linalg.eigh(Qt_Cs_Qt)
QtCsQt = eve_QtCsQt.dot(np.sqrt(np.diag(eva_QtCsQt))).dot(eve_QtCsQt.T) QtCsQt = eve_QtCsQt.dot(np.sqrt(np.diag(eva_QtCsQt))).dot(eve_QtCsQt.T)
ts = np.linalg.inv(Qt).dot(QtCsQt).dot(np.linalg.inv(Qt)).dot(t) ts = np.linalg.inv(Qt).dot(QtCsQt).dot(np.linalg.inv(Qt)).dot(t)
matched_img = ts.reshape(*target_img.transpose(2,0,1).shape).transpose(1,2,0) matched_img = ts.reshape(*target_img.transpose(2, 0, 1).shape).transpose(1, 2, 0)
matched_img += mu_s matched_img += mu_s
matched_img[matched_img>1] = 1 matched_img[matched_img > 1] = 1
matched_img[matched_img<0] = 0 matched_img[matched_img < 0] = 0
return matched_img return matched_img
def lab_image_stats(image): def lab_image_stats(image):
# compute the mean and standard deviation of each channel # compute the mean and standard deviation of each channel
(l, a, b) = cv2.split(image) (l, a, b) = cv2.split(image)
@ -133,6 +135,7 @@ def lab_image_stats(image):
# return the color statistics # return the color statistics
return (lMean, lStd, aMean, aStd, bMean, bStd) return (lMean, lStd, aMean, aStd, bMean, bStd)
def _scale_array(arr, clip=True): def _scale_array(arr, clip=True):
if clip: if clip:
return np.clip(arr, 0, 255) return np.clip(arr, 0, 255)
@ -146,6 +149,7 @@ def _scale_array(arr, clip=True):
return arr return arr
def channel_hist_match(source, template, hist_match_threshold=255, mask=None): def channel_hist_match(source, template, hist_match_threshold=255, mask=None):
# Code borrowed from: # Code borrowed from:
# https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x # https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x
@ -165,7 +169,7 @@ def channel_hist_match(source, template, hist_match_threshold=255, mask=None):
return_counts=True) return_counts=True)
t_values, t_counts = np.unique(template, return_counts=True) t_values, t_counts = np.unique(template, return_counts=True)
ms_values, mbin_idx, ms_counts = np.unique(source, return_inverse=True, ms_values, mbin_idx, ms_counts = np.unique(source, return_inverse=True,
return_counts=True) return_counts=True)
mt_values, mt_counts = np.unique(template, return_counts=True) mt_values, mt_counts = np.unique(template, return_counts=True)
s_quantiles = np.cumsum(s_counts).astype(np.float64) s_quantiles = np.cumsum(s_counts).astype(np.float64)
@ -176,16 +180,16 @@ def channel_hist_match(source, template, hist_match_threshold=255, mask=None):
return interp_t_values[bin_idx].reshape(oldshape) return interp_t_values[bin_idx].reshape(oldshape)
def color_hist_match(src_im, tar_im, hist_match_threshold=255): def color_hist_match(src_im, tar_im, hist_match_threshold=255):
h,w,c = src_im.shape h, w, c = src_im.shape
matched_R = channel_hist_match(src_im[:,:,0], tar_im[:,:,0], hist_match_threshold, None) matched_R = channel_hist_match(src_im[:, :, 0], tar_im[:, :, 0], hist_match_threshold, None)
matched_G = channel_hist_match(src_im[:,:,1], tar_im[:,:,1], hist_match_threshold, None) matched_G = channel_hist_match(src_im[:, :, 1], tar_im[:, :, 1], hist_match_threshold, None)
matched_B = channel_hist_match(src_im[:,:,2], tar_im[:,:,2], hist_match_threshold, None) matched_B = channel_hist_match(src_im[:, :, 2], tar_im[:, :, 2], hist_match_threshold, None)
to_stack = (matched_R, matched_G, matched_B) to_stack = (matched_R, matched_G, matched_B)
for i in range(3, c): for i in range(3, c):
to_stack += ( src_im[:,:,i],) to_stack += (src_im[:, :, i],)
matched = np.stack(to_stack, axis=-1).astype(src_im.dtype) matched = np.stack(to_stack, axis=-1).astype(src_im.dtype)
return matched return matched