diff --git a/main.py b/main.py index b9d5d33..62eff02 100644 --- a/main.py +++ b/main.py @@ -63,7 +63,7 @@ if __name__ == "__main__": p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.") p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU.") p.set_defaults (func=process_dev_extract_vggface2_dataset) - + def process_dev_extract_umd_csv(arguments): os_utils.set_process_lowest_prio() from mainscripts import Extractor @@ -78,8 +78,8 @@ if __name__ == "__main__": p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.") p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU.") p.set_defaults (func=process_dev_extract_umd_csv) - - + + def process_dev_apply_celebamaskhq(arguments): os_utils.set_process_lowest_prio() from mainscripts import dev_misc @@ -130,10 +130,10 @@ if __name__ == "__main__": #if arguments.remove_fanseg: # Util.remove_fanseg_folder (input_path=arguments.input_dir) - + if arguments.remove_ie_polys: Util.remove_ie_polys_folder (input_path=arguments.input_dir) - + p = subparsers.add_parser( "util", help="Utilities.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") p.add_argument('--convert-png-to-jpg', action="store_true", dest="convert_png_to_jpg", default=False, help="Convert DeepFaceLAB PNG files to JPEG.") @@ -190,7 +190,7 @@ if __name__ == "__main__": Converter.main (args, device_args) p = subparsers.add_parser( "convert", help="Converter") - p.add_argument('--training-data-src-dir', action=fixPathAction, dest="training_data_src_dir", help="(optional, may be required by some models) Dir of extracted SRC faceset.") + p.add_argument('--training-data-src-dir', action=fixPathAction, dest="training_data_src_dir", help="(optional, may be required by some models) Dir of extracted SRC faceset.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the converted files will be stored.") p.add_argument('--aligned-dir', action=fixPathAction, dest="aligned_dir", help="Aligned directory. This is where the extracted of dst faces stored.") @@ -270,9 +270,29 @@ if __name__ == "__main__": p.add_argument('--confirmed-dir', required=True, action=fixPathAction, dest="confirmed_dir", help="This is where the labeled faces will be stored.") p.add_argument('--skipped-dir', required=True, action=fixPathAction, dest="skipped_dir", help="This is where the labeled faces will be stored.") p.add_argument('--no-default-mask', action="store_true", dest="no_default_mask", default=False, help="Don't use default mask.") - + p.set_defaults(func=process_labelingtool_edit_mask) + def process_relight_faceset(arguments): + from mainscripts import FacesetRelighter + FacesetRelighter.relight (arguments.input_dir, arguments.lighten, arguments.random_one) + + def process_delete_relighted(arguments): + from mainscripts import FacesetRelighter + FacesetRelighter.delete_relighted (arguments.input_dir) + + facesettool_parser = subparsers.add_parser( "facesettool", help="Faceset tools.").add_subparsers() + + p = facesettool_parser.add_parser ("relight", help="Synthesize new faces from existing ones by relighting them. With the relighted faces neural network will better reproduce face shadows.") + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") + p.add_argument('--lighten', action="store_true", dest="lighten", default=None, help="Lighten the faces.") + p.add_argument('--random-one', action="store_true", dest="random_one", default=None, help="Relight the faces only with one random direction, otherwise relight with all directions.") + p.set_defaults(func=process_relight_faceset) + + p = facesettool_parser.add_parser ("delete_relighted", help="Delete relighted faces.") + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") + p.set_defaults(func=process_delete_relighted) + def bad_args(arguments): parser.print_help() exit(0) diff --git a/mainscripts/FacesetRelighter.py b/mainscripts/FacesetRelighter.py new file mode 100644 index 0000000..69fc5cf --- /dev/null +++ b/mainscripts/FacesetRelighter.py @@ -0,0 +1,81 @@ +import traceback +from pathlib import Path + +from interact import interact as io +from nnlib import DeepPortraitRelighting +from utils import Path_utils +from utils.cv2_utils import * +from utils.DFLJPG import DFLJPG +from utils.DFLPNG import DFLPNG + + +def relight(input_dir, lighten=None, random_one=None): + if lighten is None: + lighten = io.input_bool ("Lighten the faces? ( y/n default:n ) : ", False) + + if random_one is None: + random_one = io.input_bool ("Relight the faces only with one random direction? ( y/n default:y ) : ", True) + + input_path = Path(input_dir) + + image_paths = [Path(x) for x in Path_utils.get_image_paths(input_path)] + + dpr = DeepPortraitRelighting() + + for filepath in io.progress_bar_generator(image_paths, "Relighting"): + try: + if filepath.suffix == '.png': + dflimg = DFLPNG.load( str(filepath) ) + elif filepath.suffix == '.jpg': + dflimg = DFLJPG.load ( str(filepath) ) + else: + dflimg = None + + if dflimg is None: + io.log_err ("%s is not a dfl image file" % (filepath.name) ) + continue + else: + if dflimg.get_relighted(): + io.log_info (f"Skipping already relighted face [{filepath.name}]") + continue + img = cv2_imread (str(filepath)) + + if random_one: + relighted_imgs = dpr.relight_random(img,lighten=lighten) + else: + relighted_imgs = dpr.relight_all(img,lighten=lighten) + + for i,relighted_img in enumerate(relighted_imgs): + im_flags = [] + if filepath.suffix == '.jpg': + im_flags += [int(cv2.IMWRITE_JPEG_QUALITY), 100] + + relighted_filename = filepath.parent / (filepath.stem+f'_relighted_{i}'+filepath.suffix) + + cv2_imwrite (relighted_filename, relighted_img ) + dflimg.embed_and_set (relighted_filename, source_filename="_", relighted=True ) + except: + io.log_err (f"Exception occured while processing file {filepath.name}. Error: {traceback.format_exc()}") + +def delete_relighted(input_dir): + input_path = Path(input_dir) + image_paths = [Path(x) for x in Path_utils.get_image_paths(input_path)] + + files_to_delete = [] + for filepath in io.progress_bar_generator(image_paths, "Loading"): + if filepath.suffix == '.png': + dflimg = DFLPNG.load( str(filepath) ) + elif filepath.suffix == '.jpg': + dflimg = DFLJPG.load ( str(filepath) ) + else: + dflimg = None + + if dflimg is None: + io.log_err ("%s is not a dfl image file" % (filepath.name) ) + continue + else: + if dflimg.get_relighted(): + files_to_delete += [filepath] + + for file in io.progress_bar_generator(files_to_delete, "Deleting"): + file.unlink() diff --git a/nnlib/DeepPortraitRelighting.py b/nnlib/DeepPortraitRelighting.py new file mode 100644 index 0000000..9ee3ab8 --- /dev/null +++ b/nnlib/DeepPortraitRelighting.py @@ -0,0 +1,223 @@ +from pathlib import Path +import numpy as np +import cv2 + +class DeepPortraitRelighting(object): + + def __init__(self): + from nnlib import nnlib + nnlib.import_torch() + + self.torch = nnlib.torch + self.torch_device = nnlib.torch_device + + self.model = DeepPortraitRelighting.build_model(self.torch, self.torch_device) + + self.shs = [ + [1.084125496282453138e+00,-4.642676300617166185e-01,2.837846795150648915e-02,6.765292733937575687e-01,-3.594067725393816914e-01,4.790996460111427574e-02,-2.280054643781863066e-01,-8.125983081159608712e-02,2.881082012687687932e-01], + [1.084125496282453138e+00,-4.642676300617170626e-01,5.466255701105990905e-01,3.996219229512094628e-01,-2.615439760463462715e-01,-2.511241554473071513e-01,6.495694866016435420e-02,3.510322039081858470e-01,1.189662732386344152e-01], + [1.084125496282453138e+00,-4.642676300617179508e-01,6.532524688468428486e-01,-1.782088862752457814e-01,3.326676893441832261e-02,-3.610566644446819295e-01,3.647561777790956361e-01,-7.496419691318900735e-02,-5.412289239602386531e-02], + [1.084125496282453138e+00,-4.642676300617186724e-01,2.679669346194941126e-01,-6.218447693376460972e-01,3.030269583891490037e-01,-1.991061409014726058e-01,-6.162944418511027977e-02,-3.176699976873690878e-01,1.920509612235956343e-01], + [1.084125496282453138e+00,-4.642676300617186724e-01,-3.191031669056417219e-01,-5.972188577671910803e-01,3.446016675533919993e-01,1.127753677656503223e-01,-1.716692196540034188e-01,2.163406460637767315e-01,2.555824552121269688e-01], + [1.084125496282453138e+00,-4.642676300617178398e-01,-6.658820752324799974e-01,-1.228749652534838893e-01,1.266842924569576145e-01,3.397347243069742673e-01,3.036887095295650041e-01,2.213893524577207617e-01,-1.886557316342868038e-02], + [1.084125496282453138e+00,-4.642676300617169516e-01,-5.112381993903207800e-01,4.439962822886048266e-01,-1.866289387481862572e-01,3.108669041197227867e-01,2.021743042675238355e-01,-3.148681770175290051e-01,3.974379604123656762e-02] + ] + + #n = [0..8] + def relight(self, img, n, lighten=False): + torch = self.torch + + sh = (np.array (self.shs[np.clip(n, 0,8)]).reshape( (1,9,1,1) )*0.7).astype(np.float32) + sh = torch.autograd.Variable(torch.from_numpy(sh).to(self.torch_device)) + + row, col, _ = img.shape + img = cv2.resize(img, (512, 512)) + Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) + + inputL = Lab[:,:,0] + outputImg, outputSH = self.model(torch.autograd.Variable(torch.from_numpy(inputL[None,None,...].astype(np.float32)/255.0).to(self.torch_device)), + sh, 0) + + outputImg = outputImg[0].cpu().data.numpy() + outputImg = outputImg.transpose((1,2,0)) + outputImg = np.squeeze(outputImg) + outputImg = np.clip (outputImg, 0.0, 1.0) + outputImg = cv2.blur(outputImg, (3,3) ) + + if not lighten: + outputImg = inputL* outputImg + else: + outputImg = outputImg*255.0 + outputImg = np.clip(outputImg, 0,255).astype(np.uint8) + + Lab[:,:,0] = outputImg + result = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR) + result = cv2.resize(result, (col, row)) + return result + + def relight_all(self, img, lighten=False): + return [ self.relight(img, n, lighten=lighten) for n in range( len(self.shs) ) ] + + def relight_random(self, img, lighten=False): + return [ self.relight(img, np.random.randint(len(self.shs)), lighten=lighten ) ] + + @staticmethod + def build_model(torch, torch_device): + nn = torch.nn + F = torch.nn.functional + + def conv3X3(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + # define the network + class BasicBlock(nn.Module): + def __init__(self, inplanes, outplanes, batchNorm_type=0, stride=1, downsample=None): + super(BasicBlock, self).__init__() + # batchNorm_type 0 means batchnormalization + # 1 means instance normalization + self.inplanes = inplanes + self.outplanes = outplanes + self.conv1 = conv3X3(inplanes, outplanes, 1) + self.conv2 = conv3X3(outplanes, outplanes, 1) + if batchNorm_type == 0: + self.bn1 = nn.BatchNorm2d(outplanes) + self.bn2 = nn.BatchNorm2d(outplanes) + else: + self.bn1 = nn.InstanceNorm2d(outplanes) + self.bn2 = nn.InstanceNorm2d(outplanes) + + self.shortcuts = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = F.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.inplanes != self.outplanes: + out += self.shortcuts(x) + else: + out += x + + out = F.relu(out) + return out + + class HourglassBlock(nn.Module): + def __init__(self, inplane, mid_plane, middleNet, skipLayer=True): + super(HourglassBlock, self).__init__() + # upper branch + self.skipLayer = True + self.upper = BasicBlock(inplane, inplane, batchNorm_type=1) + + # lower branch + self.downSample = nn.MaxPool2d(kernel_size=2, stride=2) + self.upSample = nn.Upsample(scale_factor=2, mode='nearest') + self.low1 = BasicBlock(inplane, mid_plane) + self.middle = middleNet + self.low2 = BasicBlock(mid_plane, inplane, batchNorm_type=1) + + def forward(self, x, light, count, skip_count): + # we use count to indicate wich layer we are in + # max_count indicates the from which layer, we would use skip connections + out_upper = self.upper(x) + out_lower = self.downSample(x) + out_lower = self.low1(out_lower) + out_lower, out_middle = self.middle(out_lower, light, count+1, skip_count) + out_lower = self.low2(out_lower) + out_lower = self.upSample(out_lower) + if count >= skip_count and self.skipLayer: + out = out_lower + out_upper + else: + out = out_lower + return out, out_middle + + class lightingNet(nn.Module): + def __init__(self, ncInput, ncOutput, ncMiddle): + super(lightingNet, self).__init__() + self.ncInput = ncInput + self.ncOutput = ncOutput + self.ncMiddle = ncMiddle + self.predict_FC1 = nn.Conv2d(self.ncInput, self.ncMiddle, kernel_size=1, stride=1, bias=False) + self.predict_relu1 = nn.PReLU() + self.predict_FC2 = nn.Conv2d(self.ncMiddle, self.ncOutput, kernel_size=1, stride=1, bias=False) + + self.post_FC1 = nn.Conv2d(self.ncOutput, self.ncMiddle, kernel_size=1, stride=1, bias=False) + self.post_relu1 = nn.PReLU() + self.post_FC2 = nn.Conv2d(self.ncMiddle, self.ncInput, kernel_size=1, stride=1, bias=False) + self.post_relu2 = nn.ReLU() # to be consistance with the original feature + + def forward(self, innerFeat, target_light, count, skip_count): + x = innerFeat[:,0:self.ncInput,:,:] # lighting feature + _, _, row, col = x.shape + # predict lighting + feat = x.mean(dim=(2,3), keepdim=True) + light = self.predict_relu1(self.predict_FC1(feat)) + light = self.predict_FC2(light) + upFeat = self.post_relu1(self.post_FC1(target_light)) + upFeat = self.post_relu2(self.post_FC2(upFeat)) + upFeat = upFeat.repeat((1,1,row, col)) + innerFeat[:,0:self.ncInput,:,:] = upFeat + return innerFeat, light#light + + + class HourglassNet(nn.Module): + def __init__(self, baseFilter = 16, gray=True): + super(HourglassNet, self).__init__() + + self.ncLight = 27 # number of channels for input to lighting network + self.baseFilter = baseFilter + + # number of channles for output of lighting network + if gray: + self.ncOutLight = 9 # gray: channel is 1 + else: + self.ncOutLight = 27 # color: channel is 3 + + self.ncPre = self.baseFilter # number of channels for pre-convolution + + # number of channels + self.ncHG3 = self.baseFilter + self.ncHG2 = 2*self.baseFilter + self.ncHG1 = 4*self.baseFilter + self.ncHG0 = 8*self.baseFilter + self.ncLight + + self.pre_conv = nn.Conv2d(1, self.ncPre, kernel_size=5, stride=1, padding=2) + self.pre_bn = nn.BatchNorm2d(self.ncPre) + + self.light = lightingNet(self.ncLight, self.ncOutLight, 128) + self.HG0 = HourglassBlock(self.ncHG1, self.ncHG0, self.light) + self.HG1 = HourglassBlock(self.ncHG2, self.ncHG1, self.HG0) + self.HG2 = HourglassBlock(self.ncHG3, self.ncHG2, self.HG1) + self.HG3 = HourglassBlock(self.ncPre, self.ncHG3, self.HG2) + + self.conv_1 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=3, stride=1, padding=1) + self.bn_1 = nn.BatchNorm2d(self.ncPre) + self.conv_2 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0) + self.bn_2 = nn.BatchNorm2d(self.ncPre) + self.conv_3 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0) + self.bn_3 = nn.BatchNorm2d(self.ncPre) + + self.output = nn.Conv2d(self.ncPre, 1, kernel_size=1, stride=1, padding=0) + + def forward(self, x, target_light, skip_count): + feat = self.pre_conv(x) + + feat = F.relu(self.pre_bn(feat)) + # get the inner most features + feat, out_light = self.HG3(feat, target_light, 0, skip_count) + #return feat, out_light + + feat = F.relu(self.bn_1(self.conv_1(feat))) + feat = F.relu(self.bn_2(self.conv_2(feat))) + feat = F.relu(self.bn_3(self.conv_3(feat))) + out_img = self.output(feat) + out_img = torch.sigmoid(out_img) + return out_img, out_light + + model = HourglassNet() + t_dict = torch.load( Path(__file__).parent / 'DeepPortraitRelighting.t7' ) + model.load_state_dict(t_dict) + model.to( torch_device ) + model.train(False) + return model \ No newline at end of file diff --git a/nnlib/DeepPortraitRelighting.t7 b/nnlib/DeepPortraitRelighting.t7 new file mode 100644 index 0000000..943b172 Binary files /dev/null and b/nnlib/DeepPortraitRelighting.t7 differ diff --git a/nnlib/__init__.py b/nnlib/__init__.py index e30936a..6876185 100644 --- a/nnlib/__init__.py +++ b/nnlib/__init__.py @@ -1,4 +1,5 @@ from .nnlib import nnlib from .FUNIT import FUNIT from .TernausNet import TernausNet -from .VGGFace import VGGFace \ No newline at end of file +from .VGGFace import VGGFace +from .DeepPortraitRelighting import DeepPortraitRelighting \ No newline at end of file diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index eafdade..5841bb9 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -20,6 +20,9 @@ class nnlib(object): dlib = None + torch = None + torch_device = None + keras = None keras_contrib = None @@ -128,7 +131,27 @@ UNet = nnlib.UNet UNetTemporalPredictor = nnlib.UNetTemporalPredictor NLayerDiscriminator = nnlib.NLayerDiscriminator """ + @staticmethod + def import_torch(device_config=None): + if nnlib.torch is not None: + return + if device_config is None: + device_config = nnlib.active_DeviceConfig + else: + nnlib.active_DeviceConfig = device_config + + if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): + os.environ.pop('CUDA_VISIBLE_DEVICES') + + import torch + nnlib.torch = torch + + if device_config.cpu_only or device_config.backend == 'plaidML': + nnlib.torch_device = torch.device(type='cpu') + else: + nnlib.torch_device = torch.device(type='cuda', index=device_config.gpu_idxs[0] ) + torch.cuda.set_device(nnlib.torch_device) @staticmethod def _import_tf(device_config): @@ -634,7 +657,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator reduction_axes = list(range(len(input_shape))) del reduction_axes[self.axis] del reduction_axes[0] - + broadcast_shape = [1] * len(input_shape) broadcast_shape[self.axis] = input_shape[self.axis] mean = K.mean(x, reduction_axes, keepdims=True) @@ -912,7 +935,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator base_config = super(Adam, self).get_config() return dict(list(base_config.items()) + list(config.items())) nnlib.Adam = Adam - + class DenseMaxout(keras.layers.Layer): """A dense maxout layer. A `MaxoutDense` layer takes the element-wise maximum of @@ -1039,7 +1062,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator base_config = super(DenseMaxout, self).get_config() return dict(list(base_config.items()) + list(config.items())) nnlib.DenseMaxout = DenseMaxout - + def CAInitializerMP( conv_weights_list ): #Convolution Aware Initialization https://arxiv.org/abs/1702.06295 data = [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ] diff --git a/requirements-colab.txt b/requirements-colab.txt index 6ed6edc..ccdf38b 100644 --- a/requirements-colab.txt +++ b/requirements-colab.txt @@ -7,4 +7,10 @@ plaidml-keras==0.5.0 scikit-image tqdm ffmpeg-python==0.1.17 -git+https://www.github.com/keras-team/keras-contrib.git \ No newline at end of file +git+https://www.github.com/keras-team/keras-contrib.git + +# +# install following packages directly via pip! +# +# pip install torch===1.3.1 -f https://download.pytorch.org/whl/torch_stable.html +# pip install torchvision===0.4.0 -f https://download.pytorch.org/whl/torch_stable.html \ No newline at end of file diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 1152b18..d44148c 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -7,3 +7,9 @@ scikit-image tqdm ffmpeg-python==0.1.17 git+https://www.github.com/keras-team/keras-contrib.git + +# +# install following packages directly via pip! +# +# pip install torch===1.3.1+cpu -f https://download.pytorch.org/whl/torch_stable.html +# pip install torchvision===0.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 0d6ab03..edfa576 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -9,3 +9,9 @@ scikit-image tqdm ffmpeg-python==0.1.17 git+https://www.github.com/keras-team/keras-contrib.git + +# +# install following packages directly via pip! +# +# pip install torch===1.3.1 -f https://download.pytorch.org/whl/torch_stable.html +# pip install torchvision===0.4.0 -f https://download.pytorch.org/whl/torch_stable.html \ No newline at end of file diff --git a/requirements-opencl.txt b/requirements-opencl.txt index 2ac2908..44b0b00 100644 --- a/requirements-opencl.txt +++ b/requirements-opencl.txt @@ -9,3 +9,9 @@ scikit-image tqdm ffmpeg-python==0.1.17 git+https://www.github.com/keras-team/keras-contrib.git + +# +# install following packages directly via pip! +# +# pip install torch===1.3.1+cpu -f https://download.pytorch.org/whl/torch_stable.html +# pip install torchvision===0.4.0+cpu -f https://download.pytorch.org/whl/torch_stable.html \ No newline at end of file diff --git a/utils/DFLJPG.py b/utils/DFLJPG.py index 888ff98..9cb2c11 100644 --- a/utils/DFLJPG.py +++ b/utils/DFLJPG.py @@ -170,6 +170,7 @@ class DFLJPG(object): fanseg_mask=None, pitch_yaw_roll=None, eyebrows_expand_mod=None, + relighted=None, **kwargs ): @@ -195,7 +196,8 @@ class DFLJPG(object): 'image_to_face_mat': image_to_face_mat, 'fanseg_mask' : fanseg_mask, 'pitch_yaw_roll' : pitch_yaw_roll, - 'eyebrows_expand_mod' : eyebrows_expand_mod + 'eyebrows_expand_mod' : eyebrows_expand_mod, + 'relighted' : relighted }) try: @@ -214,6 +216,7 @@ class DFLJPG(object): fanseg_mask=None, pitch_yaw_roll=None, eyebrows_expand_mod=None, + relighted=None, **kwargs ): if face_type is None: face_type = self.get_face_type() @@ -226,7 +229,7 @@ class DFLJPG(object): if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask() if pitch_yaw_roll is None: pitch_yaw_roll = self.get_pitch_yaw_roll() if eyebrows_expand_mod is None: eyebrows_expand_mod = self.get_eyebrows_expand_mod() - + if relighted is None: relighted = self.get_relighted() DFLJPG.embed_data (filename, face_type=face_type, landmarks=landmarks, ie_polys=ie_polys, @@ -236,7 +239,7 @@ class DFLJPG(object): image_to_face_mat=image_to_face_mat, fanseg_mask=fanseg_mask, pitch_yaw_roll=pitch_yaw_roll, - eyebrows_expand_mod=eyebrows_expand_mod) + relighted=relighted) def remove_ie_polys(self): self.dfl_dict['ie_polys'] = None @@ -312,4 +315,6 @@ class DFLJPG(object): return self.dfl_dict.get ('pitch_yaw_roll', None) def get_eyebrows_expand_mod(self): return self.dfl_dict.get ('eyebrows_expand_mod', None) + def get_relighted(self): + return self.dfl_dict.get ('relighted', False) diff --git a/utils/DFLPNG.py b/utils/DFLPNG.py index 7650f91..7e84616 100644 --- a/utils/DFLPNG.py +++ b/utils/DFLPNG.py @@ -286,6 +286,7 @@ class DFLPNG(object): fanseg_mask=None, pitch_yaw_roll=None, eyebrows_expand_mod=None, + relighted=None, **kwargs ): @@ -312,6 +313,7 @@ class DFLPNG(object): 'fanseg_mask' : fanseg_mask, 'pitch_yaw_roll' : pitch_yaw_roll, 'eyebrows_expand_mod' : eyebrows_expand_mod, + 'relighted' : relighted }) try: @@ -330,6 +332,7 @@ class DFLPNG(object): fanseg_mask=None, pitch_yaw_roll=None, eyebrows_expand_mod=None, + relighted=None, **kwargs ): if face_type is None: face_type = self.get_face_type() @@ -342,7 +345,8 @@ class DFLPNG(object): if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask() if pitch_yaw_roll is None: pitch_yaw_roll = self.get_pitch_yaw_roll() if eyebrows_expand_mod is None: eyebrows_expand_mod = self.get_eyebrows_expand_mod() - + if relighted is None: relighted = self.get_relighted() + DFLPNG.embed_data (filename, face_type=face_type, landmarks=landmarks, ie_polys=ie_polys, @@ -352,7 +356,8 @@ class DFLPNG(object): image_to_face_mat=image_to_face_mat, fanseg_mask=fanseg_mask, pitch_yaw_roll=pitch_yaw_roll, - eyebrows_expand_mod=eyebrows_expand_mod) + eyebrows_expand_mod=eyebrows_expand_mod, + relighted=relighted) def remove_ie_polys(self): self.dfl_dict['ie_polys'] = None @@ -417,6 +422,9 @@ class DFLPNG(object): return self.dfl_dict.get ('pitch_yaw_roll', None) def get_eyebrows_expand_mod(self): return self.dfl_dict.get ('eyebrows_expand_mod', None) + def get_relighted(self): + return self.dfl_dict.get ('relighted', False) + def __str__(self): return "".format(len(self.chunks), **self.__dict__)