Fixes issues in color transfer

This commit is contained in:
Jeremy Hummel 2019-08-13 17:50:52 -07:00
commit 64f7f5ca08
2 changed files with 17 additions and 16 deletions

View file

@ -2,9 +2,9 @@ import numpy as np
import cv2
def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, target_mask=None, source_mask=None):
def reinhard_color_transfer(source, target, clip=False, preserve_paper=False, source_mask=None, target_mask=None):
"""
Transfers the color distribution from the source to the target
Transfers the color distribution from the target to the source
image using the mean and standard deviations of the L*a*b*
color space.
@ -47,18 +47,19 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, ta
# convert the images from the RGB to L*ab* color space, being
# sure to utilizing the floating point data type (note: OpenCV
# expects floats to be 32-bit, so use that instead of 64-bit)
source = cv2.cvtColor(source.astype(np.float32), cv2.COLOR_BGR2LAB)
target = cv2.cvtColor(target.astype(np.float32), cv2.COLOR_BGR2LAB)
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB)
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB)
# compute color statistics for the source and target images
(lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(source, mask=source_mask)
(lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(target, mask=target_mask)
# subtract the means from the target image
(l, a, b) = cv2.split(target)
l -= lMeanTar
a -= aMeanTar
b -= bMeanTar
# subtract the means from the source image
(l, a, b) = cv2.split(source)
l -= lMeanSrc
a -= aMeanSrc
b -= bMeanSrc
if preserve_paper:
# scale by the standard deviations using paper proposed factor
@ -72,9 +73,9 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, ta
b = (bStdSrc / bStdTar) * b
# add in the source mean
l += lMeanSrc
a += aMeanSrc
b += bMeanSrc
l += lMeanTar
a += aMeanTar
b += bMeanTar
# clip/scale the pixel intensities to [0, 255] if they fall
# outside this range
@ -144,8 +145,8 @@ def lab_image_stats(image, mask=None):
l, a, b = l[mask == 1], a[mask == 1], b[mask == 1]
l_mean, l_std = np.mean(l), np.std(l)
a_mean, a_std = np.mean(l), np.std(l)
b_mean, b_std = np.mean(l), np.std(l)
a_mean, a_std = np.mean(a), np.std(a)
b_mean, b_std = np.mean(b), np.std(b)
# return the color statistics
return l_mean, l_std, a_mean, a_std, b_mean, b_std

View file

@ -245,8 +245,8 @@ class SampleProcessor(object):
ct_sample_bgr,
clip=False,
preserve_paper=True,
target_mask=img_mask,
source_mask=ct_sample_mask)
source_mask=img_mask,
target_mask=ct_sample_mask)
if normalize_std_dev: