If using mask, scale the masked portion of image

This commit is contained in:
Jeremy Hummel 2019-08-13 23:45:29 -07:00
commit 390a9638d5
2 changed files with 40 additions and 37 deletions

View file

@ -77,11 +77,11 @@ def reinhard_color_transfer(source, target, clip=False, preserve_paper=False, so
a += aMeanTar a += aMeanTar
b += bMeanTar b += bMeanTar
# clip/scale the pixel intensities to [0, 255] if they fall # clip/scale the pixel intensities if they fall
# outside this range # outside the ranges for LAB
l = _scale_array(l, 0, 100, clip=clip) l = _scale_array(l, 0, 100, clip=clip, mask=source_mask)
a = _scale_array(a, -127, 127, clip=clip) a = _scale_array(a, -127, 127, clip=clip, mask=source_mask)
b = _scale_array(b, -127, 127, clip=clip) b = _scale_array(b, -127, 127, clip=clip, mask=source_mask)
# merge the channels together and convert back to the RGB color # merge the channels together and convert back to the RGB color
transfer = cv2.merge([l, a, b]) transfer = cv2.merge([l, a, b])
@ -180,7 +180,7 @@ def _min_max_scale(arr, new_range=(0, 255)):
return scaled return scaled
def _scale_array(arr, mn, mx, clip=True): def _scale_array(arr, mn, mx, clip=True, mask=None):
""" """
Trim NumPy array values to be in [0, 255] range with option of Trim NumPy array values to be in [0, 255] range with option of
clipping or scaling. clipping or scaling.
@ -197,7 +197,10 @@ def _scale_array(arr, mn, mx, clip=True):
if clip: if clip:
scaled = np.clip(arr, mn, mx) scaled = np.clip(arr, mn, mx)
else: else:
scale_range = (max([arr.min(), mn]), min([arr.max(), mx])) if mask is not None:
scale_range = (max([np.min(mask * arr), mn]), min([np.max(mask * arr), mx]))
else:
scale_range = (max([np.min(arr), mn]), min([np.max(arr), mx]))
scaled = _min_max_scale(arr, new_range=scale_range) scaled = _min_max_scale(arr, new_range=scale_range)
return scaled return scaled

View file

@ -15,7 +15,7 @@ class ColorTranfer(unittest.TestCase):
src_samples = SampleLoader.load(SampleType.FACE, './test_src', None) src_samples = SampleLoader.load(SampleType.FACE, './test_src', None)
dst_samples = SampleLoader.load(SampleType.FACE, './test_dst', None) dst_samples = SampleLoader.load(SampleType.FACE, './test_dst', None)
src_sample = src_samples[2] for src_sample in src_samples:
src_img = src_sample.load_bgr() src_img = src_sample.load_bgr()
src_mask = src_sample.load_mask() src_mask = src_sample.load_mask()