mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-12 08:07:03 -07:00
FacesetRelighter fixes and improvements:
now you have 3 ways: 1) define light directions manually (not for google colab) watch demo https://youtu.be/79xz7yEO5Jw 2) relight faceset with one random direction 3) relight faceset with predefined 8 directions
This commit is contained in:
parent
fe58459f36
commit
05153d9ba5
3 changed files with 235 additions and 36 deletions
|
@ -1,33 +1,55 @@
|
|||
import math
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.linalg as npla
|
||||
|
||||
|
||||
class DeepPortraitRelighting(object):
|
||||
|
||||
def __init__(self):
|
||||
from nnlib import nnlib
|
||||
nnlib.import_torch()
|
||||
|
||||
nnlib.import_torch()
|
||||
self.torch = nnlib.torch
|
||||
self.torch_device = nnlib.torch_device
|
||||
|
||||
self.torch_device = nnlib.torch_device
|
||||
self.model = DeepPortraitRelighting.build_model(self.torch, self.torch_device)
|
||||
|
||||
def SH_basis(self, alt, azi):
|
||||
alt = alt * math.pi / 180.0
|
||||
azi = azi * math.pi / 180.0
|
||||
|
||||
self.shs = [
|
||||
[1.084125496282453138e+00,-4.642676300617166185e-01,2.837846795150648915e-02,6.765292733937575687e-01,-3.594067725393816914e-01,4.790996460111427574e-02,-2.280054643781863066e-01,-8.125983081159608712e-02,2.881082012687687932e-01],
|
||||
[1.084125496282453138e+00,-4.642676300617170626e-01,5.466255701105990905e-01,3.996219229512094628e-01,-2.615439760463462715e-01,-2.511241554473071513e-01,6.495694866016435420e-02,3.510322039081858470e-01,1.189662732386344152e-01],
|
||||
[1.084125496282453138e+00,-4.642676300617179508e-01,6.532524688468428486e-01,-1.782088862752457814e-01,3.326676893441832261e-02,-3.610566644446819295e-01,3.647561777790956361e-01,-7.496419691318900735e-02,-5.412289239602386531e-02],
|
||||
[1.084125496282453138e+00,-4.642676300617186724e-01,2.679669346194941126e-01,-6.218447693376460972e-01,3.030269583891490037e-01,-1.991061409014726058e-01,-6.162944418511027977e-02,-3.176699976873690878e-01,1.920509612235956343e-01],
|
||||
[1.084125496282453138e+00,-4.642676300617186724e-01,-3.191031669056417219e-01,-5.972188577671910803e-01,3.446016675533919993e-01,1.127753677656503223e-01,-1.716692196540034188e-01,2.163406460637767315e-01,2.555824552121269688e-01],
|
||||
[1.084125496282453138e+00,-4.642676300617178398e-01,-6.658820752324799974e-01,-1.228749652534838893e-01,1.266842924569576145e-01,3.397347243069742673e-01,3.036887095295650041e-01,2.213893524577207617e-01,-1.886557316342868038e-02],
|
||||
[1.084125496282453138e+00,-4.642676300617169516e-01,-5.112381993903207800e-01,4.439962822886048266e-01,-1.866289387481862572e-01,3.108669041197227867e-01,2.021743042675238355e-01,-3.148681770175290051e-01,3.974379604123656762e-02]
|
||||
]
|
||||
x = math.cos(alt)*math.sin(azi)
|
||||
y = -math.cos(alt)*math.cos(azi)
|
||||
z = math.sin(alt)
|
||||
|
||||
normal = np.array([x,y,z])
|
||||
|
||||
norm_X = normal[0]
|
||||
norm_Y = normal[1]
|
||||
norm_Z = normal[2]
|
||||
|
||||
sh_basis = np.zeros((9))
|
||||
att= np.pi*np.array([1, 2.0/3.0, 1/4.0])
|
||||
sh_basis[0] = 0.5/np.sqrt(np.pi)*att[0]
|
||||
|
||||
sh_basis[1] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Y*att[1]
|
||||
sh_basis[2] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Z*att[1]
|
||||
sh_basis[3] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_X*att[1]
|
||||
|
||||
sh_basis[4] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_X*att[2]
|
||||
sh_basis[5] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_Z*att[2]
|
||||
sh_basis[6] = np.sqrt(5)/4/np.sqrt(np.pi)*(3*norm_Z**2-1)*att[2]
|
||||
sh_basis[7] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_X*norm_Z*att[2]
|
||||
sh_basis[8] = np.sqrt(15)/4/np.sqrt(np.pi)*(norm_X**2-norm_Y**2)*att[2]
|
||||
return sh_basis
|
||||
|
||||
#n = [0..8]
|
||||
def relight(self, img, n, lighten=False):
|
||||
def relight(self, img, alt, azi, lighten=False):
|
||||
torch = self.torch
|
||||
|
||||
sh = (np.array (self.shs[np.clip(n, 0,8)]).reshape( (1,9,1,1) )*0.7).astype(np.float32)
|
||||
|
||||
sh = self.SH_basis (alt, azi)
|
||||
sh = (sh.reshape( (1,9,1,1) ) ).astype(np.float32)
|
||||
sh = torch.autograd.Variable(torch.from_numpy(sh).to(self.torch_device))
|
||||
|
||||
row, col, _ = img.shape
|
||||
|
@ -54,13 +76,7 @@ class DeepPortraitRelighting(object):
|
|||
result = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR)
|
||||
result = cv2.resize(result, (col, row))
|
||||
return result
|
||||
|
||||
def relight_all(self, img, lighten=False):
|
||||
return [ self.relight(img, n, lighten=lighten) for n in range( len(self.shs) ) ]
|
||||
|
||||
def relight_random(self, img, lighten=False):
|
||||
return [ self.relight(img, np.random.randint(len(self.shs)), lighten=lighten ) ]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def build_model(torch, torch_device):
|
||||
nn = torch.nn
|
||||
|
@ -220,4 +236,4 @@ class DeepPortraitRelighting(object):
|
|||
model.load_state_dict(t_dict)
|
||||
model.to( torch_device )
|
||||
model.train(False)
|
||||
return model
|
||||
return model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue