diff --git a/imagelib/color_transfer.py b/imagelib/color_transfer.py index 093ed3d..a2d4326 100644 --- a/imagelib/color_transfer.py +++ b/imagelib/color_transfer.py @@ -77,11 +77,11 @@ def reinhard_color_transfer(source, target, clip=False, preserve_paper=False, so a += aMeanTar b += bMeanTar - # clip/scale the pixel intensities to [0, 255] if they fall - # outside this range - l = _scale_array(l, 0, 100, clip=clip) - a = _scale_array(a, -127, 127, clip=clip) - b = _scale_array(b, -127, 127, clip=clip) + # clip/scale the pixel intensities if they fall + # outside the ranges for LAB + l = _scale_array(l, 0, 100, clip=clip, mask=source_mask) + a = _scale_array(a, -127, 127, clip=clip, mask=source_mask) + b = _scale_array(b, -127, 127, clip=clip, mask=source_mask) # merge the channels together and convert back to the RGB color transfer = cv2.merge([l, a, b]) @@ -180,7 +180,7 @@ def _min_max_scale(arr, new_range=(0, 255)): return scaled -def _scale_array(arr, mn, mx, clip=True): +def _scale_array(arr, mn, mx, clip=True, mask=None): """ Trim NumPy array values to be in [0, 255] range with option of clipping or scaling. @@ -197,7 +197,10 @@ def _scale_array(arr, mn, mx, clip=True): if clip: scaled = np.clip(arr, mn, mx) else: - scale_range = (max([arr.min(), mn]), min([arr.max(), mx])) + if mask is not None: + scale_range = (max([np.min(mask * arr), mn]), min([np.max(mask * arr), mx])) + else: + scale_range = (max([np.min(arr), mn]), min([np.max(arr), mx])) scaled = _min_max_scale(arr, new_range=scale_range) return scaled diff --git a/imagelib/test/color_transfer.py b/imagelib/test/color_transfer.py index 5b1338f..c570452 100644 --- a/imagelib/test/color_transfer.py +++ b/imagelib/test/color_transfer.py @@ -15,44 +15,44 @@ class ColorTranfer(unittest.TestCase): src_samples = SampleLoader.load(SampleType.FACE, './test_src', None) dst_samples = SampleLoader.load(SampleType.FACE, './test_dst', None) - src_sample = src_samples[2] - src_img = src_sample.load_bgr() - src_mask = src_sample.load_mask() + for src_sample in src_samples: + src_img = src_sample.load_bgr() + src_mask = src_sample.load_mask() - # Toggle to see masks - show_masks = False + # Toggle to see masks + show_masks = False - grid = [] - for ct_sample in dst_samples: - print(src_sample.filename, ct_sample.filename) - ct_img = ct_sample.load_bgr() - ct_mask = ct_sample.load_mask() + grid = [] + for ct_sample in dst_samples: + print(src_sample.filename, ct_sample.filename) + ct_img = ct_sample.load_bgr() + ct_mask = ct_sample.load_mask() - lct_img = linear_color_transfer(src_img, ct_img) - rct_img = reinhard_color_transfer(src_img, ct_img) - rct_img_clip = reinhard_color_transfer(src_img, ct_img, clip=True) - rct_img_paper = reinhard_color_transfer(src_img, ct_img, preserve_paper=True) - rct_img_paper_clip = reinhard_color_transfer(src_img, ct_img, clip=True, preserve_paper=True) + lct_img = linear_color_transfer(src_img, ct_img) + rct_img = reinhard_color_transfer(src_img, ct_img) + rct_img_clip = reinhard_color_transfer(src_img, ct_img, clip=True) + rct_img_paper = reinhard_color_transfer(src_img, ct_img, preserve_paper=True) + rct_img_paper_clip = reinhard_color_transfer(src_img, ct_img, clip=True, preserve_paper=True) - masked_rct_img = reinhard_color_transfer(src_img, ct_img, source_mask=src_mask, target_mask=ct_mask) - masked_rct_img_clip = reinhard_color_transfer(src_img, ct_img, clip=True, source_mask=src_mask, target_mask=ct_mask) - masked_rct_img_paper = reinhard_color_transfer(src_img, ct_img, preserve_paper=True, source_mask=src_mask, target_mask=ct_mask) - masked_rct_img_paper_clip = reinhard_color_transfer(src_img, ct_img, clip=True, preserve_paper=True, source_mask=src_mask, target_mask=ct_mask) + masked_rct_img = reinhard_color_transfer(src_img, ct_img, source_mask=src_mask, target_mask=ct_mask) + masked_rct_img_clip = reinhard_color_transfer(src_img, ct_img, clip=True, source_mask=src_mask, target_mask=ct_mask) + masked_rct_img_paper = reinhard_color_transfer(src_img, ct_img, preserve_paper=True, source_mask=src_mask, target_mask=ct_mask) + masked_rct_img_paper_clip = reinhard_color_transfer(src_img, ct_img, clip=True, preserve_paper=True, source_mask=src_mask, target_mask=ct_mask) - results = [lct_img, rct_img, rct_img_clip, rct_img_paper, rct_img_paper_clip, - masked_rct_img, masked_rct_img_clip, masked_rct_img_paper, masked_rct_img_paper_clip] + results = [lct_img, rct_img, rct_img_clip, rct_img_paper, rct_img_paper_clip, + masked_rct_img, masked_rct_img_clip, masked_rct_img_paper, masked_rct_img_paper_clip] - if show_masks: - results = [src_mask * im for im in results] - src_img *= src_mask - ct_img *= ct_mask + if show_masks: + results = [src_mask * im for im in results] + src_img *= src_mask + ct_img *= ct_mask - results = np.concatenate((src_img, ct_img, *results), axis=1) - grid.append(results) + results = np.concatenate((src_img, ct_img, *results), axis=1) + grid.append(results) - cv2.namedWindow('test output', cv2.WINDOW_NORMAL) - cv2.imshow('test output', np.concatenate(grid, axis=0)) - cv2.waitKey(0) + cv2.namedWindow('test output', cv2.WINDOW_NORMAL) + cv2.imshow('test output', np.concatenate(grid, axis=0)) + cv2.waitKey(0) cv2.destroyAllWindows()