diff --git a/imagelib/color_transfer.py b/imagelib/color_transfer.py index 7d55828..093ed3d 100644 --- a/imagelib/color_transfer.py +++ b/imagelib/color_transfer.py @@ -2,9 +2,9 @@ import numpy as np import cv2 -def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, target_mask=None, source_mask=None): +def reinhard_color_transfer(source, target, clip=False, preserve_paper=False, source_mask=None, target_mask=None): """ - Transfers the color distribution from the source to the target + Transfers the color distribution from the target to the source image using the mean and standard deviations of the L*a*b* color space. @@ -47,18 +47,19 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, ta # convert the images from the RGB to L*ab* color space, being # sure to utilizing the floating point data type (note: OpenCV # expects floats to be 32-bit, so use that instead of 64-bit) - source = cv2.cvtColor(source.astype(np.float32), cv2.COLOR_BGR2LAB) - target = cv2.cvtColor(target.astype(np.float32), cv2.COLOR_BGR2LAB) + source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB) + target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB) # compute color statistics for the source and target images (lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(source, mask=source_mask) (lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(target, mask=target_mask) - # subtract the means from the target image - (l, a, b) = cv2.split(target) - l -= lMeanTar - a -= aMeanTar - b -= bMeanTar + + # subtract the means from the source image + (l, a, b) = cv2.split(source) + l -= lMeanSrc + a -= aMeanSrc + b -= bMeanSrc if preserve_paper: # scale by the standard deviations using paper proposed factor @@ -72,9 +73,9 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, ta b = (bStdSrc / bStdTar) * b # add in the source mean - l += lMeanSrc - a += aMeanSrc - b += bMeanSrc + l += lMeanTar + a += aMeanTar + b += bMeanTar # clip/scale the pixel intensities to [0, 255] if they fall # outside this range @@ -144,8 +145,8 @@ def lab_image_stats(image, mask=None): l, a, b = l[mask == 1], a[mask == 1], b[mask == 1] l_mean, l_std = np.mean(l), np.std(l) - a_mean, a_std = np.mean(l), np.std(l) - b_mean, b_std = np.mean(l), np.std(l) + a_mean, a_std = np.mean(a), np.std(a) + b_mean, b_std = np.mean(b), np.std(b) # return the color statistics return l_mean, l_std, a_mean, a_std, b_mean, b_std diff --git a/samplelib/SampleProcessor.py b/samplelib/SampleProcessor.py index cbbbe47..d00b992 100644 --- a/samplelib/SampleProcessor.py +++ b/samplelib/SampleProcessor.py @@ -245,8 +245,8 @@ class SampleProcessor(object): ct_sample_bgr, clip=False, preserve_paper=True, - target_mask=img_mask, - source_mask=ct_sample_mask) + source_mask=img_mask, + target_mask=ct_sample_mask) if normalize_std_dev: