mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-08-21 05:53:25 -07:00
refactoring
This commit is contained in:
parent
8489949f2c
commit
30ba51edf7
24 changed files with 663 additions and 459 deletions
|
@ -2,12 +2,14 @@ import cv2
|
|||
import numpy as np
|
||||
import numpy.linalg as npla
|
||||
|
||||
def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
|
||||
def sot(src,trg, mask=None, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0, return_diff=False):
|
||||
"""
|
||||
Color Transform via Sliced Optimal Transfer, ported from https://github.com/dcoeurjo/OTColorTransfer
|
||||
|
||||
src - any float range any channel image
|
||||
dst - any float range any channel image, same shape as src
|
||||
mask -
|
||||
|
||||
steps - number of solver steps
|
||||
batch_size - solver batch size
|
||||
reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0
|
||||
|
@ -39,8 +41,8 @@ def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
|
|||
dir = np.random.normal(size=c).astype(src_dtype)
|
||||
dir /= npla.norm(dir)
|
||||
|
||||
projsource = np.sum( new_src*dir, axis=-1).reshape ((h*w))
|
||||
projtarget = np.sum( trg*dir, axis=-1).reshape ((h*w))
|
||||
projsource = np.sum( new_src*dir*mask, axis=-1).reshape ((h*w))
|
||||
projtarget = np.sum( trg*dir*mask, axis=-1).reshape ((h*w))
|
||||
|
||||
idSource = np.argsort (projsource)
|
||||
idTarget = np.argsort (projtarget)
|
||||
|
@ -48,12 +50,18 @@ def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
|
|||
a = projtarget[idTarget]-projsource[idSource]
|
||||
for i_c in range(c):
|
||||
advect[idSource,i_c] += a * dir[i_c]
|
||||
new_src += advect.reshape( (h,w,c) ) / batch_size
|
||||
|
||||
new_src += (advect.reshape( (h,w,c) ) * mask) / batch_size
|
||||
|
||||
if reg_sigmaXY != 0.0:
|
||||
src_diff = new_src-src
|
||||
src_diff_filt = cv2.bilateralFilter (src_diff, 0, reg_sigmaV, reg_sigmaXY )
|
||||
if len(src_diff_filt.shape) == 2:
|
||||
src_diff_filt = src_diff_filt[...,None]
|
||||
new_src = src + src_diff_filt
|
||||
return new_src
|
||||
if return_diff:
|
||||
return src_diff_filt
|
||||
return src + src_diff_filt
|
||||
else:
|
||||
if return_diff:
|
||||
return new_src-src
|
||||
return new_src
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue