Adds masking to RCT

This commit is contained in:
Jeremy Hummel 2019-08-11 16:27:56 -07:00
commit 3d4886cd36
2 changed files with 23 additions and 26 deletions

View file

@ -2,7 +2,7 @@ import numpy as np
import cv2 import cv2
def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, source_mask=None, target_mask=None): def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, target_mask=None, source_mask=None):
""" """
Transfers the color distribution from the source to the target Transfers the color distribution from the source to the target
image using the mean and standard deviations of the L*a*b* image using the mean and standard deviations of the L*a*b*
@ -51,8 +51,6 @@ def reinhard_color_transfer(target, source, clip=False, preserve_paper=False, so
target = cv2.cvtColor(target.astype(np.float32), cv2.COLOR_BGR2LAB) target = cv2.cvtColor(target.astype(np.float32), cv2.COLOR_BGR2LAB)
# compute color statistics for the source and target images # compute color statistics for the source and target images
src_input = source if source_mask is None else source * source_mask
tgt_input = target if target_mask is None else target * target_mask
(lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input) (lMeanSrc, lStdSrc, aMeanSrc, aStdSrc, bMeanSrc, bStdSrc) = lab_image_stats(src_input)
(lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input) (lMeanTar, lStdTar, aMeanTar, aStdTar, bMeanTar, bStdTar) = lab_image_stats(tgt_input)
@ -138,32 +136,20 @@ def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5):
return matched_img return matched_img
def linear_lab_color_transform(target_img, source_img, eps=1e-5, mode='pca'): def lab_image_stats(image, mask=None):
"""doesn't work yet"""
np.clip(source_img, 0, 1, out=source_img)
np.clip(target_img, 0, 1, out=target_img)
# convert the images from the RGB to L*ab* color space, being
# sure to utilizing the floating point data type (note: OpenCV
# expects floats to be 32-bit, so use that instead of 64-bit)
source_img = cv2.cvtColor(source_img.astype(np.float32), cv2.COLOR_BGR2LAB)
target_img = cv2.cvtColor(target_img.astype(np.float32), cv2.COLOR_BGR2LAB)
target_img = linear_color_transfer(target_img, source_img, mode=mode, eps=eps)
target_img = cv2.cvtColor(np.clip(target_img, 0, 1).astype(np.float32), cv2.COLOR_LAB2BGR)
np.clip(target_img, 0, 1, out=target_img)
return target_img
def lab_image_stats(image):
# compute the mean and standard deviation of each channel # compute the mean and standard deviation of each channel
if mask is not None:
mask = np.squeeze(mask)
image = image[mask == 1]
(l, a, b) = cv2.split(image) (l, a, b) = cv2.split(image)
(lMean, lStd) = (l.mean(), l.std())
(aMean, aStd) = (a.mean(), a.std()) l_mean, l_std = np.mean(l), np.std(l)
(bMean, bStd) = (b.mean(), b.std()) a_mean, a_std = np.mean(l), np.std(l)
b_mean, b_std = np.mean(l), np.std(l)
# return the color statistics # return the color statistics
return (lMean, lStd, aMean, aStd, bMean, bStd) return l_mean, l_std, a_mean, a_std, b_mean, b_std
def _scale_array(arr, clip=True): def _scale_array(arr, clip=True):

View file

@ -223,7 +223,18 @@ class SampleProcessor(object):
if apply_ct and ct_sample is not None: if apply_ct and ct_sample is not None:
if ct_sample_bgr is None: if ct_sample_bgr is None:
ct_sample_bgr = ct_sample.load_bgr() ct_sample_bgr = ct_sample.load_bgr()
img_bgr = imagelib.reinhard_color_transfer(img_bgr, ct_sample_bgr, clip=True) if ct_sample_mask is None:
ct_sample_mask = ct_sample.load_fanseg_mask() or \
LandmarksProcessor.get_image_hull_mask(ct_sample_bgr.shape, ct_sample.landmarks)
img_bgr = imagelib.reinhard_color_transfer(img_bgr,
ct_sample_bgr,
clip=True,
target_mask=img_mask,
source_mask=ct_sample_mask)
# img_bgr = imagelib.reinhard_color_transfer(img_bgr,
# ct_sample_bgr,
# clip=True)
if normalize_std_dev: if normalize_std_dev:
img_bgr = (img_bgr - img_bgr.mean((0, 1))) / img_bgr.std((0, 1)) img_bgr = (img_bgr - img_bgr.mean((0, 1))) / img_bgr.std((0, 1))