refactoring

This commit is contained in:
iperov 2021-11-07 10:03:15 +04:00
commit 30ba51edf7
24 changed files with 663 additions and 459 deletions

View file

@ -2,12 +2,14 @@ import cv2
import numpy as np
import numpy.linalg as npla
def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
def sot(src,trg, mask=None, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0, return_diff=False):
"""
Color Transform via Sliced Optimal Transfer, ported from https://github.com/dcoeurjo/OTColorTransfer
src - any float range any channel image
dst - any float range any channel image, same shape as src
mask -
steps - number of solver steps
batch_size - solver batch size
reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0
@ -39,8 +41,8 @@ def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
dir = np.random.normal(size=c).astype(src_dtype)
dir /= npla.norm(dir)
projsource = np.sum( new_src*dir, axis=-1).reshape ((h*w))
projtarget = np.sum( trg*dir, axis=-1).reshape ((h*w))
projsource = np.sum( new_src*dir*mask, axis=-1).reshape ((h*w))
projtarget = np.sum( trg*dir*mask, axis=-1).reshape ((h*w))
idSource = np.argsort (projsource)
idTarget = np.argsort (projtarget)
@ -48,12 +50,18 @@ def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
a = projtarget[idTarget]-projsource[idSource]
for i_c in range(c):
advect[idSource,i_c] += a * dir[i_c]
new_src += advect.reshape( (h,w,c) ) / batch_size
new_src += (advect.reshape( (h,w,c) ) * mask) / batch_size
if reg_sigmaXY != 0.0:
src_diff = new_src-src
src_diff_filt = cv2.bilateralFilter (src_diff, 0, reg_sigmaV, reg_sigmaXY )
if len(src_diff_filt.shape) == 2:
src_diff_filt = src_diff_filt[...,None]
new_src = src + src_diff_filt
return new_src
if return_diff:
return src_diff_filt
return src + src_diff_filt
else:
if return_diff:
return new_src-src
return new_src