Formatting

This commit is contained in:
Jeremy Hummel 2019-08-10 00:47:48 -07:00
commit 62c7be73b6

View file

@ -1,6 +1,7 @@
import numpy as np
import cv2
def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None):
"""
Transfers the color distribution from the source to the target
@ -35,7 +36,6 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, so
OpenCV image (w, h, 3) NumPy array (uint8)
"""
# convert the images from the RGB to L*ab* color space, being
# sure to utilizing the floating point data type (note: OpenCV
# expects floats to be 32-bit, so use that instead of 64-bit)
@ -43,8 +43,8 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, so
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
# compute color statistics for the source and target images
src_input = source if source_mask is None else source*source_mask
tgt_input = target if target_mask is None else target*target_mask
src_input = source if source_mask is None else source * source_mask
tgt_input = target if target_mask is None else target * target_mask
(lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input)
(lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input)
@ -85,20 +85,21 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, so
# return the color transferred image
return transfer
def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5):
'''
"""
Matches the colour distribution of the target image to that of the source image
using a linear transform.
Images are expected to be of form (w,h,c) and float in [0,1].
Modes are chol, pca or sym for different choices of basis.
'''
"""
mu_t = target_img.mean(0).mean(0)
t = target_img - mu_t
t = t.transpose(2,0,1).reshape(3,-1)
t = t.transpose(2, 0, 1).reshape(3, -1)
Ct = t.dot(t.T) / t.shape[1] + eps * np.eye(t.shape[0])
mu_s = source_img.mean(0).mean(0)
s = source_img - mu_s
s = s.transpose(2,0,1).reshape(3,-1)
s = s.transpose(2, 0, 1).reshape(3, -1)
Cs = s.dot(s.T) / s.shape[1] + eps * np.eye(s.shape[0])
if mode == 'chol':
chol_t = np.linalg.cholesky(Ct)
@ -117,12 +118,13 @@ def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5):
eva_QtCsQt, eve_QtCsQt = np.linalg.eigh(Qt_Cs_Qt)
QtCsQt = eve_QtCsQt.dot(np.sqrt(np.diag(eva_QtCsQt))).dot(eve_QtCsQt.T)
ts = np.linalg.inv(Qt).dot(QtCsQt).dot(np.linalg.inv(Qt)).dot(t)
matched_img = ts.reshape(*target_img.transpose(2,0,1).shape).transpose(1,2,0)
matched_img = ts.reshape(*target_img.transpose(2, 0, 1).shape).transpose(1, 2, 0)
matched_img += mu_s
matched_img[matched_img>1] = 1
matched_img[matched_img<0] = 0
matched_img[matched_img > 1] = 1
matched_img[matched_img < 0] = 0
return matched_img
def lab_image_stats(image):
# compute the mean and standard deviation of each channel
(l, a, b) = cv2.split(image)
@ -133,6 +135,7 @@ def lab_image_stats(image):
# return the color statistics
return (lMean, lStd, aMean, aStd, bMean, bStd)
def _scale_array(arr, clip=True):
if clip:
return np.clip(arr, 0, 255)
@ -146,6 +149,7 @@ def _scale_array(arr, clip=True):
return arr
def channel_hist_match(source, template, hist_match_threshold=255, mask=None):
# Code borrowed from:
# https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x
@ -176,16 +180,16 @@ def channel_hist_match(source, template, hist_match_threshold=255, mask=None):
return interp_t_values[bin_idx].reshape(oldshape)
def color_hist_match(src_im, tar_im, hist_match_threshold=255):
h,w,c = src_im.shape
matched_R = channel_hist_match(src_im[:,:,0], tar_im[:,:,0], hist_match_threshold, None)
matched_G = channel_hist_match(src_im[:,:,1], tar_im[:,:,1], hist_match_threshold, None)
matched_B = channel_hist_match(src_im[:,:,2], tar_im[:,:,2], hist_match_threshold, None)
h, w, c = src_im.shape
matched_R = channel_hist_match(src_im[:, :, 0], tar_im[:, :, 0], hist_match_threshold, None)
matched_G = channel_hist_match(src_im[:, :, 1], tar_im[:, :, 1], hist_match_threshold, None)
matched_B = channel_hist_match(src_im[:, :, 2], tar_im[:, :, 2], hist_match_threshold, None)
to_stack = (matched_R, matched_G, matched_B)
for i in range(3, c):
to_stack += ( src_im[:,:,i],)
to_stack += (src_im[:, :, i],)
matched = np.stack(to_stack, axis=-1).astype(src_im.dtype)
return matched