mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
Synthesize new faces from existing ones by relighting them using DeepPortraitRelighter network. With the relighted faces neural network will better reproduce face shadows. Therefore you can synthsize shadowed faces from fully lit faceset. https://i.imgur.com/wxcmQoi.jpg as a result, better fakes on dark faces: https://i.imgur.com/5xXIbz5.jpg in OpenCL build Relighter runs on CPU, install pytorch directly via pip install, look at requirements
223 lines
No EOL
11 KiB
Python
223 lines
No EOL
11 KiB
Python
from pathlib import Path
|
|
import numpy as np
|
|
import cv2
|
|
|
|
class DeepPortraitRelighting(object):
|
|
|
|
def __init__(self):
|
|
from nnlib import nnlib
|
|
nnlib.import_torch()
|
|
|
|
self.torch = nnlib.torch
|
|
self.torch_device = nnlib.torch_device
|
|
|
|
self.model = DeepPortraitRelighting.build_model(self.torch, self.torch_device)
|
|
|
|
self.shs = [
|
|
[1.084125496282453138e+00,-4.642676300617166185e-01,2.837846795150648915e-02,6.765292733937575687e-01,-3.594067725393816914e-01,4.790996460111427574e-02,-2.280054643781863066e-01,-8.125983081159608712e-02,2.881082012687687932e-01],
|
|
[1.084125496282453138e+00,-4.642676300617170626e-01,5.466255701105990905e-01,3.996219229512094628e-01,-2.615439760463462715e-01,-2.511241554473071513e-01,6.495694866016435420e-02,3.510322039081858470e-01,1.189662732386344152e-01],
|
|
[1.084125496282453138e+00,-4.642676300617179508e-01,6.532524688468428486e-01,-1.782088862752457814e-01,3.326676893441832261e-02,-3.610566644446819295e-01,3.647561777790956361e-01,-7.496419691318900735e-02,-5.412289239602386531e-02],
|
|
[1.084125496282453138e+00,-4.642676300617186724e-01,2.679669346194941126e-01,-6.218447693376460972e-01,3.030269583891490037e-01,-1.991061409014726058e-01,-6.162944418511027977e-02,-3.176699976873690878e-01,1.920509612235956343e-01],
|
|
[1.084125496282453138e+00,-4.642676300617186724e-01,-3.191031669056417219e-01,-5.972188577671910803e-01,3.446016675533919993e-01,1.127753677656503223e-01,-1.716692196540034188e-01,2.163406460637767315e-01,2.555824552121269688e-01],
|
|
[1.084125496282453138e+00,-4.642676300617178398e-01,-6.658820752324799974e-01,-1.228749652534838893e-01,1.266842924569576145e-01,3.397347243069742673e-01,3.036887095295650041e-01,2.213893524577207617e-01,-1.886557316342868038e-02],
|
|
[1.084125496282453138e+00,-4.642676300617169516e-01,-5.112381993903207800e-01,4.439962822886048266e-01,-1.866289387481862572e-01,3.108669041197227867e-01,2.021743042675238355e-01,-3.148681770175290051e-01,3.974379604123656762e-02]
|
|
]
|
|
|
|
#n = [0..8]
|
|
def relight(self, img, n, lighten=False):
|
|
torch = self.torch
|
|
|
|
sh = (np.array (self.shs[np.clip(n, 0,8)]).reshape( (1,9,1,1) )*0.7).astype(np.float32)
|
|
sh = torch.autograd.Variable(torch.from_numpy(sh).to(self.torch_device))
|
|
|
|
row, col, _ = img.shape
|
|
img = cv2.resize(img, (512, 512))
|
|
Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
|
|
|
|
inputL = Lab[:,:,0]
|
|
outputImg, outputSH = self.model(torch.autograd.Variable(torch.from_numpy(inputL[None,None,...].astype(np.float32)/255.0).to(self.torch_device)),
|
|
sh, 0)
|
|
|
|
outputImg = outputImg[0].cpu().data.numpy()
|
|
outputImg = outputImg.transpose((1,2,0))
|
|
outputImg = np.squeeze(outputImg)
|
|
outputImg = np.clip (outputImg, 0.0, 1.0)
|
|
outputImg = cv2.blur(outputImg, (3,3) )
|
|
|
|
if not lighten:
|
|
outputImg = inputL* outputImg
|
|
else:
|
|
outputImg = outputImg*255.0
|
|
outputImg = np.clip(outputImg, 0,255).astype(np.uint8)
|
|
|
|
Lab[:,:,0] = outputImg
|
|
result = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR)
|
|
result = cv2.resize(result, (col, row))
|
|
return result
|
|
|
|
def relight_all(self, img, lighten=False):
|
|
return [ self.relight(img, n, lighten=lighten) for n in range( len(self.shs) ) ]
|
|
|
|
def relight_random(self, img, lighten=False):
|
|
return [ self.relight(img, np.random.randint(len(self.shs)), lighten=lighten ) ]
|
|
|
|
@staticmethod
|
|
def build_model(torch, torch_device):
|
|
nn = torch.nn
|
|
F = torch.nn.functional
|
|
|
|
def conv3X3(in_planes, out_planes, stride=1):
|
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
|
|
|
# define the network
|
|
class BasicBlock(nn.Module):
|
|
def __init__(self, inplanes, outplanes, batchNorm_type=0, stride=1, downsample=None):
|
|
super(BasicBlock, self).__init__()
|
|
# batchNorm_type 0 means batchnormalization
|
|
# 1 means instance normalization
|
|
self.inplanes = inplanes
|
|
self.outplanes = outplanes
|
|
self.conv1 = conv3X3(inplanes, outplanes, 1)
|
|
self.conv2 = conv3X3(outplanes, outplanes, 1)
|
|
if batchNorm_type == 0:
|
|
self.bn1 = nn.BatchNorm2d(outplanes)
|
|
self.bn2 = nn.BatchNorm2d(outplanes)
|
|
else:
|
|
self.bn1 = nn.InstanceNorm2d(outplanes)
|
|
self.bn2 = nn.InstanceNorm2d(outplanes)
|
|
|
|
self.shortcuts = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False)
|
|
|
|
def forward(self, x):
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = F.relu(out)
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
if self.inplanes != self.outplanes:
|
|
out += self.shortcuts(x)
|
|
else:
|
|
out += x
|
|
|
|
out = F.relu(out)
|
|
return out
|
|
|
|
class HourglassBlock(nn.Module):
|
|
def __init__(self, inplane, mid_plane, middleNet, skipLayer=True):
|
|
super(HourglassBlock, self).__init__()
|
|
# upper branch
|
|
self.skipLayer = True
|
|
self.upper = BasicBlock(inplane, inplane, batchNorm_type=1)
|
|
|
|
# lower branch
|
|
self.downSample = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.upSample = nn.Upsample(scale_factor=2, mode='nearest')
|
|
self.low1 = BasicBlock(inplane, mid_plane)
|
|
self.middle = middleNet
|
|
self.low2 = BasicBlock(mid_plane, inplane, batchNorm_type=1)
|
|
|
|
def forward(self, x, light, count, skip_count):
|
|
# we use count to indicate wich layer we are in
|
|
# max_count indicates the from which layer, we would use skip connections
|
|
out_upper = self.upper(x)
|
|
out_lower = self.downSample(x)
|
|
out_lower = self.low1(out_lower)
|
|
out_lower, out_middle = self.middle(out_lower, light, count+1, skip_count)
|
|
out_lower = self.low2(out_lower)
|
|
out_lower = self.upSample(out_lower)
|
|
if count >= skip_count and self.skipLayer:
|
|
out = out_lower + out_upper
|
|
else:
|
|
out = out_lower
|
|
return out, out_middle
|
|
|
|
class lightingNet(nn.Module):
|
|
def __init__(self, ncInput, ncOutput, ncMiddle):
|
|
super(lightingNet, self).__init__()
|
|
self.ncInput = ncInput
|
|
self.ncOutput = ncOutput
|
|
self.ncMiddle = ncMiddle
|
|
self.predict_FC1 = nn.Conv2d(self.ncInput, self.ncMiddle, kernel_size=1, stride=1, bias=False)
|
|
self.predict_relu1 = nn.PReLU()
|
|
self.predict_FC2 = nn.Conv2d(self.ncMiddle, self.ncOutput, kernel_size=1, stride=1, bias=False)
|
|
|
|
self.post_FC1 = nn.Conv2d(self.ncOutput, self.ncMiddle, kernel_size=1, stride=1, bias=False)
|
|
self.post_relu1 = nn.PReLU()
|
|
self.post_FC2 = nn.Conv2d(self.ncMiddle, self.ncInput, kernel_size=1, stride=1, bias=False)
|
|
self.post_relu2 = nn.ReLU() # to be consistance with the original feature
|
|
|
|
def forward(self, innerFeat, target_light, count, skip_count):
|
|
x = innerFeat[:,0:self.ncInput,:,:] # lighting feature
|
|
_, _, row, col = x.shape
|
|
# predict lighting
|
|
feat = x.mean(dim=(2,3), keepdim=True)
|
|
light = self.predict_relu1(self.predict_FC1(feat))
|
|
light = self.predict_FC2(light)
|
|
upFeat = self.post_relu1(self.post_FC1(target_light))
|
|
upFeat = self.post_relu2(self.post_FC2(upFeat))
|
|
upFeat = upFeat.repeat((1,1,row, col))
|
|
innerFeat[:,0:self.ncInput,:,:] = upFeat
|
|
return innerFeat, light#light
|
|
|
|
|
|
class HourglassNet(nn.Module):
|
|
def __init__(self, baseFilter = 16, gray=True):
|
|
super(HourglassNet, self).__init__()
|
|
|
|
self.ncLight = 27 # number of channels for input to lighting network
|
|
self.baseFilter = baseFilter
|
|
|
|
# number of channles for output of lighting network
|
|
if gray:
|
|
self.ncOutLight = 9 # gray: channel is 1
|
|
else:
|
|
self.ncOutLight = 27 # color: channel is 3
|
|
|
|
self.ncPre = self.baseFilter # number of channels for pre-convolution
|
|
|
|
# number of channels
|
|
self.ncHG3 = self.baseFilter
|
|
self.ncHG2 = 2*self.baseFilter
|
|
self.ncHG1 = 4*self.baseFilter
|
|
self.ncHG0 = 8*self.baseFilter + self.ncLight
|
|
|
|
self.pre_conv = nn.Conv2d(1, self.ncPre, kernel_size=5, stride=1, padding=2)
|
|
self.pre_bn = nn.BatchNorm2d(self.ncPre)
|
|
|
|
self.light = lightingNet(self.ncLight, self.ncOutLight, 128)
|
|
self.HG0 = HourglassBlock(self.ncHG1, self.ncHG0, self.light)
|
|
self.HG1 = HourglassBlock(self.ncHG2, self.ncHG1, self.HG0)
|
|
self.HG2 = HourglassBlock(self.ncHG3, self.ncHG2, self.HG1)
|
|
self.HG3 = HourglassBlock(self.ncPre, self.ncHG3, self.HG2)
|
|
|
|
self.conv_1 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=3, stride=1, padding=1)
|
|
self.bn_1 = nn.BatchNorm2d(self.ncPre)
|
|
self.conv_2 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0)
|
|
self.bn_2 = nn.BatchNorm2d(self.ncPre)
|
|
self.conv_3 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0)
|
|
self.bn_3 = nn.BatchNorm2d(self.ncPre)
|
|
|
|
self.output = nn.Conv2d(self.ncPre, 1, kernel_size=1, stride=1, padding=0)
|
|
|
|
def forward(self, x, target_light, skip_count):
|
|
feat = self.pre_conv(x)
|
|
|
|
feat = F.relu(self.pre_bn(feat))
|
|
# get the inner most features
|
|
feat, out_light = self.HG3(feat, target_light, 0, skip_count)
|
|
#return feat, out_light
|
|
|
|
feat = F.relu(self.bn_1(self.conv_1(feat)))
|
|
feat = F.relu(self.bn_2(self.conv_2(feat)))
|
|
feat = F.relu(self.bn_3(self.conv_3(feat)))
|
|
out_img = self.output(feat)
|
|
out_img = torch.sigmoid(out_img)
|
|
return out_img, out_light
|
|
|
|
model = HourglassNet()
|
|
t_dict = torch.load( Path(__file__).parent / 'DeepPortraitRelighting.t7' )
|
|
model.load_state_dict(t_dict)
|
|
model.to( torch_device )
|
|
model.train(False)
|
|
return model |