mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-05 20:42:11 -07:00
added FacesetRelighter:
Synthesize new faces from existing ones by relighting them using DeepPortraitRelighter network. With the relighted faces neural network will better reproduce face shadows. Therefore you can synthsize shadowed faces from fully lit faceset. https://i.imgur.com/wxcmQoi.jpg as a result, better fakes on dark faces: https://i.imgur.com/5xXIbz5.jpg in OpenCL build Relighter runs on CPU, install pytorch directly via pip install, look at requirements
This commit is contained in:
parent
b9c0815d17
commit
fe58459f36
12 changed files with 402 additions and 17 deletions
34
main.py
34
main.py
|
@ -63,7 +63,7 @@ if __name__ == "__main__":
|
|||
p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.")
|
||||
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU.")
|
||||
p.set_defaults (func=process_dev_extract_vggface2_dataset)
|
||||
|
||||
|
||||
def process_dev_extract_umd_csv(arguments):
|
||||
os_utils.set_process_lowest_prio()
|
||||
from mainscripts import Extractor
|
||||
|
@ -78,8 +78,8 @@ if __name__ == "__main__":
|
|||
p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.")
|
||||
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU.")
|
||||
p.set_defaults (func=process_dev_extract_umd_csv)
|
||||
|
||||
|
||||
|
||||
|
||||
def process_dev_apply_celebamaskhq(arguments):
|
||||
os_utils.set_process_lowest_prio()
|
||||
from mainscripts import dev_misc
|
||||
|
@ -130,10 +130,10 @@ if __name__ == "__main__":
|
|||
|
||||
#if arguments.remove_fanseg:
|
||||
# Util.remove_fanseg_folder (input_path=arguments.input_dir)
|
||||
|
||||
|
||||
if arguments.remove_ie_polys:
|
||||
Util.remove_ie_polys_folder (input_path=arguments.input_dir)
|
||||
|
||||
|
||||
p = subparsers.add_parser( "util", help="Utilities.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
|
||||
p.add_argument('--convert-png-to-jpg', action="store_true", dest="convert_png_to_jpg", default=False, help="Convert DeepFaceLAB PNG files to JPEG.")
|
||||
|
@ -190,7 +190,7 @@ if __name__ == "__main__":
|
|||
Converter.main (args, device_args)
|
||||
|
||||
p = subparsers.add_parser( "convert", help="Converter")
|
||||
p.add_argument('--training-data-src-dir', action=fixPathAction, dest="training_data_src_dir", help="(optional, may be required by some models) Dir of extracted SRC faceset.")
|
||||
p.add_argument('--training-data-src-dir', action=fixPathAction, dest="training_data_src_dir", help="(optional, may be required by some models) Dir of extracted SRC faceset.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.")
|
||||
p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the converted files will be stored.")
|
||||
p.add_argument('--aligned-dir', action=fixPathAction, dest="aligned_dir", help="Aligned directory. This is where the extracted of dst faces stored.")
|
||||
|
@ -270,9 +270,29 @@ if __name__ == "__main__":
|
|||
p.add_argument('--confirmed-dir', required=True, action=fixPathAction, dest="confirmed_dir", help="This is where the labeled faces will be stored.")
|
||||
p.add_argument('--skipped-dir', required=True, action=fixPathAction, dest="skipped_dir", help="This is where the labeled faces will be stored.")
|
||||
p.add_argument('--no-default-mask', action="store_true", dest="no_default_mask", default=False, help="Don't use default mask.")
|
||||
|
||||
|
||||
p.set_defaults(func=process_labelingtool_edit_mask)
|
||||
|
||||
def process_relight_faceset(arguments):
|
||||
from mainscripts import FacesetRelighter
|
||||
FacesetRelighter.relight (arguments.input_dir, arguments.lighten, arguments.random_one)
|
||||
|
||||
def process_delete_relighted(arguments):
|
||||
from mainscripts import FacesetRelighter
|
||||
FacesetRelighter.delete_relighted (arguments.input_dir)
|
||||
|
||||
facesettool_parser = subparsers.add_parser( "facesettool", help="Faceset tools.").add_subparsers()
|
||||
|
||||
p = facesettool_parser.add_parser ("relight", help="Synthesize new faces from existing ones by relighting them. With the relighted faces neural network will better reproduce face shadows.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
|
||||
p.add_argument('--lighten', action="store_true", dest="lighten", default=None, help="Lighten the faces.")
|
||||
p.add_argument('--random-one', action="store_true", dest="random_one", default=None, help="Relight the faces only with one random direction, otherwise relight with all directions.")
|
||||
p.set_defaults(func=process_relight_faceset)
|
||||
|
||||
p = facesettool_parser.add_parser ("delete_relighted", help="Delete relighted faces.")
|
||||
p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.")
|
||||
p.set_defaults(func=process_delete_relighted)
|
||||
|
||||
def bad_args(arguments):
|
||||
parser.print_help()
|
||||
exit(0)
|
||||
|
|
81
mainscripts/FacesetRelighter.py
Normal file
81
mainscripts/FacesetRelighter.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from interact import interact as io
|
||||
from nnlib import DeepPortraitRelighting
|
||||
from utils import Path_utils
|
||||
from utils.cv2_utils import *
|
||||
from utils.DFLJPG import DFLJPG
|
||||
from utils.DFLPNG import DFLPNG
|
||||
|
||||
|
||||
def relight(input_dir, lighten=None, random_one=None):
|
||||
if lighten is None:
|
||||
lighten = io.input_bool ("Lighten the faces? ( y/n default:n ) : ", False)
|
||||
|
||||
if random_one is None:
|
||||
random_one = io.input_bool ("Relight the faces only with one random direction? ( y/n default:y ) : ", True)
|
||||
|
||||
input_path = Path(input_dir)
|
||||
|
||||
image_paths = [Path(x) for x in Path_utils.get_image_paths(input_path)]
|
||||
|
||||
dpr = DeepPortraitRelighting()
|
||||
|
||||
for filepath in io.progress_bar_generator(image_paths, "Relighting"):
|
||||
try:
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
continue
|
||||
else:
|
||||
if dflimg.get_relighted():
|
||||
io.log_info (f"Skipping already relighted face [{filepath.name}]")
|
||||
continue
|
||||
img = cv2_imread (str(filepath))
|
||||
|
||||
if random_one:
|
||||
relighted_imgs = dpr.relight_random(img,lighten=lighten)
|
||||
else:
|
||||
relighted_imgs = dpr.relight_all(img,lighten=lighten)
|
||||
|
||||
for i,relighted_img in enumerate(relighted_imgs):
|
||||
im_flags = []
|
||||
if filepath.suffix == '.jpg':
|
||||
im_flags += [int(cv2.IMWRITE_JPEG_QUALITY), 100]
|
||||
|
||||
relighted_filename = filepath.parent / (filepath.stem+f'_relighted_{i}'+filepath.suffix)
|
||||
|
||||
cv2_imwrite (relighted_filename, relighted_img )
|
||||
dflimg.embed_and_set (relighted_filename, source_filename="_", relighted=True )
|
||||
except:
|
||||
io.log_err (f"Exception occured while processing file {filepath.name}. Error: {traceback.format_exc()}")
|
||||
|
||||
def delete_relighted(input_dir):
|
||||
input_path = Path(input_dir)
|
||||
image_paths = [Path(x) for x in Path_utils.get_image_paths(input_path)]
|
||||
|
||||
files_to_delete = []
|
||||
for filepath in io.progress_bar_generator(image_paths, "Loading"):
|
||||
if filepath.suffix == '.png':
|
||||
dflimg = DFLPNG.load( str(filepath) )
|
||||
elif filepath.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(filepath) )
|
||||
else:
|
||||
dflimg = None
|
||||
|
||||
if dflimg is None:
|
||||
io.log_err ("%s is not a dfl image file" % (filepath.name) )
|
||||
continue
|
||||
else:
|
||||
if dflimg.get_relighted():
|
||||
files_to_delete += [filepath]
|
||||
|
||||
for file in io.progress_bar_generator(files_to_delete, "Deleting"):
|
||||
file.unlink()
|
223
nnlib/DeepPortraitRelighting.py
Normal file
223
nnlib/DeepPortraitRelighting.py
Normal file
|
@ -0,0 +1,223 @@
|
|||
from pathlib import Path
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
class DeepPortraitRelighting(object):
|
||||
|
||||
def __init__(self):
|
||||
from nnlib import nnlib
|
||||
nnlib.import_torch()
|
||||
|
||||
self.torch = nnlib.torch
|
||||
self.torch_device = nnlib.torch_device
|
||||
|
||||
self.model = DeepPortraitRelighting.build_model(self.torch, self.torch_device)
|
||||
|
||||
self.shs = [
|
||||
[1.084125496282453138e+00,-4.642676300617166185e-01,2.837846795150648915e-02,6.765292733937575687e-01,-3.594067725393816914e-01,4.790996460111427574e-02,-2.280054643781863066e-01,-8.125983081159608712e-02,2.881082012687687932e-01],
|
||||
[1.084125496282453138e+00,-4.642676300617170626e-01,5.466255701105990905e-01,3.996219229512094628e-01,-2.615439760463462715e-01,-2.511241554473071513e-01,6.495694866016435420e-02,3.510322039081858470e-01,1.189662732386344152e-01],
|
||||
[1.084125496282453138e+00,-4.642676300617179508e-01,6.532524688468428486e-01,-1.782088862752457814e-01,3.326676893441832261e-02,-3.610566644446819295e-01,3.647561777790956361e-01,-7.496419691318900735e-02,-5.412289239602386531e-02],
|
||||
[1.084125496282453138e+00,-4.642676300617186724e-01,2.679669346194941126e-01,-6.218447693376460972e-01,3.030269583891490037e-01,-1.991061409014726058e-01,-6.162944418511027977e-02,-3.176699976873690878e-01,1.920509612235956343e-01],
|
||||
[1.084125496282453138e+00,-4.642676300617186724e-01,-3.191031669056417219e-01,-5.972188577671910803e-01,3.446016675533919993e-01,1.127753677656503223e-01,-1.716692196540034188e-01,2.163406460637767315e-01,2.555824552121269688e-01],
|
||||
[1.084125496282453138e+00,-4.642676300617178398e-01,-6.658820752324799974e-01,-1.228749652534838893e-01,1.266842924569576145e-01,3.397347243069742673e-01,3.036887095295650041e-01,2.213893524577207617e-01,-1.886557316342868038e-02],
|
||||
[1.084125496282453138e+00,-4.642676300617169516e-01,-5.112381993903207800e-01,4.439962822886048266e-01,-1.866289387481862572e-01,3.108669041197227867e-01,2.021743042675238355e-01,-3.148681770175290051e-01,3.974379604123656762e-02]
|
||||
]
|
||||
|
||||
#n = [0..8]
|
||||
def relight(self, img, n, lighten=False):
|
||||
torch = self.torch
|
||||
|
||||
sh = (np.array (self.shs[np.clip(n, 0,8)]).reshape( (1,9,1,1) )*0.7).astype(np.float32)
|
||||
sh = torch.autograd.Variable(torch.from_numpy(sh).to(self.torch_device))
|
||||
|
||||
row, col, _ = img.shape
|
||||
img = cv2.resize(img, (512, 512))
|
||||
Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
|
||||
|
||||
inputL = Lab[:,:,0]
|
||||
outputImg, outputSH = self.model(torch.autograd.Variable(torch.from_numpy(inputL[None,None,...].astype(np.float32)/255.0).to(self.torch_device)),
|
||||
sh, 0)
|
||||
|
||||
outputImg = outputImg[0].cpu().data.numpy()
|
||||
outputImg = outputImg.transpose((1,2,0))
|
||||
outputImg = np.squeeze(outputImg)
|
||||
outputImg = np.clip (outputImg, 0.0, 1.0)
|
||||
outputImg = cv2.blur(outputImg, (3,3) )
|
||||
|
||||
if not lighten:
|
||||
outputImg = inputL* outputImg
|
||||
else:
|
||||
outputImg = outputImg*255.0
|
||||
outputImg = np.clip(outputImg, 0,255).astype(np.uint8)
|
||||
|
||||
Lab[:,:,0] = outputImg
|
||||
result = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR)
|
||||
result = cv2.resize(result, (col, row))
|
||||
return result
|
||||
|
||||
def relight_all(self, img, lighten=False):
|
||||
return [ self.relight(img, n, lighten=lighten) for n in range( len(self.shs) ) ]
|
||||
|
||||
def relight_random(self, img, lighten=False):
|
||||
return [ self.relight(img, np.random.randint(len(self.shs)), lighten=lighten ) ]
|
||||
|
||||
@staticmethod
|
||||
def build_model(torch, torch_device):
|
||||
nn = torch.nn
|
||||
F = torch.nn.functional
|
||||
|
||||
def conv3X3(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
# define the network
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, inplanes, outplanes, batchNorm_type=0, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
# batchNorm_type 0 means batchnormalization
|
||||
# 1 means instance normalization
|
||||
self.inplanes = inplanes
|
||||
self.outplanes = outplanes
|
||||
self.conv1 = conv3X3(inplanes, outplanes, 1)
|
||||
self.conv2 = conv3X3(outplanes, outplanes, 1)
|
||||
if batchNorm_type == 0:
|
||||
self.bn1 = nn.BatchNorm2d(outplanes)
|
||||
self.bn2 = nn.BatchNorm2d(outplanes)
|
||||
else:
|
||||
self.bn1 = nn.InstanceNorm2d(outplanes)
|
||||
self.bn2 = nn.InstanceNorm2d(outplanes)
|
||||
|
||||
self.shortcuts = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = F.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.inplanes != self.outplanes:
|
||||
out += self.shortcuts(x)
|
||||
else:
|
||||
out += x
|
||||
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
class HourglassBlock(nn.Module):
|
||||
def __init__(self, inplane, mid_plane, middleNet, skipLayer=True):
|
||||
super(HourglassBlock, self).__init__()
|
||||
# upper branch
|
||||
self.skipLayer = True
|
||||
self.upper = BasicBlock(inplane, inplane, batchNorm_type=1)
|
||||
|
||||
# lower branch
|
||||
self.downSample = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.upSample = nn.Upsample(scale_factor=2, mode='nearest')
|
||||
self.low1 = BasicBlock(inplane, mid_plane)
|
||||
self.middle = middleNet
|
||||
self.low2 = BasicBlock(mid_plane, inplane, batchNorm_type=1)
|
||||
|
||||
def forward(self, x, light, count, skip_count):
|
||||
# we use count to indicate wich layer we are in
|
||||
# max_count indicates the from which layer, we would use skip connections
|
||||
out_upper = self.upper(x)
|
||||
out_lower = self.downSample(x)
|
||||
out_lower = self.low1(out_lower)
|
||||
out_lower, out_middle = self.middle(out_lower, light, count+1, skip_count)
|
||||
out_lower = self.low2(out_lower)
|
||||
out_lower = self.upSample(out_lower)
|
||||
if count >= skip_count and self.skipLayer:
|
||||
out = out_lower + out_upper
|
||||
else:
|
||||
out = out_lower
|
||||
return out, out_middle
|
||||
|
||||
class lightingNet(nn.Module):
|
||||
def __init__(self, ncInput, ncOutput, ncMiddle):
|
||||
super(lightingNet, self).__init__()
|
||||
self.ncInput = ncInput
|
||||
self.ncOutput = ncOutput
|
||||
self.ncMiddle = ncMiddle
|
||||
self.predict_FC1 = nn.Conv2d(self.ncInput, self.ncMiddle, kernel_size=1, stride=1, bias=False)
|
||||
self.predict_relu1 = nn.PReLU()
|
||||
self.predict_FC2 = nn.Conv2d(self.ncMiddle, self.ncOutput, kernel_size=1, stride=1, bias=False)
|
||||
|
||||
self.post_FC1 = nn.Conv2d(self.ncOutput, self.ncMiddle, kernel_size=1, stride=1, bias=False)
|
||||
self.post_relu1 = nn.PReLU()
|
||||
self.post_FC2 = nn.Conv2d(self.ncMiddle, self.ncInput, kernel_size=1, stride=1, bias=False)
|
||||
self.post_relu2 = nn.ReLU() # to be consistance with the original feature
|
||||
|
||||
def forward(self, innerFeat, target_light, count, skip_count):
|
||||
x = innerFeat[:,0:self.ncInput,:,:] # lighting feature
|
||||
_, _, row, col = x.shape
|
||||
# predict lighting
|
||||
feat = x.mean(dim=(2,3), keepdim=True)
|
||||
light = self.predict_relu1(self.predict_FC1(feat))
|
||||
light = self.predict_FC2(light)
|
||||
upFeat = self.post_relu1(self.post_FC1(target_light))
|
||||
upFeat = self.post_relu2(self.post_FC2(upFeat))
|
||||
upFeat = upFeat.repeat((1,1,row, col))
|
||||
innerFeat[:,0:self.ncInput,:,:] = upFeat
|
||||
return innerFeat, light#light
|
||||
|
||||
|
||||
class HourglassNet(nn.Module):
|
||||
def __init__(self, baseFilter = 16, gray=True):
|
||||
super(HourglassNet, self).__init__()
|
||||
|
||||
self.ncLight = 27 # number of channels for input to lighting network
|
||||
self.baseFilter = baseFilter
|
||||
|
||||
# number of channles for output of lighting network
|
||||
if gray:
|
||||
self.ncOutLight = 9 # gray: channel is 1
|
||||
else:
|
||||
self.ncOutLight = 27 # color: channel is 3
|
||||
|
||||
self.ncPre = self.baseFilter # number of channels for pre-convolution
|
||||
|
||||
# number of channels
|
||||
self.ncHG3 = self.baseFilter
|
||||
self.ncHG2 = 2*self.baseFilter
|
||||
self.ncHG1 = 4*self.baseFilter
|
||||
self.ncHG0 = 8*self.baseFilter + self.ncLight
|
||||
|
||||
self.pre_conv = nn.Conv2d(1, self.ncPre, kernel_size=5, stride=1, padding=2)
|
||||
self.pre_bn = nn.BatchNorm2d(self.ncPre)
|
||||
|
||||
self.light = lightingNet(self.ncLight, self.ncOutLight, 128)
|
||||
self.HG0 = HourglassBlock(self.ncHG1, self.ncHG0, self.light)
|
||||
self.HG1 = HourglassBlock(self.ncHG2, self.ncHG1, self.HG0)
|
||||
self.HG2 = HourglassBlock(self.ncHG3, self.ncHG2, self.HG1)
|
||||
self.HG3 = HourglassBlock(self.ncPre, self.ncHG3, self.HG2)
|
||||
|
||||
self.conv_1 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=3, stride=1, padding=1)
|
||||
self.bn_1 = nn.BatchNorm2d(self.ncPre)
|
||||
self.conv_2 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0)
|
||||
self.bn_2 = nn.BatchNorm2d(self.ncPre)
|
||||
self.conv_3 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0)
|
||||
self.bn_3 = nn.BatchNorm2d(self.ncPre)
|
||||
|
||||
self.output = nn.Conv2d(self.ncPre, 1, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, target_light, skip_count):
|
||||
feat = self.pre_conv(x)
|
||||
|
||||
feat = F.relu(self.pre_bn(feat))
|
||||
# get the inner most features
|
||||
feat, out_light = self.HG3(feat, target_light, 0, skip_count)
|
||||
#return feat, out_light
|
||||
|
||||
feat = F.relu(self.bn_1(self.conv_1(feat)))
|
||||
feat = F.relu(self.bn_2(self.conv_2(feat)))
|
||||
feat = F.relu(self.bn_3(self.conv_3(feat)))
|
||||
out_img = self.output(feat)
|
||||
out_img = torch.sigmoid(out_img)
|
||||
return out_img, out_light
|
||||
|
||||
model = HourglassNet()
|
||||
t_dict = torch.load( Path(__file__).parent / 'DeepPortraitRelighting.t7' )
|
||||
model.load_state_dict(t_dict)
|
||||
model.to( torch_device )
|
||||
model.train(False)
|
||||
return model
|
BIN
nnlib/DeepPortraitRelighting.t7
Normal file
BIN
nnlib/DeepPortraitRelighting.t7
Normal file
Binary file not shown.
|
@ -1,4 +1,5 @@
|
|||
from .nnlib import nnlib
|
||||
from .FUNIT import FUNIT
|
||||
from .TernausNet import TernausNet
|
||||
from .VGGFace import VGGFace
|
||||
from .VGGFace import VGGFace
|
||||
from .DeepPortraitRelighting import DeepPortraitRelighting
|
|
@ -20,6 +20,9 @@ class nnlib(object):
|
|||
|
||||
dlib = None
|
||||
|
||||
torch = None
|
||||
torch_device = None
|
||||
|
||||
keras = None
|
||||
keras_contrib = None
|
||||
|
||||
|
@ -128,7 +131,27 @@ UNet = nnlib.UNet
|
|||
UNetTemporalPredictor = nnlib.UNetTemporalPredictor
|
||||
NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||
"""
|
||||
@staticmethod
|
||||
def import_torch(device_config=None):
|
||||
if nnlib.torch is not None:
|
||||
return
|
||||
|
||||
if device_config is None:
|
||||
device_config = nnlib.active_DeviceConfig
|
||||
else:
|
||||
nnlib.active_DeviceConfig = device_config
|
||||
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
|
||||
os.environ.pop('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
import torch
|
||||
nnlib.torch = torch
|
||||
|
||||
if device_config.cpu_only or device_config.backend == 'plaidML':
|
||||
nnlib.torch_device = torch.device(type='cpu')
|
||||
else:
|
||||
nnlib.torch_device = torch.device(type='cuda', index=device_config.gpu_idxs[0] )
|
||||
torch.cuda.set_device(nnlib.torch_device)
|
||||
|
||||
@staticmethod
|
||||
def _import_tf(device_config):
|
||||
|
@ -634,7 +657,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
reduction_axes = list(range(len(input_shape)))
|
||||
del reduction_axes[self.axis]
|
||||
del reduction_axes[0]
|
||||
|
||||
|
||||
broadcast_shape = [1] * len(input_shape)
|
||||
broadcast_shape[self.axis] = input_shape[self.axis]
|
||||
mean = K.mean(x, reduction_axes, keepdims=True)
|
||||
|
@ -912,7 +935,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
base_config = super(Adam, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
nnlib.Adam = Adam
|
||||
|
||||
|
||||
class DenseMaxout(keras.layers.Layer):
|
||||
"""A dense maxout layer.
|
||||
A `MaxoutDense` layer takes the element-wise maximum of
|
||||
|
@ -1039,7 +1062,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
base_config = super(DenseMaxout, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
nnlib.DenseMaxout = DenseMaxout
|
||||
|
||||
|
||||
def CAInitializerMP( conv_weights_list ):
|
||||
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295
|
||||
data = [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ]
|
||||
|
|
|
@ -7,4 +7,10 @@ plaidml-keras==0.5.0
|
|||
scikit-image
|
||||
tqdm
|
||||
ffmpeg-python==0.1.17
|
||||
git+https://www.github.com/keras-team/keras-contrib.git
|
||||
git+https://www.github.com/keras-team/keras-contrib.git
|
||||
|
||||
#
|
||||
# install following packages directly via pip!
|
||||
#
|
||||
# pip install torch===1.3.1 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
# pip install torchvision===0.4.0 -f https://download.pytorch.org/whl/torch_stable.html
|
|
@ -7,3 +7,9 @@ scikit-image
|
|||
tqdm
|
||||
ffmpeg-python==0.1.17
|
||||
git+https://www.github.com/keras-team/keras-contrib.git
|
||||
|
||||
#
|
||||
# install following packages directly via pip!
|
||||
#
|
||||
# pip install torch===1.3.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
# pip install torchvision===0.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
|
|
|
@ -9,3 +9,9 @@ scikit-image
|
|||
tqdm
|
||||
ffmpeg-python==0.1.17
|
||||
git+https://www.github.com/keras-team/keras-contrib.git
|
||||
|
||||
#
|
||||
# install following packages directly via pip!
|
||||
#
|
||||
# pip install torch===1.3.1 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
# pip install torchvision===0.4.0 -f https://download.pytorch.org/whl/torch_stable.html
|
|
@ -9,3 +9,9 @@ scikit-image
|
|||
tqdm
|
||||
ffmpeg-python==0.1.17
|
||||
git+https://www.github.com/keras-team/keras-contrib.git
|
||||
|
||||
#
|
||||
# install following packages directly via pip!
|
||||
#
|
||||
# pip install torch===1.3.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
# pip install torchvision===0.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
|
@ -170,6 +170,7 @@ class DFLJPG(object):
|
|||
fanseg_mask=None,
|
||||
pitch_yaw_roll=None,
|
||||
eyebrows_expand_mod=None,
|
||||
relighted=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
|
@ -195,7 +196,8 @@ class DFLJPG(object):
|
|||
'image_to_face_mat': image_to_face_mat,
|
||||
'fanseg_mask' : fanseg_mask,
|
||||
'pitch_yaw_roll' : pitch_yaw_roll,
|
||||
'eyebrows_expand_mod' : eyebrows_expand_mod
|
||||
'eyebrows_expand_mod' : eyebrows_expand_mod,
|
||||
'relighted' : relighted
|
||||
})
|
||||
|
||||
try:
|
||||
|
@ -214,6 +216,7 @@ class DFLJPG(object):
|
|||
fanseg_mask=None,
|
||||
pitch_yaw_roll=None,
|
||||
eyebrows_expand_mod=None,
|
||||
relighted=None,
|
||||
**kwargs
|
||||
):
|
||||
if face_type is None: face_type = self.get_face_type()
|
||||
|
@ -226,7 +229,7 @@ class DFLJPG(object):
|
|||
if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask()
|
||||
if pitch_yaw_roll is None: pitch_yaw_roll = self.get_pitch_yaw_roll()
|
||||
if eyebrows_expand_mod is None: eyebrows_expand_mod = self.get_eyebrows_expand_mod()
|
||||
|
||||
if relighted is None: relighted = self.get_relighted()
|
||||
DFLJPG.embed_data (filename, face_type=face_type,
|
||||
landmarks=landmarks,
|
||||
ie_polys=ie_polys,
|
||||
|
@ -236,7 +239,7 @@ class DFLJPG(object):
|
|||
image_to_face_mat=image_to_face_mat,
|
||||
fanseg_mask=fanseg_mask,
|
||||
pitch_yaw_roll=pitch_yaw_roll,
|
||||
eyebrows_expand_mod=eyebrows_expand_mod)
|
||||
relighted=relighted)
|
||||
|
||||
def remove_ie_polys(self):
|
||||
self.dfl_dict['ie_polys'] = None
|
||||
|
@ -312,4 +315,6 @@ class DFLJPG(object):
|
|||
return self.dfl_dict.get ('pitch_yaw_roll', None)
|
||||
def get_eyebrows_expand_mod(self):
|
||||
return self.dfl_dict.get ('eyebrows_expand_mod', None)
|
||||
def get_relighted(self):
|
||||
return self.dfl_dict.get ('relighted', False)
|
||||
|
||||
|
|
|
@ -286,6 +286,7 @@ class DFLPNG(object):
|
|||
fanseg_mask=None,
|
||||
pitch_yaw_roll=None,
|
||||
eyebrows_expand_mod=None,
|
||||
relighted=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
|
@ -312,6 +313,7 @@ class DFLPNG(object):
|
|||
'fanseg_mask' : fanseg_mask,
|
||||
'pitch_yaw_roll' : pitch_yaw_roll,
|
||||
'eyebrows_expand_mod' : eyebrows_expand_mod,
|
||||
'relighted' : relighted
|
||||
})
|
||||
|
||||
try:
|
||||
|
@ -330,6 +332,7 @@ class DFLPNG(object):
|
|||
fanseg_mask=None,
|
||||
pitch_yaw_roll=None,
|
||||
eyebrows_expand_mod=None,
|
||||
relighted=None,
|
||||
**kwargs
|
||||
):
|
||||
if face_type is None: face_type = self.get_face_type()
|
||||
|
@ -342,7 +345,8 @@ class DFLPNG(object):
|
|||
if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask()
|
||||
if pitch_yaw_roll is None: pitch_yaw_roll = self.get_pitch_yaw_roll()
|
||||
if eyebrows_expand_mod is None: eyebrows_expand_mod = self.get_eyebrows_expand_mod()
|
||||
|
||||
if relighted is None: relighted = self.get_relighted()
|
||||
|
||||
DFLPNG.embed_data (filename, face_type=face_type,
|
||||
landmarks=landmarks,
|
||||
ie_polys=ie_polys,
|
||||
|
@ -352,7 +356,8 @@ class DFLPNG(object):
|
|||
image_to_face_mat=image_to_face_mat,
|
||||
fanseg_mask=fanseg_mask,
|
||||
pitch_yaw_roll=pitch_yaw_roll,
|
||||
eyebrows_expand_mod=eyebrows_expand_mod)
|
||||
eyebrows_expand_mod=eyebrows_expand_mod,
|
||||
relighted=relighted)
|
||||
|
||||
def remove_ie_polys(self):
|
||||
self.dfl_dict['ie_polys'] = None
|
||||
|
@ -417,6 +422,9 @@ class DFLPNG(object):
|
|||
return self.dfl_dict.get ('pitch_yaw_roll', None)
|
||||
def get_eyebrows_expand_mod(self):
|
||||
return self.dfl_dict.get ('eyebrows_expand_mod', None)
|
||||
def get_relighted(self):
|
||||
return self.dfl_dict.get ('relighted', False)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return "<PNG length={length} chunks={}>".format(len(self.chunks), **self.__dict__)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue