diff --git a/imagelib/color_transfer.py b/imagelib/color_transfer.py index 67a580d..2b085d5 100644 --- a/imagelib/color_transfer.py +++ b/imagelib/color_transfer.py @@ -138,11 +138,11 @@ def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5): def lab_image_stats(image, mask=None): # compute the mean and standard deviation of each channel + l, a, b = cv2.split(image) + if mask is not None: mask = np.squeeze(mask) - image = image[mask == 1] - - (l, a, b) = cv2.split(image) + l, a, b = l[mask == 1], a[mask == 1], b[mask == 1] l_mean, l_std = np.mean(l), np.std(l) a_mean, a_std = np.mean(l), np.std(l)