Update color_transfer.py

Penrose inverse (pinv)  is more performant and accurate for PCA calculations then normal inverse
This commit is contained in:
Jeremy Hummel 2019-08-10 08:02:35 -07:00
commit 858ddf4079

View file

@ -117,20 +117,20 @@ def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5):
if mode == 'chol':
chol_t = np.linalg.cholesky(Ct)
chol_s = np.linalg.cholesky(Cs)
ts = chol_s.dot(np.linalg.inv(chol_t)).dot(t)
ts = chol_s.dot(np.linalg.pinv(chol_t)).dot(t)
if mode == 'pca':
eva_t, eve_t = np.linalg.eigh(Ct)
Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T)
eva_s, eve_s = np.linalg.eigh(Cs)
Qs = eve_s.dot(np.sqrt(np.diag(eva_s))).dot(eve_s.T)
ts = Qs.dot(np.linalg.inv(Qt)).dot(t)
ts = Qs.dot(np.linalg.pinv(Qt)).dot(t)
if mode == 'sym':
eva_t, eve_t = np.linalg.eigh(Ct)
Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T)
Qt_Cs_Qt = Qt.dot(Cs).dot(Qt)
eva_QtCsQt, eve_QtCsQt = np.linalg.eigh(Qt_Cs_Qt)
QtCsQt = eve_QtCsQt.dot(np.sqrt(np.diag(eva_QtCsQt))).dot(eve_QtCsQt.T)
ts = np.linalg.inv(Qt).dot(QtCsQt).dot(np.linalg.inv(Qt)).dot(t)
ts = np.linalg.inv(Qt).dot(QtCsQt).dot(np.linalg.pinv(Qt)).dot(t)
matched_img = ts.reshape(*target_img.transpose(2, 0, 1).shape).transpose(1, 2, 0)
matched_img += mu_s
matched_img[matched_img > 1] = 1
@ -138,7 +138,8 @@ def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5):
return matched_img
def lab_linear_color_transform(target_img, source_img, eps=1e-5, mode='pca'):
def linear_lab_color_transform(target_img, source_img, eps=1e-5, mode='pca'):
"""doesn't work yet"""
np.clip(source_img, 0, 1, out=source_img)
np.clip(target_img, 0, 1, out=target_img)
@ -149,7 +150,7 @@ def lab_linear_color_transform(target_img, source_img, eps=1e-5, mode='pca'):
target_img = cv2.cvtColor(target_img.astype(np.float32), cv2.COLOR_BGR2LAB)
target_img = linear_color_transfer(target_img, source_img, mode=mode, eps=eps)
target_img = cv2.cvtColor(target_img, cv2.COLOR_LAB2BGR)
target_img = cv2.cvtColor(np.clip(target_img, 0, 1).astype(np.float32), cv2.COLOR_LAB2BGR)
np.clip(target_img, 0, 1, out=target_img)
return target_img