mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
initial code to extract umdfaces.io dataset and train pose estimator
This commit is contained in:
parent
51a917facc
commit
e58197ca22
18 changed files with 437 additions and 57 deletions
|
@ -26,7 +26,7 @@ class FANSegmentator(object):
|
||||||
|
|
||||||
self.model = FANSegmentator.BuildModel(resolution, ngf=64)
|
self.model = FANSegmentator.BuildModel(resolution, ngf=64)
|
||||||
|
|
||||||
if weights_file_root:
|
if weights_file_root is not None:
|
||||||
weights_file_root = Path(weights_file_root)
|
weights_file_root = Path(weights_file_root)
|
||||||
else:
|
else:
|
||||||
weights_file_root = Path(__file__).parent
|
weights_file_root = Path(__file__).parent
|
||||||
|
|
|
@ -332,7 +332,7 @@ def calc_face_yaw(landmarks):
|
||||||
return float(r-l)
|
return float(r-l)
|
||||||
|
|
||||||
#returns pitch,yaw [-1...+1]
|
#returns pitch,yaw [-1...+1]
|
||||||
def estimate_pitch_yaw(aligned_256px_landmarks):
|
def estimate_pitch_yaw_roll(aligned_256px_landmarks):
|
||||||
shape = (256,256)
|
shape = (256,256)
|
||||||
focal_length = shape[1]
|
focal_length = shape[1]
|
||||||
camera_center = (shape[1] / 2, shape[0] / 2)
|
camera_center = (shape[1] / 2, shape[0] / 2)
|
||||||
|
@ -347,7 +347,8 @@ def estimate_pitch_yaw(aligned_256px_landmarks):
|
||||||
camera_matrix,
|
camera_matrix,
|
||||||
np.zeros((4, 1)) )
|
np.zeros((4, 1)) )
|
||||||
|
|
||||||
pitch, yaw, _ = mathlib.rotationMatrixToEulerAngles( cv2.Rodrigues(rotation_vector)[0] )
|
pitch, yaw, roll = mathlib.rotationMatrixToEulerAngles( cv2.Rodrigues(rotation_vector)[0] )
|
||||||
pitch = np.clip ( pitch*1.25, -1.0, 1.0 )
|
pitch = np.clip ( pitch*1.25, -1.0, 1.0 )
|
||||||
yaw = np.clip ( yaw*1.25, -1.0, 1.0 )
|
yaw = np.clip ( yaw*1.25, -1.0, 1.0 )
|
||||||
return pitch, yaw
|
roll = np.clip ( roll*1.25, -1.0, 1.0 )
|
||||||
|
return pitch, yaw, roll
|
||||||
|
|
155
facelib/PoseEstimator.py
Normal file
155
facelib/PoseEstimator.py
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from interact import interact as io
|
||||||
|
from nnlib import nnlib
|
||||||
|
|
||||||
|
"""
|
||||||
|
PoseEstimator estimates pitch, yaw, roll, from FAN aligned face.
|
||||||
|
trained on https://www.umdfaces.io
|
||||||
|
"""
|
||||||
|
|
||||||
|
class PoseEstimator(object):
|
||||||
|
VERSION = 1
|
||||||
|
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None, training=False):
|
||||||
|
exec( nnlib.import_all(), locals(), globals() )
|
||||||
|
|
||||||
|
self.class_num = 180
|
||||||
|
|
||||||
|
self.model = PoseEstimator.BuildModel(resolution, class_num=self.class_num)
|
||||||
|
|
||||||
|
if weights_file_root is not None:
|
||||||
|
weights_file_root = Path(weights_file_root)
|
||||||
|
else:
|
||||||
|
weights_file_root = Path(__file__).parent
|
||||||
|
|
||||||
|
self.weights_path = weights_file_root / ('PoseEst_%d_%s.h5' % (resolution, face_type_str) )
|
||||||
|
|
||||||
|
if load_weights:
|
||||||
|
self.model.load_weights (str(self.weights_path))
|
||||||
|
|
||||||
|
idx_tensor = np.array([idx for idx in range(self.class_num)], dtype=K.floatx() )
|
||||||
|
idx_tensor = K.constant(idx_tensor)
|
||||||
|
|
||||||
|
#inp_t = Input ( (resolution,resolution,3) )
|
||||||
|
|
||||||
|
inp_t, = self.model.inputs
|
||||||
|
pitch_bins_t, yaw_bins_t, roll_bins_t = self.model.outputs
|
||||||
|
|
||||||
|
pitch_t, yaw_t, roll_t = K.sum ( pitch_bins_t * idx_tensor, 1), K.sum ( yaw_bins_t * idx_tensor, 1), K.sum ( roll_bins_t * idx_tensor, 1)
|
||||||
|
|
||||||
|
inp_pitch_bins_t = Input ( (self.class_num,) )
|
||||||
|
inp_pitch_t = Input ( (1,) )
|
||||||
|
|
||||||
|
inp_yaw_bins_t = Input ( (self.class_num,) )
|
||||||
|
inp_yaw_t = Input ( (1,) )
|
||||||
|
|
||||||
|
inp_roll_bins_t = Input ( (self.class_num,) )
|
||||||
|
inp_roll_t = Input ( (1,) )
|
||||||
|
|
||||||
|
pitch_loss = K.categorical_crossentropy(inp_pitch_bins_t, pitch_bins_t) \
|
||||||
|
+ 0.001 * K.mean(K.square( inp_pitch_t - pitch_t), -1)
|
||||||
|
|
||||||
|
yaw_loss = K.categorical_crossentropy(inp_yaw_bins_t, yaw_bins_t) \
|
||||||
|
+ 0.001 * K.mean(K.square( inp_yaw_t - yaw_t), -1)
|
||||||
|
|
||||||
|
roll_loss = K.categorical_crossentropy(inp_roll_bins_t, roll_bins_t) \
|
||||||
|
+ 0.001 * K.mean(K.square( inp_roll_t - roll_t), -1)
|
||||||
|
|
||||||
|
|
||||||
|
loss = K.mean( pitch_loss + yaw_loss + roll_loss )
|
||||||
|
|
||||||
|
if training:
|
||||||
|
self.train = K.function ([inp_t, inp_pitch_bins_t, inp_pitch_t, inp_yaw_bins_t, inp_yaw_t, inp_roll_bins_t, inp_roll_t],
|
||||||
|
[loss], Adam(tf_cpu_mode=2).get_updates(loss, self.model.trainable_weights) )
|
||||||
|
|
||||||
|
self.view = K.function ([inp_t], [pitch_t, yaw_t, roll_t] )
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
|
||||||
|
return False #pass exception between __enter__ and __exit__ to outter level
|
||||||
|
|
||||||
|
def save_weights(self):
|
||||||
|
self.model.save_weights (str(self.weights_path))
|
||||||
|
|
||||||
|
def train_on_batch(self, imgs, pitch_yaw_roll):
|
||||||
|
c = ( (pitch_yaw_roll+1) * 90.0 ).astype(np.int).astype(K.floatx())
|
||||||
|
|
||||||
|
inp_pitch = c[:,0:1]
|
||||||
|
inp_yaw = c[:,1:2]
|
||||||
|
inp_roll = c[:,2:3]
|
||||||
|
|
||||||
|
inp_pitch_bins = keras.utils.to_categorical(inp_pitch, self.class_num )
|
||||||
|
inp_yaw_bins = keras.utils.to_categorical(inp_yaw, self.class_num )
|
||||||
|
inp_roll_bins = keras.utils.to_categorical(inp_roll, self.class_num )
|
||||||
|
|
||||||
|
loss, = self.train( [imgs, inp_pitch_bins, inp_pitch, inp_yaw_bins, inp_yaw, inp_roll_bins, inp_roll] )
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def extract (self, input_image, is_input_tanh=False):
|
||||||
|
if is_input_tanh:
|
||||||
|
raise NotImplemented("is_input_tanh")
|
||||||
|
|
||||||
|
input_shape_len = len(input_image.shape)
|
||||||
|
if input_shape_len == 3:
|
||||||
|
input_image = input_image[np.newaxis,...]
|
||||||
|
|
||||||
|
pitch, yaw, roll = self.view( [input_image] )
|
||||||
|
result = np.concatenate( (pitch[...,np.newaxis], yaw[...,np.newaxis], roll[...,np.newaxis]), -1 )
|
||||||
|
result = np.clip ( result / 90.0 - 1, -1, 1 )
|
||||||
|
|
||||||
|
if input_shape_len == 3:
|
||||||
|
result = result[0]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def BuildModel ( resolution, class_num):
|
||||||
|
exec( nnlib.import_all(), locals(), globals() )
|
||||||
|
inp = Input ( (resolution,resolution,3) )
|
||||||
|
x = inp
|
||||||
|
x = PoseEstimator.Flow(class_num=class_num)(x)
|
||||||
|
model = Model(inp,x)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Flow(class_num):
|
||||||
|
exec( nnlib.import_all(), locals(), globals() )
|
||||||
|
|
||||||
|
def func(input):
|
||||||
|
x = input
|
||||||
|
|
||||||
|
x = Conv2D(64, kernel_size=11, strides=4, padding='same', activation='relu')(x)
|
||||||
|
x = MaxPooling2D( (3,3), strides=2 )(x)
|
||||||
|
|
||||||
|
x = Conv2D(192, kernel_size=5, strides=1, padding='same', activation='relu')(x)
|
||||||
|
x = MaxPooling2D( (3,3), strides=2 )(x)
|
||||||
|
|
||||||
|
x = Conv2D(384, kernel_size=3, strides=1, padding='same', activation='relu')(x)
|
||||||
|
x = Conv2D(256, kernel_size=3, strides=1, padding='same', activation='relu')(x)
|
||||||
|
x = Conv2D(256, kernel_size=3, strides=1, padding='same', activation='relu')(x)
|
||||||
|
x = MaxPooling2D( (3,3), strides=2 )(x)
|
||||||
|
|
||||||
|
x = Flatten()(x)
|
||||||
|
x = Dense(4096, activation='relu')(x)
|
||||||
|
x = Dropout(0.5)(x)
|
||||||
|
x = Dense(4096, activation='relu')(x)
|
||||||
|
x = Dropout(0.5)(x)
|
||||||
|
x = Dense(1000, activation='relu')(x)
|
||||||
|
|
||||||
|
pitch = Dense(class_num, activation='softmax', name='pitch')(x)
|
||||||
|
yaw = Dense(class_num, activation='softmax', name='yaw')(x)
|
||||||
|
roll = Dense(class_num, activation='softmax', name='roll')(x)
|
||||||
|
|
||||||
|
return [pitch, yaw, roll]
|
||||||
|
|
||||||
|
|
||||||
|
return func
|
|
@ -4,3 +4,4 @@ from .MTCExtractor import MTCExtractor
|
||||||
from .S3FDExtractor import S3FDExtractor
|
from .S3FDExtractor import S3FDExtractor
|
||||||
from .LandmarksExtractor import LandmarksExtractor
|
from .LandmarksExtractor import LandmarksExtractor
|
||||||
from .FANSegmentator import FANSegmentator
|
from .FANSegmentator import FANSegmentator
|
||||||
|
from .PoseEstimator import PoseEstimator
|
15
main.py
15
main.py
|
@ -49,6 +49,21 @@ if __name__ == "__main__":
|
||||||
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU. Forces to use MT extractor.")
|
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU. Forces to use MT extractor.")
|
||||||
p.set_defaults (func=process_extract)
|
p.set_defaults (func=process_extract)
|
||||||
|
|
||||||
|
|
||||||
|
def process_dev_extract_umd_csv(arguments):
|
||||||
|
os_utils.set_process_lowest_prio()
|
||||||
|
from mainscripts import Extractor
|
||||||
|
Extractor.extract_umd_csv( arguments.input_csv_file,
|
||||||
|
device_args={'cpu_only' : arguments.cpu_only,
|
||||||
|
'multi_gpu' : arguments.multi_gpu,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
p = subparsers.add_parser( "dev_extract_umd_csv", help="")
|
||||||
|
p.add_argument('--input-csv-file', required=True, action=fixPathAction, dest="input_csv_file", help="input_csv_file")
|
||||||
|
p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.")
|
||||||
|
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU.")
|
||||||
|
p.set_defaults (func=process_dev_extract_umd_csv)
|
||||||
"""
|
"""
|
||||||
def process_extract_fanseg(arguments):
|
def process_extract_fanseg(arguments):
|
||||||
os_utils.set_process_lowest_prio()
|
os_utils.set_process_lowest_prio()
|
||||||
|
|
|
@ -23,12 +23,13 @@ from interact import interact as io
|
||||||
|
|
||||||
class ExtractSubprocessor(Subprocessor):
|
class ExtractSubprocessor(Subprocessor):
|
||||||
class Data(object):
|
class Data(object):
|
||||||
def __init__(self, filename=None, rects=None, landmarks = None, landmarks_accurate=True, final_output_files = None):
|
def __init__(self, filename=None, rects=None, landmarks = None, landmarks_accurate=True, pitch_yaw_roll=None, final_output_files = None):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
self.rects = rects or []
|
self.rects = rects or []
|
||||||
self.rects_rotation = 0
|
self.rects_rotation = 0
|
||||||
self.landmarks_accurate = landmarks_accurate
|
self.landmarks_accurate = landmarks_accurate
|
||||||
self.landmarks = landmarks or []
|
self.landmarks = landmarks or []
|
||||||
|
self.pitch_yaw_roll = pitch_yaw_roll
|
||||||
self.final_output_files = final_output_files or []
|
self.final_output_files = final_output_files or []
|
||||||
self.faces_detected = 0
|
self.faces_detected = 0
|
||||||
|
|
||||||
|
@ -251,7 +252,8 @@ class ExtractSubprocessor(Subprocessor):
|
||||||
source_filename=filename_path.name,
|
source_filename=filename_path.name,
|
||||||
source_rect=rect,
|
source_rect=rect,
|
||||||
source_landmarks=image_landmarks.tolist(),
|
source_landmarks=image_landmarks.tolist(),
|
||||||
image_to_face_mat=image_to_face_mat
|
image_to_face_mat=image_to_face_mat,
|
||||||
|
pitch_yaw_roll=data.pitch_yaw_roll
|
||||||
)
|
)
|
||||||
|
|
||||||
data.final_output_files.append (output_file)
|
data.final_output_files.append (output_file)
|
||||||
|
@ -701,6 +703,82 @@ def extract_fanseg(input_dir, device_args={} ):
|
||||||
io.log_info ("Performing extract fanseg for %d files..." % (paths_to_extract_len) )
|
io.log_info ("Performing extract fanseg for %d files..." % (paths_to_extract_len) )
|
||||||
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(filename) for filename in paths_to_extract ], 'fanseg', multi_gpu=multi_gpu, cpu_only=cpu_only).run()
|
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(filename) for filename in paths_to_extract ], 'fanseg', multi_gpu=multi_gpu, cpu_only=cpu_only).run()
|
||||||
|
|
||||||
|
def extract_umd_csv(input_file_csv,
|
||||||
|
image_size=256,
|
||||||
|
face_type='full_face',
|
||||||
|
device_args={} ):
|
||||||
|
|
||||||
|
#extract faces from umdfaces.io dataset csv file with pitch,yaw,roll info.
|
||||||
|
multi_gpu = device_args.get('multi_gpu', False)
|
||||||
|
cpu_only = device_args.get('cpu_only', False)
|
||||||
|
face_type = FaceType.fromString(face_type)
|
||||||
|
|
||||||
|
input_file_csv_path = Path(input_file_csv)
|
||||||
|
if not input_file_csv_path.exists():
|
||||||
|
raise ValueError('input_file_csv not found. Please ensure it exists.')
|
||||||
|
|
||||||
|
input_file_csv_root_path = input_file_csv_path.parent
|
||||||
|
output_path = input_file_csv_path.parent / ('aligned_' + input_file_csv_path.name)
|
||||||
|
|
||||||
|
io.log_info("Output dir is %s." % (str(output_path)) )
|
||||||
|
|
||||||
|
if output_path.exists():
|
||||||
|
output_images_paths = Path_utils.get_image_paths(output_path)
|
||||||
|
if len(output_images_paths) > 0:
|
||||||
|
io.input_bool("WARNING !!! \n %s contains files! \n They will be deleted. \n Press enter to continue." % (str(output_path)), False )
|
||||||
|
for filename in output_images_paths:
|
||||||
|
Path(filename).unlink()
|
||||||
|
else:
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open( str(input_file_csv_path), 'r') as f:
|
||||||
|
csv_file = f.read()
|
||||||
|
except Exception as e:
|
||||||
|
io.log_err("Unable to open or read file " + str(input_file_csv_path) + ": " + str(e) )
|
||||||
|
return
|
||||||
|
|
||||||
|
strings = csv_file.split('\n')
|
||||||
|
keys = strings[0].split(',')
|
||||||
|
keys_len = len(keys)
|
||||||
|
csv_data = []
|
||||||
|
for i in range(1, len(strings)):
|
||||||
|
values = strings[i].split(',')
|
||||||
|
if keys_len != len(values):
|
||||||
|
io.log_err("Wrong string in csv file, skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
csv_data += [ { keys[n] : values[n] for n in range(keys_len) } ]
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for d in csv_data:
|
||||||
|
filename = input_file_csv_root_path / d['FILE']
|
||||||
|
|
||||||
|
pitch, yaw, roll = float(d['PITCH']), float(d['YAW']), float(d['ROLL'])
|
||||||
|
if pitch < -90 or pitch > 90 or yaw < -90 or yaw > 90 or roll < -90 or roll > 90:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pitch_yaw_roll = pitch/90.0, yaw/90.0, roll/90.0
|
||||||
|
|
||||||
|
x,y,w,h = float(d['FACE_X']), float(d['FACE_Y']), float(d['FACE_WIDTH']), float(d['FACE_HEIGHT'])
|
||||||
|
|
||||||
|
data += [ ExtractSubprocessor.Data(filename=filename, rects=[ [x,y,x+w,y+h] ], pitch_yaw_roll=pitch_yaw_roll) ]
|
||||||
|
|
||||||
|
images_found = len(data)
|
||||||
|
faces_detected = 0
|
||||||
|
if len(data) > 0:
|
||||||
|
io.log_info ("Performing 2nd pass from csv file...")
|
||||||
|
data = ExtractSubprocessor (data, 'landmarks', multi_gpu=multi_gpu, cpu_only=cpu_only).run()
|
||||||
|
|
||||||
|
io.log_info ('Performing 3rd pass...')
|
||||||
|
data = ExtractSubprocessor (data, 'final', image_size, face_type, None, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, final_output_path=output_path).run()
|
||||||
|
faces_detected += sum([d.faces_detected for d in data])
|
||||||
|
|
||||||
|
|
||||||
|
io.log_info ('-------------------------')
|
||||||
|
io.log_info ('Images found: %d' % (images_found) )
|
||||||
|
io.log_info ('Faces detected: %d' % (faces_detected) )
|
||||||
|
io.log_info ('-------------------------')
|
||||||
|
|
||||||
def main(input_dir,
|
def main(input_dir,
|
||||||
output_dir,
|
output_dir,
|
||||||
|
|
|
@ -202,7 +202,11 @@ def sort_by_face_yaw(input_path):
|
||||||
trash_img_list.append ( [str(filepath)] )
|
trash_img_list.append ( [str(filepath)] )
|
||||||
continue
|
continue
|
||||||
|
|
||||||
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() )
|
pitch_yaw_roll = dflimg.get_pitch_yaw_roll()
|
||||||
|
if pitch_yaw_roll is not None:
|
||||||
|
pitch, yaw, roll = pitch_yaw_roll
|
||||||
|
else:
|
||||||
|
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks() )
|
||||||
|
|
||||||
img_list.append( [str(filepath), yaw ] )
|
img_list.append( [str(filepath), yaw ] )
|
||||||
|
|
||||||
|
@ -230,7 +234,11 @@ def sort_by_face_pitch(input_path):
|
||||||
trash_img_list.append ( [str(filepath)] )
|
trash_img_list.append ( [str(filepath)] )
|
||||||
continue
|
continue
|
||||||
|
|
||||||
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() )
|
pitch_yaw_roll = dflimg.get_pitch_yaw_roll()
|
||||||
|
if pitch_yaw_roll is not None:
|
||||||
|
pitch, yaw, roll = pitch_yaw_roll
|
||||||
|
else:
|
||||||
|
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks() )
|
||||||
|
|
||||||
img_list.append( [str(filepath), pitch ] )
|
img_list.append( [str(filepath), pitch ] )
|
||||||
|
|
||||||
|
@ -532,7 +540,7 @@ class FinalLoaderSubprocessor(Subprocessor):
|
||||||
|
|
||||||
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
|
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
|
||||||
sharpness = estimate_sharpness(gray) if self.include_by_blur else 0
|
sharpness = estimate_sharpness(gray) if self.include_by_blur else 0
|
||||||
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() )
|
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks() )
|
||||||
|
|
||||||
hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
|
hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
108
models/Model_DEV_POSEEST/Model.py
Normal file
108
models/Model_DEV_POSEEST/Model.py
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from nnlib import nnlib
|
||||||
|
from models import ModelBase
|
||||||
|
from facelib import FaceType
|
||||||
|
from facelib import PoseEstimator
|
||||||
|
from samplelib import *
|
||||||
|
from interact import interact as io
|
||||||
|
import imagelib
|
||||||
|
|
||||||
|
class Model(ModelBase):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs,
|
||||||
|
ask_write_preview_history=False,
|
||||||
|
ask_target_iter=False,
|
||||||
|
ask_sort_by_yaw=False,
|
||||||
|
ask_random_flip=False,
|
||||||
|
ask_src_scale_mod=False)
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onInitializeOptions(self, is_first_run, ask_override):
|
||||||
|
default_face_type = 'f'
|
||||||
|
if is_first_run:
|
||||||
|
self.options['face_type'] = io.input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
|
||||||
|
else:
|
||||||
|
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onInitialize(self):
|
||||||
|
exec(nnlib.import_all(), locals(), globals())
|
||||||
|
self.set_vram_batch_requirements( {4:64} )
|
||||||
|
|
||||||
|
self.resolution = 227
|
||||||
|
self.face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
|
||||||
|
|
||||||
|
|
||||||
|
self.pose_est = PoseEstimator(self.resolution,
|
||||||
|
FaceType.toString(self.face_type),
|
||||||
|
load_weights=not self.is_first_run(),
|
||||||
|
weights_file_root=self.get_model_root_path(),
|
||||||
|
training=True)
|
||||||
|
|
||||||
|
if self.is_training_mode:
|
||||||
|
f = SampleProcessor.TypeFlags
|
||||||
|
face_type = f.FACE_TYPE_FULL if self.options['face_type'] == 'f' else f.FACE_TYPE_HALF
|
||||||
|
|
||||||
|
self.set_training_data_generators ([
|
||||||
|
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
|
sample_process_options=SampleProcessor.Options( motion_blur = [25, 1] ), #random_flip=True,
|
||||||
|
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE | f.OPT_APPLY_MOTION_BLUR, self.resolution],
|
||||||
|
[f.PITCH_YAW_ROLL],
|
||||||
|
]),
|
||||||
|
|
||||||
|
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
|
sample_process_options=SampleProcessor.Options(), #random_flip=True,
|
||||||
|
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE, self.resolution],
|
||||||
|
[f.PITCH_YAW_ROLL],
|
||||||
|
])
|
||||||
|
])
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onSave(self):
|
||||||
|
self.pose_est.save_weights()
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onTrainOneIter(self, generators_samples, generators_list):
|
||||||
|
target_src, pitch_yaw_roll = generators_samples[0]
|
||||||
|
|
||||||
|
loss = self.pose_est.train_on_batch( target_src, pitch_yaw_roll )
|
||||||
|
|
||||||
|
return ( ('loss', loss), )
|
||||||
|
|
||||||
|
#override
|
||||||
|
def onGetPreview(self, generators_samples):
|
||||||
|
test_src = generators_samples[0][0][0:4] #first 4 samples
|
||||||
|
test_pyr_src = generators_samples[0][1][0:4]
|
||||||
|
test_dst = generators_samples[1][0][0:4]
|
||||||
|
test_pyr_dst = generators_samples[1][1][0:4]
|
||||||
|
|
||||||
|
h,w,c = self.resolution,self.resolution,3
|
||||||
|
h_line = 13
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for name, img, pyr in [ ['training data', test_src, test_pyr_src], \
|
||||||
|
['evaluating data',test_dst, test_pyr_dst] ]:
|
||||||
|
pyr_pred = self.pose_est.extract(img)
|
||||||
|
|
||||||
|
hor_imgs = []
|
||||||
|
for i in range(len(img)):
|
||||||
|
img_info = np.ones ( (h,w,c) ) * 0.1
|
||||||
|
lines = ["%s" % ( str(pyr[i]) ),
|
||||||
|
"%s" % ( str(pyr_pred[i]) ) ]
|
||||||
|
|
||||||
|
lines_count = len(lines)
|
||||||
|
for ln in range(lines_count):
|
||||||
|
img_info[ ln*h_line:(ln+1)*h_line, 0:w] += \
|
||||||
|
imagelib.get_text_image ( (h_line,w,c), lines[ln], color=[0.8]*c )
|
||||||
|
|
||||||
|
hor_imgs.append ( np.concatenate ( (
|
||||||
|
img[i,:,:,0:3],
|
||||||
|
img_info
|
||||||
|
), axis=1) )
|
||||||
|
|
||||||
|
|
||||||
|
result += [ (name, np.concatenate (hor_imgs, axis=0)) ]
|
||||||
|
|
||||||
|
return result
|
1
models/Model_DEV_POSEEST/__init__.py
Normal file
1
models/Model_DEV_POSEEST/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .Model import Model
|
|
@ -16,26 +16,26 @@ class SampleType(IntEnum):
|
||||||
FACE = 1 #aligned face unsorted
|
FACE = 1 #aligned face unsorted
|
||||||
FACE_YAW_SORTED = 2 #sorted by yaw
|
FACE_YAW_SORTED = 2 #sorted by yaw
|
||||||
FACE_YAW_SORTED_AS_TARGET = 3 #sorted by yaw and included only yaws which exist in TARGET also automatic mirrored
|
FACE_YAW_SORTED_AS_TARGET = 3 #sorted by yaw and included only yaws which exist in TARGET also automatic mirrored
|
||||||
FACE_END = 3
|
FACE_TEMPORAL_SORTED = 4
|
||||||
|
FACE_END = 4
|
||||||
|
|
||||||
QTY = 4
|
QTY = 5
|
||||||
|
|
||||||
class Sample(object):
|
class Sample(object):
|
||||||
def __init__(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch=None, yaw=None, source_filename=None, mirror=None, close_target_list=None, fanseg_mask_exist=False):
|
def __init__(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch_yaw_roll=None, source_filename=None, mirror=None, close_target_list=None, fanseg_mask_exist=False):
|
||||||
self.sample_type = sample_type if sample_type is not None else SampleType.IMAGE
|
self.sample_type = sample_type if sample_type is not None else SampleType.IMAGE
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
self.face_type = face_type
|
self.face_type = face_type
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
self.landmarks = np.array(landmarks) if landmarks is not None else None
|
self.landmarks = np.array(landmarks) if landmarks is not None else None
|
||||||
self.ie_polys = ie_polys
|
self.ie_polys = ie_polys
|
||||||
self.pitch = pitch
|
self.pitch_yaw_roll = pitch_yaw_roll
|
||||||
self.yaw = yaw
|
|
||||||
self.source_filename = source_filename
|
self.source_filename = source_filename
|
||||||
self.mirror = mirror
|
self.mirror = mirror
|
||||||
self.close_target_list = close_target_list
|
self.close_target_list = close_target_list
|
||||||
self.fanseg_mask_exist = fanseg_mask_exist
|
self.fanseg_mask_exist = fanseg_mask_exist
|
||||||
|
|
||||||
def copy_and_set(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch=None, yaw=None, source_filename=None, mirror=None, close_target_list=None, fanseg_mask=None, fanseg_mask_exist=None):
|
def copy_and_set(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch_yaw_roll=None, source_filename=None, mirror=None, close_target_list=None, fanseg_mask=None, fanseg_mask_exist=None):
|
||||||
return Sample(
|
return Sample(
|
||||||
sample_type=sample_type if sample_type is not None else self.sample_type,
|
sample_type=sample_type if sample_type is not None else self.sample_type,
|
||||||
filename=filename if filename is not None else self.filename,
|
filename=filename if filename is not None else self.filename,
|
||||||
|
@ -43,8 +43,7 @@ class Sample(object):
|
||||||
shape=shape if shape is not None else self.shape,
|
shape=shape if shape is not None else self.shape,
|
||||||
landmarks=landmarks if landmarks is not None else self.landmarks.copy(),
|
landmarks=landmarks if landmarks is not None else self.landmarks.copy(),
|
||||||
ie_polys=ie_polys if ie_polys is not None else self.ie_polys,
|
ie_polys=ie_polys if ie_polys is not None else self.ie_polys,
|
||||||
pitch=pitch if pitch is not None else self.pitch,
|
pitch_yaw_roll=pitch_yaw_roll if pitch_yaw_roll is not None else self.pitch_yaw_roll,
|
||||||
yaw=yaw if yaw is not None else self.yaw,
|
|
||||||
source_filename=source_filename if source_filename is not None else self.source_filename,
|
source_filename=source_filename if source_filename is not None else self.source_filename,
|
||||||
mirror=mirror if mirror is not None else self.mirror,
|
mirror=mirror if mirror is not None else self.mirror,
|
||||||
close_target_list=close_target_list if close_target_list is not None else self.close_target_list,
|
close_target_list=close_target_list if close_target_list is not None else self.close_target_list,
|
||||||
|
|
|
@ -15,13 +15,12 @@ output_sample_types = [
|
||||||
]
|
]
|
||||||
'''
|
'''
|
||||||
class SampleGeneratorFace(SampleGeneratorBase):
|
class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, add_pitch=False, add_yaw=False, generators_count=2, generators_random_seed=None, **kwargs):
|
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
|
||||||
super().__init__(samples_path, debug, batch_size)
|
super().__init__(samples_path, debug, batch_size)
|
||||||
self.sample_process_options = sample_process_options
|
self.sample_process_options = sample_process_options
|
||||||
self.output_sample_types = output_sample_types
|
self.output_sample_types = output_sample_types
|
||||||
self.add_sample_idx = add_sample_idx
|
self.add_sample_idx = add_sample_idx
|
||||||
self.add_pitch = add_pitch
|
# self.add_pitch_yaw_roll = add_pitch_yaw_roll
|
||||||
self.add_yaw = add_yaw
|
|
||||||
|
|
||||||
if sort_by_yaw_target_samples_path is not None:
|
if sort_by_yaw_target_samples_path is not None:
|
||||||
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
|
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
|
||||||
|
@ -143,12 +142,6 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
if self.add_sample_idx:
|
if self.add_sample_idx:
|
||||||
batches += [ [] ]
|
batches += [ [] ]
|
||||||
i_sample_idx = len(batches)-1
|
i_sample_idx = len(batches)-1
|
||||||
if self.add_pitch:
|
|
||||||
batches += [ [] ]
|
|
||||||
i_pitch = len(batches)-1
|
|
||||||
if self.add_yaw:
|
|
||||||
batches += [ [] ]
|
|
||||||
i_yaw = len(batches)-1
|
|
||||||
|
|
||||||
for i in range(len(x)):
|
for i in range(len(x)):
|
||||||
batches[i].append ( x[i] )
|
batches[i].append ( x[i] )
|
||||||
|
@ -156,14 +149,5 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
if self.add_sample_idx:
|
if self.add_sample_idx:
|
||||||
batches[i_sample_idx].append (idx)
|
batches[i_sample_idx].append (idx)
|
||||||
|
|
||||||
if self.add_pitch or self.add_yaw:
|
|
||||||
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw (sample.landmarks)
|
|
||||||
|
|
||||||
if self.add_pitch:
|
|
||||||
batches[i_pitch].append ([pitch])
|
|
||||||
|
|
||||||
if self.add_yaw:
|
|
||||||
batches[i_yaw].append ([yaw])
|
|
||||||
|
|
||||||
break
|
break
|
||||||
yield [ np.array(batch) for batch in batches]
|
yield [ np.array(batch) for batch in batches]
|
||||||
|
|
|
@ -35,9 +35,9 @@ class SampleLoader:
|
||||||
if datas[sample_type] is None:
|
if datas[sample_type] is None:
|
||||||
datas[sample_type] = SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename) for filename in Path_utils.get_image_paths(samples_path) ] )
|
datas[sample_type] = SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename) for filename in Path_utils.get_image_paths(samples_path) ] )
|
||||||
|
|
||||||
# elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||||
# if datas[sample_type] is None:
|
if datas[sample_type] is None:
|
||||||
# datas[sample_type] = SampleLoader.upgradeToFaceTemporalSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) )
|
datas[sample_type] = SampleLoader.upgradeToFaceTemporalSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) )
|
||||||
|
|
||||||
elif sample_type == SampleType.FACE_YAW_SORTED:
|
elif sample_type == SampleType.FACE_YAW_SORTED:
|
||||||
if datas[sample_type] is None:
|
if datas[sample_type] is None:
|
||||||
|
@ -69,15 +69,12 @@ class SampleLoader:
|
||||||
print ("%s is not a dfl image file required for training" % (s_filename_path.name) )
|
print ("%s is not a dfl image file required for training" % (s_filename_path.name) )
|
||||||
continue
|
continue
|
||||||
|
|
||||||
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() )
|
|
||||||
|
|
||||||
sample_list.append( s.copy_and_set(sample_type=SampleType.FACE,
|
sample_list.append( s.copy_and_set(sample_type=SampleType.FACE,
|
||||||
face_type=FaceType.fromString (dflimg.get_face_type()),
|
face_type=FaceType.fromString (dflimg.get_face_type()),
|
||||||
shape=dflimg.get_shape(),
|
shape=dflimg.get_shape(),
|
||||||
landmarks=dflimg.get_landmarks(),
|
landmarks=dflimg.get_landmarks(),
|
||||||
ie_polys=dflimg.get_ie_polys(),
|
ie_polys=dflimg.get_ie_polys(),
|
||||||
pitch=pitch,
|
pitch_yaw_roll=dflimg.get_pitch_yaw_roll(),
|
||||||
yaw=yaw,
|
|
||||||
source_filename=dflimg.get_source_filename(),
|
source_filename=dflimg.get_source_filename(),
|
||||||
fanseg_mask_exist=dflimg.get_fanseg_mask() is not None, ) )
|
fanseg_mask_exist=dflimg.get_fanseg_mask() is not None, ) )
|
||||||
except:
|
except:
|
||||||
|
@ -85,12 +82,12 @@ class SampleLoader:
|
||||||
|
|
||||||
return sample_list
|
return sample_list
|
||||||
|
|
||||||
# @staticmethod
|
@staticmethod
|
||||||
# def upgradeToFaceTemporalSortedSamples( samples ):
|
def upgradeToFaceTemporalSortedSamples( samples ):
|
||||||
# new_s = [ (s, s.source_filename) for s in samples]
|
new_s = [ (s, s.source_filename) for s in samples]
|
||||||
# new_s = sorted(new_s, key=operator.itemgetter(1))
|
new_s = sorted(new_s, key=operator.itemgetter(1))
|
||||||
|
|
||||||
# return [ s[0] for s in new_s]
|
return [ s[0] for s in new_s]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def upgradeToFaceYawSortedSamples( samples ):
|
def upgradeToFaceYawSortedSamples( samples ):
|
||||||
|
|
|
@ -13,9 +13,10 @@ class SampleProcessor(object):
|
||||||
WARPED_TRANSFORMED = 0x00000004,
|
WARPED_TRANSFORMED = 0x00000004,
|
||||||
TRANSFORMED = 0x00000008,
|
TRANSFORMED = 0x00000008,
|
||||||
LANDMARKS_ARRAY = 0x00000010, #currently unused
|
LANDMARKS_ARRAY = 0x00000010, #currently unused
|
||||||
|
PITCH_YAW_ROLL = 0x00000020,
|
||||||
|
|
||||||
RANDOM_CLOSE = 0x00000020, #currently unused
|
RANDOM_CLOSE = 0x00000040, #currently unused
|
||||||
MORPH_TO_RANDOM_CLOSE = 0x00000040, #currently unused
|
MORPH_TO_RANDOM_CLOSE = 0x00000080, #currently unused
|
||||||
|
|
||||||
FACE_TYPE_HALF = 0x00000100,
|
FACE_TYPE_HALF = 0x00000100,
|
||||||
FACE_TYPE_FULL = 0x00000200,
|
FACE_TYPE_FULL = 0x00000200,
|
||||||
|
@ -77,7 +78,7 @@ class SampleProcessor(object):
|
||||||
outputs = []
|
outputs = []
|
||||||
for sample_type in output_sample_types:
|
for sample_type in output_sample_types:
|
||||||
f = sample_type[0]
|
f = sample_type[0]
|
||||||
size = sample_type[1]
|
size = 0 if len (sample_type) < 2 else sample_type[1]
|
||||||
random_sub_size = 0 if len (sample_type) < 3 else min( sample_type[2] , size)
|
random_sub_size = 0 if len (sample_type) < 3 else min( sample_type[2] , size)
|
||||||
|
|
||||||
if f & SPTF.SOURCE != 0:
|
if f & SPTF.SOURCE != 0:
|
||||||
|
@ -90,6 +91,8 @@ class SampleProcessor(object):
|
||||||
img_type = 3
|
img_type = 3
|
||||||
elif f & SPTF.LANDMARKS_ARRAY != 0:
|
elif f & SPTF.LANDMARKS_ARRAY != 0:
|
||||||
img_type = 4
|
img_type = 4
|
||||||
|
elif f & SPTF.PITCH_YAW_ROLL != 0:
|
||||||
|
img_type = 5
|
||||||
else:
|
else:
|
||||||
raise ValueError ('expected SampleTypeFlags type')
|
raise ValueError ('expected SampleTypeFlags type')
|
||||||
|
|
||||||
|
@ -121,6 +124,16 @@ class SampleProcessor(object):
|
||||||
l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 )
|
l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 )
|
||||||
l = np.clip(l, 0.0, 1.0)
|
l = np.clip(l, 0.0, 1.0)
|
||||||
img = l
|
img = l
|
||||||
|
elif img_type == 5:
|
||||||
|
pitch_yaw_roll = sample.pitch_yaw_roll
|
||||||
|
if pitch_yaw_roll is not None:
|
||||||
|
pitch, yaw, roll = pitch_yaw_roll
|
||||||
|
else:
|
||||||
|
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll (sample.landmarks)
|
||||||
|
if params['flip']:
|
||||||
|
yaw = -yaw
|
||||||
|
|
||||||
|
img = (pitch, yaw, roll)
|
||||||
else:
|
else:
|
||||||
if images[img_type][face_mask_type] is None:
|
if images[img_type][face_mask_type] is None:
|
||||||
if img_type >= 10 and img_type <= 19: #RANDOM_CLOSE
|
if img_type >= 10 and img_type <= 19: #RANDOM_CLOSE
|
||||||
|
|
|
@ -4,4 +4,5 @@ from .SampleLoader import SampleLoader
|
||||||
from .SampleProcessor import SampleProcessor
|
from .SampleProcessor import SampleProcessor
|
||||||
from .SampleGeneratorBase import SampleGeneratorBase
|
from .SampleGeneratorBase import SampleGeneratorBase
|
||||||
from .SampleGeneratorFace import SampleGeneratorFace
|
from .SampleGeneratorFace import SampleGeneratorFace
|
||||||
|
from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
|
||||||
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal
|
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal
|
||||||
|
|
|
@ -167,7 +167,9 @@ class DFLJPG(object):
|
||||||
source_rect=None,
|
source_rect=None,
|
||||||
source_landmarks=None,
|
source_landmarks=None,
|
||||||
image_to_face_mat=None,
|
image_to_face_mat=None,
|
||||||
fanseg_mask=None, **kwargs
|
fanseg_mask=None,
|
||||||
|
pitch_yaw_roll=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
if fanseg_mask is not None:
|
if fanseg_mask is not None:
|
||||||
|
@ -191,6 +193,7 @@ class DFLJPG(object):
|
||||||
'source_landmarks': source_landmarks,
|
'source_landmarks': source_landmarks,
|
||||||
'image_to_face_mat': image_to_face_mat,
|
'image_to_face_mat': image_to_face_mat,
|
||||||
'fanseg_mask' : fanseg_mask,
|
'fanseg_mask' : fanseg_mask,
|
||||||
|
'pitch_yaw_roll' : pitch_yaw_roll
|
||||||
})
|
})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -206,7 +209,9 @@ class DFLJPG(object):
|
||||||
source_rect=None,
|
source_rect=None,
|
||||||
source_landmarks=None,
|
source_landmarks=None,
|
||||||
image_to_face_mat=None,
|
image_to_face_mat=None,
|
||||||
fanseg_mask=None, **kwargs
|
fanseg_mask=None,
|
||||||
|
pitch_yaw_roll=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
if face_type is None: face_type = self.get_face_type()
|
if face_type is None: face_type = self.get_face_type()
|
||||||
if landmarks is None: landmarks = self.get_landmarks()
|
if landmarks is None: landmarks = self.get_landmarks()
|
||||||
|
@ -216,6 +221,7 @@ class DFLJPG(object):
|
||||||
if source_landmarks is None: source_landmarks = self.get_source_landmarks()
|
if source_landmarks is None: source_landmarks = self.get_source_landmarks()
|
||||||
if image_to_face_mat is None: image_to_face_mat = self.get_image_to_face_mat()
|
if image_to_face_mat is None: image_to_face_mat = self.get_image_to_face_mat()
|
||||||
if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask()
|
if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask()
|
||||||
|
if pitch_yaw_roll is None: pitch_yaw_roll = self.get_pitch_yaw_roll()
|
||||||
DFLJPG.embed_data (filename, face_type=face_type,
|
DFLJPG.embed_data (filename, face_type=face_type,
|
||||||
landmarks=landmarks,
|
landmarks=landmarks,
|
||||||
ie_polys=ie_polys,
|
ie_polys=ie_polys,
|
||||||
|
@ -223,7 +229,8 @@ class DFLJPG(object):
|
||||||
source_rect=source_rect,
|
source_rect=source_rect,
|
||||||
source_landmarks=source_landmarks,
|
source_landmarks=source_landmarks,
|
||||||
image_to_face_mat=image_to_face_mat,
|
image_to_face_mat=image_to_face_mat,
|
||||||
fanseg_mask=fanseg_mask)
|
fanseg_mask=fanseg_mask,
|
||||||
|
pitch_yaw_roll=pitch_yaw_roll)
|
||||||
def remove_fanseg_mask(self):
|
def remove_fanseg_mask(self):
|
||||||
self.dfl_dict['fanseg_mask'] = None
|
self.dfl_dict['fanseg_mask'] = None
|
||||||
|
|
||||||
|
@ -291,3 +298,6 @@ class DFLJPG(object):
|
||||||
if fanseg_mask is not None:
|
if fanseg_mask is not None:
|
||||||
return np.clip ( np.array (fanseg_mask) / 255.0, 0.0, 1.0 )[...,np.newaxis]
|
return np.clip ( np.array (fanseg_mask) / 255.0, 0.0, 1.0 )[...,np.newaxis]
|
||||||
return None
|
return None
|
||||||
|
def get_pitch_yaw_roll(self):
|
||||||
|
return self.dfl_dict.get ('pitch_yaw_roll', None)
|
||||||
|
|
||||||
|
|
|
@ -283,7 +283,9 @@ class DFLPNG(object):
|
||||||
source_rect=None,
|
source_rect=None,
|
||||||
source_landmarks=None,
|
source_landmarks=None,
|
||||||
image_to_face_mat=None,
|
image_to_face_mat=None,
|
||||||
fanseg_mask=None, **kwargs
|
fanseg_mask=None,
|
||||||
|
pitch_yaw_roll=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
if fanseg_mask is not None:
|
if fanseg_mask is not None:
|
||||||
|
@ -307,6 +309,7 @@ class DFLPNG(object):
|
||||||
'source_landmarks': source_landmarks,
|
'source_landmarks': source_landmarks,
|
||||||
'image_to_face_mat':image_to_face_mat,
|
'image_to_face_mat':image_to_face_mat,
|
||||||
'fanseg_mask' : fanseg_mask,
|
'fanseg_mask' : fanseg_mask,
|
||||||
|
'pitch_yaw_roll' : pitch_yaw_roll
|
||||||
})
|
})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -322,7 +325,9 @@ class DFLPNG(object):
|
||||||
source_rect=None,
|
source_rect=None,
|
||||||
source_landmarks=None,
|
source_landmarks=None,
|
||||||
image_to_face_mat=None,
|
image_to_face_mat=None,
|
||||||
fanseg_mask=None, **kwargs
|
fanseg_mask=None,
|
||||||
|
pitch_yaw_roll=None,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
if face_type is None: face_type = self.get_face_type()
|
if face_type is None: face_type = self.get_face_type()
|
||||||
if landmarks is None: landmarks = self.get_landmarks()
|
if landmarks is None: landmarks = self.get_landmarks()
|
||||||
|
@ -332,6 +337,7 @@ class DFLPNG(object):
|
||||||
if source_landmarks is None: source_landmarks = self.get_source_landmarks()
|
if source_landmarks is None: source_landmarks = self.get_source_landmarks()
|
||||||
if image_to_face_mat is None: image_to_face_mat = self.get_image_to_face_mat()
|
if image_to_face_mat is None: image_to_face_mat = self.get_image_to_face_mat()
|
||||||
if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask()
|
if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask()
|
||||||
|
if pitch_yaw_roll is None: pitch_yaw_roll = self.get_pitch_yaw_roll()
|
||||||
DFLPNG.embed_data (filename, face_type=face_type,
|
DFLPNG.embed_data (filename, face_type=face_type,
|
||||||
landmarks=landmarks,
|
landmarks=landmarks,
|
||||||
ie_polys=ie_polys,
|
ie_polys=ie_polys,
|
||||||
|
@ -339,7 +345,8 @@ class DFLPNG(object):
|
||||||
source_rect=source_rect,
|
source_rect=source_rect,
|
||||||
source_landmarks=source_landmarks,
|
source_landmarks=source_landmarks,
|
||||||
image_to_face_mat=image_to_face_mat,
|
image_to_face_mat=image_to_face_mat,
|
||||||
fanseg_mask=fanseg_mask)
|
fanseg_mask=fanseg_mask,
|
||||||
|
pitch_yaw_roll=pitch_yaw_roll)
|
||||||
|
|
||||||
def remove_fanseg_mask(self):
|
def remove_fanseg_mask(self):
|
||||||
self.dfl_dict['fanseg_mask'] = None
|
self.dfl_dict['fanseg_mask'] = None
|
||||||
|
@ -397,5 +404,7 @@ class DFLPNG(object):
|
||||||
if fanseg_mask is not None:
|
if fanseg_mask is not None:
|
||||||
return np.clip ( np.array (fanseg_mask) / 255.0, 0.0, 1.0 )[...,np.newaxis]
|
return np.clip ( np.array (fanseg_mask) / 255.0, 0.0, 1.0 )[...,np.newaxis]
|
||||||
return None
|
return None
|
||||||
|
def get_pitch_yaw_roll(self):
|
||||||
|
return self.dfl_dict.get ('pitch_yaw_roll', None)
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "<PNG length={length} chunks={}>".format(len(self.chunks), **self.__dict__)
|
return "<PNG length={length} chunks={}>".format(len(self.chunks), **self.__dict__)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue