diff --git a/imagelib/color_transfer.py b/imagelib/color_transfer.py index ee778a6..580f680 100644 --- a/imagelib/color_transfer.py +++ b/imagelib/color_transfer.py @@ -140,8 +140,8 @@ def lab_image_stats(image, mask=None): l, a, b = cv2.split(image) if mask is not None: - mask = np.squeeze(mask) - l, a, b = l[mask == 1], a[mask == 1], b[mask == 1] + im_mask = np.squeeze(mask) if len(np.shape(mask)) == 3 else mask + l, a, b = l[im_mask == 1], a[im_mask == 1], b[im_mask == 1] l_mean, l_std = np.mean(l), np.std(l) a_mean, a_std = np.mean(a), np.std(a)