initial code to extract umdfaces.io dataset and train pose estimator

This commit is contained in:
iperov 2019-04-23 08:14:09 +04:00
parent 51a917facc
commit e58197ca22
18 changed files with 437 additions and 57 deletions

View file

@ -26,7 +26,7 @@ class FANSegmentator(object):
self.model = FANSegmentator.BuildModel(resolution, ngf=64) self.model = FANSegmentator.BuildModel(resolution, ngf=64)
if weights_file_root: if weights_file_root is not None:
weights_file_root = Path(weights_file_root) weights_file_root = Path(weights_file_root)
else: else:
weights_file_root = Path(__file__).parent weights_file_root = Path(__file__).parent

View file

@ -332,7 +332,7 @@ def calc_face_yaw(landmarks):
return float(r-l) return float(r-l)
#returns pitch,yaw [-1...+1] #returns pitch,yaw [-1...+1]
def estimate_pitch_yaw(aligned_256px_landmarks): def estimate_pitch_yaw_roll(aligned_256px_landmarks):
shape = (256,256) shape = (256,256)
focal_length = shape[1] focal_length = shape[1]
camera_center = (shape[1] / 2, shape[0] / 2) camera_center = (shape[1] / 2, shape[0] / 2)
@ -347,7 +347,8 @@ def estimate_pitch_yaw(aligned_256px_landmarks):
camera_matrix, camera_matrix,
np.zeros((4, 1)) ) np.zeros((4, 1)) )
pitch, yaw, _ = mathlib.rotationMatrixToEulerAngles( cv2.Rodrigues(rotation_vector)[0] ) pitch, yaw, roll = mathlib.rotationMatrixToEulerAngles( cv2.Rodrigues(rotation_vector)[0] )
pitch = np.clip ( pitch*1.25, -1.0, 1.0 ) pitch = np.clip ( pitch*1.25, -1.0, 1.0 )
yaw = np.clip ( yaw*1.25, -1.0, 1.0 ) yaw = np.clip ( yaw*1.25, -1.0, 1.0 )
return pitch, yaw roll = np.clip ( roll*1.25, -1.0, 1.0 )
return pitch, yaw, roll

155
facelib/PoseEstimator.py Normal file
View file

@ -0,0 +1,155 @@
import os
import pickle
from functools import partial
from pathlib import Path
import cv2
import numpy as np
from interact import interact as io
from nnlib import nnlib
"""
PoseEstimator estimates pitch, yaw, roll, from FAN aligned face.
trained on https://www.umdfaces.io
"""
class PoseEstimator(object):
VERSION = 1
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None, training=False):
exec( nnlib.import_all(), locals(), globals() )
self.class_num = 180
self.model = PoseEstimator.BuildModel(resolution, class_num=self.class_num)
if weights_file_root is not None:
weights_file_root = Path(weights_file_root)
else:
weights_file_root = Path(__file__).parent
self.weights_path = weights_file_root / ('PoseEst_%d_%s.h5' % (resolution, face_type_str) )
if load_weights:
self.model.load_weights (str(self.weights_path))
idx_tensor = np.array([idx for idx in range(self.class_num)], dtype=K.floatx() )
idx_tensor = K.constant(idx_tensor)
#inp_t = Input ( (resolution,resolution,3) )
inp_t, = self.model.inputs
pitch_bins_t, yaw_bins_t, roll_bins_t = self.model.outputs
pitch_t, yaw_t, roll_t = K.sum ( pitch_bins_t * idx_tensor, 1), K.sum ( yaw_bins_t * idx_tensor, 1), K.sum ( roll_bins_t * idx_tensor, 1)
inp_pitch_bins_t = Input ( (self.class_num,) )
inp_pitch_t = Input ( (1,) )
inp_yaw_bins_t = Input ( (self.class_num,) )
inp_yaw_t = Input ( (1,) )
inp_roll_bins_t = Input ( (self.class_num,) )
inp_roll_t = Input ( (1,) )
pitch_loss = K.categorical_crossentropy(inp_pitch_bins_t, pitch_bins_t) \
+ 0.001 * K.mean(K.square( inp_pitch_t - pitch_t), -1)
yaw_loss = K.categorical_crossentropy(inp_yaw_bins_t, yaw_bins_t) \
+ 0.001 * K.mean(K.square( inp_yaw_t - yaw_t), -1)
roll_loss = K.categorical_crossentropy(inp_roll_bins_t, roll_bins_t) \
+ 0.001 * K.mean(K.square( inp_roll_t - roll_t), -1)
loss = K.mean( pitch_loss + yaw_loss + roll_loss )
if training:
self.train = K.function ([inp_t, inp_pitch_bins_t, inp_pitch_t, inp_yaw_bins_t, inp_yaw_t, inp_roll_bins_t, inp_roll_t],
[loss], Adam(tf_cpu_mode=2).get_updates(loss, self.model.trainable_weights) )
self.view = K.function ([inp_t], [pitch_t, yaw_t, roll_t] )
def __enter__(self):
return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
return False #pass exception between __enter__ and __exit__ to outter level
def save_weights(self):
self.model.save_weights (str(self.weights_path))
def train_on_batch(self, imgs, pitch_yaw_roll):
c = ( (pitch_yaw_roll+1) * 90.0 ).astype(np.int).astype(K.floatx())
inp_pitch = c[:,0:1]
inp_yaw = c[:,1:2]
inp_roll = c[:,2:3]
inp_pitch_bins = keras.utils.to_categorical(inp_pitch, self.class_num )
inp_yaw_bins = keras.utils.to_categorical(inp_yaw, self.class_num )
inp_roll_bins = keras.utils.to_categorical(inp_roll, self.class_num )
loss, = self.train( [imgs, inp_pitch_bins, inp_pitch, inp_yaw_bins, inp_yaw, inp_roll_bins, inp_roll] )
return loss
def extract (self, input_image, is_input_tanh=False):
if is_input_tanh:
raise NotImplemented("is_input_tanh")
input_shape_len = len(input_image.shape)
if input_shape_len == 3:
input_image = input_image[np.newaxis,...]
pitch, yaw, roll = self.view( [input_image] )
result = np.concatenate( (pitch[...,np.newaxis], yaw[...,np.newaxis], roll[...,np.newaxis]), -1 )
result = np.clip ( result / 90.0 - 1, -1, 1 )
if input_shape_len == 3:
result = result[0]
return result
@staticmethod
def BuildModel ( resolution, class_num):
exec( nnlib.import_all(), locals(), globals() )
inp = Input ( (resolution,resolution,3) )
x = inp
x = PoseEstimator.Flow(class_num=class_num)(x)
model = Model(inp,x)
return model
@staticmethod
def Flow(class_num):
exec( nnlib.import_all(), locals(), globals() )
def func(input):
x = input
x = Conv2D(64, kernel_size=11, strides=4, padding='same', activation='relu')(x)
x = MaxPooling2D( (3,3), strides=2 )(x)
x = Conv2D(192, kernel_size=5, strides=1, padding='same', activation='relu')(x)
x = MaxPooling2D( (3,3), strides=2 )(x)
x = Conv2D(384, kernel_size=3, strides=1, padding='same', activation='relu')(x)
x = Conv2D(256, kernel_size=3, strides=1, padding='same', activation='relu')(x)
x = Conv2D(256, kernel_size=3, strides=1, padding='same', activation='relu')(x)
x = MaxPooling2D( (3,3), strides=2 )(x)
x = Flatten()(x)
x = Dense(4096, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(4096, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(1000, activation='relu')(x)
pitch = Dense(class_num, activation='softmax', name='pitch')(x)
yaw = Dense(class_num, activation='softmax', name='yaw')(x)
roll = Dense(class_num, activation='softmax', name='roll')(x)
return [pitch, yaw, roll]
return func

View file

@ -4,3 +4,4 @@ from .MTCExtractor import MTCExtractor
from .S3FDExtractor import S3FDExtractor from .S3FDExtractor import S3FDExtractor
from .LandmarksExtractor import LandmarksExtractor from .LandmarksExtractor import LandmarksExtractor
from .FANSegmentator import FANSegmentator from .FANSegmentator import FANSegmentator
from .PoseEstimator import PoseEstimator

15
main.py
View file

@ -49,6 +49,21 @@ if __name__ == "__main__":
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU. Forces to use MT extractor.") p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU. Forces to use MT extractor.")
p.set_defaults (func=process_extract) p.set_defaults (func=process_extract)
def process_dev_extract_umd_csv(arguments):
os_utils.set_process_lowest_prio()
from mainscripts import Extractor
Extractor.extract_umd_csv( arguments.input_csv_file,
device_args={'cpu_only' : arguments.cpu_only,
'multi_gpu' : arguments.multi_gpu,
}
)
p = subparsers.add_parser( "dev_extract_umd_csv", help="")
p.add_argument('--input-csv-file', required=True, action=fixPathAction, dest="input_csv_file", help="input_csv_file")
p.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="Enables multi GPU.")
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU.")
p.set_defaults (func=process_dev_extract_umd_csv)
""" """
def process_extract_fanseg(arguments): def process_extract_fanseg(arguments):
os_utils.set_process_lowest_prio() os_utils.set_process_lowest_prio()

View file

@ -23,12 +23,13 @@ from interact import interact as io
class ExtractSubprocessor(Subprocessor): class ExtractSubprocessor(Subprocessor):
class Data(object): class Data(object):
def __init__(self, filename=None, rects=None, landmarks = None, landmarks_accurate=True, final_output_files = None): def __init__(self, filename=None, rects=None, landmarks = None, landmarks_accurate=True, pitch_yaw_roll=None, final_output_files = None):
self.filename = filename self.filename = filename
self.rects = rects or [] self.rects = rects or []
self.rects_rotation = 0 self.rects_rotation = 0
self.landmarks_accurate = landmarks_accurate self.landmarks_accurate = landmarks_accurate
self.landmarks = landmarks or [] self.landmarks = landmarks or []
self.pitch_yaw_roll = pitch_yaw_roll
self.final_output_files = final_output_files or [] self.final_output_files = final_output_files or []
self.faces_detected = 0 self.faces_detected = 0
@ -251,7 +252,8 @@ class ExtractSubprocessor(Subprocessor):
source_filename=filename_path.name, source_filename=filename_path.name,
source_rect=rect, source_rect=rect,
source_landmarks=image_landmarks.tolist(), source_landmarks=image_landmarks.tolist(),
image_to_face_mat=image_to_face_mat image_to_face_mat=image_to_face_mat,
pitch_yaw_roll=data.pitch_yaw_roll
) )
data.final_output_files.append (output_file) data.final_output_files.append (output_file)
@ -701,6 +703,82 @@ def extract_fanseg(input_dir, device_args={} ):
io.log_info ("Performing extract fanseg for %d files..." % (paths_to_extract_len) ) io.log_info ("Performing extract fanseg for %d files..." % (paths_to_extract_len) )
data = ExtractSubprocessor ([ ExtractSubprocessor.Data(filename) for filename in paths_to_extract ], 'fanseg', multi_gpu=multi_gpu, cpu_only=cpu_only).run() data = ExtractSubprocessor ([ ExtractSubprocessor.Data(filename) for filename in paths_to_extract ], 'fanseg', multi_gpu=multi_gpu, cpu_only=cpu_only).run()
def extract_umd_csv(input_file_csv,
image_size=256,
face_type='full_face',
device_args={} ):
#extract faces from umdfaces.io dataset csv file with pitch,yaw,roll info.
multi_gpu = device_args.get('multi_gpu', False)
cpu_only = device_args.get('cpu_only', False)
face_type = FaceType.fromString(face_type)
input_file_csv_path = Path(input_file_csv)
if not input_file_csv_path.exists():
raise ValueError('input_file_csv not found. Please ensure it exists.')
input_file_csv_root_path = input_file_csv_path.parent
output_path = input_file_csv_path.parent / ('aligned_' + input_file_csv_path.name)
io.log_info("Output dir is %s." % (str(output_path)) )
if output_path.exists():
output_images_paths = Path_utils.get_image_paths(output_path)
if len(output_images_paths) > 0:
io.input_bool("WARNING !!! \n %s contains files! \n They will be deleted. \n Press enter to continue." % (str(output_path)), False )
for filename in output_images_paths:
Path(filename).unlink()
else:
output_path.mkdir(parents=True, exist_ok=True)
try:
with open( str(input_file_csv_path), 'r') as f:
csv_file = f.read()
except Exception as e:
io.log_err("Unable to open or read file " + str(input_file_csv_path) + ": " + str(e) )
return
strings = csv_file.split('\n')
keys = strings[0].split(',')
keys_len = len(keys)
csv_data = []
for i in range(1, len(strings)):
values = strings[i].split(',')
if keys_len != len(values):
io.log_err("Wrong string in csv file, skipping.")
continue
csv_data += [ { keys[n] : values[n] for n in range(keys_len) } ]
data = []
for d in csv_data:
filename = input_file_csv_root_path / d['FILE']
pitch, yaw, roll = float(d['PITCH']), float(d['YAW']), float(d['ROLL'])
if pitch < -90 or pitch > 90 or yaw < -90 or yaw > 90 or roll < -90 or roll > 90:
continue
pitch_yaw_roll = pitch/90.0, yaw/90.0, roll/90.0
x,y,w,h = float(d['FACE_X']), float(d['FACE_Y']), float(d['FACE_WIDTH']), float(d['FACE_HEIGHT'])
data += [ ExtractSubprocessor.Data(filename=filename, rects=[ [x,y,x+w,y+h] ], pitch_yaw_roll=pitch_yaw_roll) ]
images_found = len(data)
faces_detected = 0
if len(data) > 0:
io.log_info ("Performing 2nd pass from csv file...")
data = ExtractSubprocessor (data, 'landmarks', multi_gpu=multi_gpu, cpu_only=cpu_only).run()
io.log_info ('Performing 3rd pass...')
data = ExtractSubprocessor (data, 'final', image_size, face_type, None, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, final_output_path=output_path).run()
faces_detected += sum([d.faces_detected for d in data])
io.log_info ('-------------------------')
io.log_info ('Images found: %d' % (images_found) )
io.log_info ('Faces detected: %d' % (faces_detected) )
io.log_info ('-------------------------')
def main(input_dir, def main(input_dir,
output_dir, output_dir,

View file

@ -202,7 +202,11 @@ def sort_by_face_yaw(input_path):
trash_img_list.append ( [str(filepath)] ) trash_img_list.append ( [str(filepath)] )
continue continue
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() ) pitch_yaw_roll = dflimg.get_pitch_yaw_roll()
if pitch_yaw_roll is not None:
pitch, yaw, roll = pitch_yaw_roll
else:
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks() )
img_list.append( [str(filepath), yaw ] ) img_list.append( [str(filepath), yaw ] )
@ -230,7 +234,11 @@ def sort_by_face_pitch(input_path):
trash_img_list.append ( [str(filepath)] ) trash_img_list.append ( [str(filepath)] )
continue continue
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() ) pitch_yaw_roll = dflimg.get_pitch_yaw_roll()
if pitch_yaw_roll is not None:
pitch, yaw, roll = pitch_yaw_roll
else:
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks() )
img_list.append( [str(filepath), pitch ] ) img_list.append( [str(filepath), pitch ] )
@ -532,7 +540,7 @@ class FinalLoaderSubprocessor(Subprocessor):
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
sharpness = estimate_sharpness(gray) if self.include_by_blur else 0 sharpness = estimate_sharpness(gray) if self.include_by_blur else 0
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() ) pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks() )
hist = cv2.calcHist([gray], [0], None, [256], [0, 256]) hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
except Exception as e: except Exception as e:

View file

@ -0,0 +1,108 @@
import numpy as np
from nnlib import nnlib
from models import ModelBase
from facelib import FaceType
from facelib import PoseEstimator
from samplelib import *
from interact import interact as io
import imagelib
class Model(ModelBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs,
ask_write_preview_history=False,
ask_target_iter=False,
ask_sort_by_yaw=False,
ask_random_flip=False,
ask_src_scale_mod=False)
#override
def onInitializeOptions(self, is_first_run, ask_override):
default_face_type = 'f'
if is_first_run:
self.options['face_type'] = io.input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
else:
self.options['face_type'] = self.options.get('face_type', default_face_type)
#override
def onInitialize(self):
exec(nnlib.import_all(), locals(), globals())
self.set_vram_batch_requirements( {4:64} )
self.resolution = 227
self.face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF
self.pose_est = PoseEstimator(self.resolution,
FaceType.toString(self.face_type),
load_weights=not self.is_first_run(),
weights_file_root=self.get_model_root_path(),
training=True)
if self.is_training_mode:
f = SampleProcessor.TypeFlags
face_type = f.FACE_TYPE_FULL if self.options['face_type'] == 'f' else f.FACE_TYPE_HALF
self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options( motion_blur = [25, 1] ), #random_flip=True,
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE | f.OPT_APPLY_MOTION_BLUR, self.resolution],
[f.PITCH_YAW_ROLL],
]),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(), #random_flip=True,
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE, self.resolution],
[f.PITCH_YAW_ROLL],
])
])
#override
def onSave(self):
self.pose_est.save_weights()
#override
def onTrainOneIter(self, generators_samples, generators_list):
target_src, pitch_yaw_roll = generators_samples[0]
loss = self.pose_est.train_on_batch( target_src, pitch_yaw_roll )
return ( ('loss', loss), )
#override
def onGetPreview(self, generators_samples):
test_src = generators_samples[0][0][0:4] #first 4 samples
test_pyr_src = generators_samples[0][1][0:4]
test_dst = generators_samples[1][0][0:4]
test_pyr_dst = generators_samples[1][1][0:4]
h,w,c = self.resolution,self.resolution,3
h_line = 13
result = []
for name, img, pyr in [ ['training data', test_src, test_pyr_src], \
['evaluating data',test_dst, test_pyr_dst] ]:
pyr_pred = self.pose_est.extract(img)
hor_imgs = []
for i in range(len(img)):
img_info = np.ones ( (h,w,c) ) * 0.1
lines = ["%s" % ( str(pyr[i]) ),
"%s" % ( str(pyr_pred[i]) ) ]
lines_count = len(lines)
for ln in range(lines_count):
img_info[ ln*h_line:(ln+1)*h_line, 0:w] += \
imagelib.get_text_image ( (h_line,w,c), lines[ln], color=[0.8]*c )
hor_imgs.append ( np.concatenate ( (
img[i,:,:,0:3],
img_info
), axis=1) )
result += [ (name, np.concatenate (hor_imgs, axis=0)) ]
return result

View file

@ -0,0 +1 @@
from .Model import Model

View file

@ -16,26 +16,26 @@ class SampleType(IntEnum):
FACE = 1 #aligned face unsorted FACE = 1 #aligned face unsorted
FACE_YAW_SORTED = 2 #sorted by yaw FACE_YAW_SORTED = 2 #sorted by yaw
FACE_YAW_SORTED_AS_TARGET = 3 #sorted by yaw and included only yaws which exist in TARGET also automatic mirrored FACE_YAW_SORTED_AS_TARGET = 3 #sorted by yaw and included only yaws which exist in TARGET also automatic mirrored
FACE_END = 3 FACE_TEMPORAL_SORTED = 4
FACE_END = 4
QTY = 4 QTY = 5
class Sample(object): class Sample(object):
def __init__(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch=None, yaw=None, source_filename=None, mirror=None, close_target_list=None, fanseg_mask_exist=False): def __init__(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch_yaw_roll=None, source_filename=None, mirror=None, close_target_list=None, fanseg_mask_exist=False):
self.sample_type = sample_type if sample_type is not None else SampleType.IMAGE self.sample_type = sample_type if sample_type is not None else SampleType.IMAGE
self.filename = filename self.filename = filename
self.face_type = face_type self.face_type = face_type
self.shape = shape self.shape = shape
self.landmarks = np.array(landmarks) if landmarks is not None else None self.landmarks = np.array(landmarks) if landmarks is not None else None
self.ie_polys = ie_polys self.ie_polys = ie_polys
self.pitch = pitch self.pitch_yaw_roll = pitch_yaw_roll
self.yaw = yaw
self.source_filename = source_filename self.source_filename = source_filename
self.mirror = mirror self.mirror = mirror
self.close_target_list = close_target_list self.close_target_list = close_target_list
self.fanseg_mask_exist = fanseg_mask_exist self.fanseg_mask_exist = fanseg_mask_exist
def copy_and_set(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch=None, yaw=None, source_filename=None, mirror=None, close_target_list=None, fanseg_mask=None, fanseg_mask_exist=None): def copy_and_set(self, sample_type=None, filename=None, face_type=None, shape=None, landmarks=None, ie_polys=None, pitch_yaw_roll=None, source_filename=None, mirror=None, close_target_list=None, fanseg_mask=None, fanseg_mask_exist=None):
return Sample( return Sample(
sample_type=sample_type if sample_type is not None else self.sample_type, sample_type=sample_type if sample_type is not None else self.sample_type,
filename=filename if filename is not None else self.filename, filename=filename if filename is not None else self.filename,
@ -43,8 +43,7 @@ class Sample(object):
shape=shape if shape is not None else self.shape, shape=shape if shape is not None else self.shape,
landmarks=landmarks if landmarks is not None else self.landmarks.copy(), landmarks=landmarks if landmarks is not None else self.landmarks.copy(),
ie_polys=ie_polys if ie_polys is not None else self.ie_polys, ie_polys=ie_polys if ie_polys is not None else self.ie_polys,
pitch=pitch if pitch is not None else self.pitch, pitch_yaw_roll=pitch_yaw_roll if pitch_yaw_roll is not None else self.pitch_yaw_roll,
yaw=yaw if yaw is not None else self.yaw,
source_filename=source_filename if source_filename is not None else self.source_filename, source_filename=source_filename if source_filename is not None else self.source_filename,
mirror=mirror if mirror is not None else self.mirror, mirror=mirror if mirror is not None else self.mirror,
close_target_list=close_target_list if close_target_list is not None else self.close_target_list, close_target_list=close_target_list if close_target_list is not None else self.close_target_list,

View file

@ -15,13 +15,12 @@ output_sample_types = [
] ]
''' '''
class SampleGeneratorFace(SampleGeneratorBase): class SampleGeneratorFace(SampleGeneratorBase):
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, add_pitch=False, add_yaw=False, generators_count=2, generators_random_seed=None, **kwargs): def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
super().__init__(samples_path, debug, batch_size) super().__init__(samples_path, debug, batch_size)
self.sample_process_options = sample_process_options self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx self.add_sample_idx = add_sample_idx
self.add_pitch = add_pitch # self.add_pitch_yaw_roll = add_pitch_yaw_roll
self.add_yaw = add_yaw
if sort_by_yaw_target_samples_path is not None: if sort_by_yaw_target_samples_path is not None:
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
@ -143,12 +142,6 @@ class SampleGeneratorFace(SampleGeneratorBase):
if self.add_sample_idx: if self.add_sample_idx:
batches += [ [] ] batches += [ [] ]
i_sample_idx = len(batches)-1 i_sample_idx = len(batches)-1
if self.add_pitch:
batches += [ [] ]
i_pitch = len(batches)-1
if self.add_yaw:
batches += [ [] ]
i_yaw = len(batches)-1
for i in range(len(x)): for i in range(len(x)):
batches[i].append ( x[i] ) batches[i].append ( x[i] )
@ -156,14 +149,5 @@ class SampleGeneratorFace(SampleGeneratorBase):
if self.add_sample_idx: if self.add_sample_idx:
batches[i_sample_idx].append (idx) batches[i_sample_idx].append (idx)
if self.add_pitch or self.add_yaw:
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw (sample.landmarks)
if self.add_pitch:
batches[i_pitch].append ([pitch])
if self.add_yaw:
batches[i_yaw].append ([yaw])
break break
yield [ np.array(batch) for batch in batches] yield [ np.array(batch) for batch in batches]

View file

@ -35,9 +35,9 @@ class SampleLoader:
if datas[sample_type] is None: if datas[sample_type] is None:
datas[sample_type] = SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename) for filename in Path_utils.get_image_paths(samples_path) ] ) datas[sample_type] = SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename) for filename in Path_utils.get_image_paths(samples_path) ] )
# elif sample_type == SampleType.FACE_TEMPORAL_SORTED: elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
# if datas[sample_type] is None: if datas[sample_type] is None:
# datas[sample_type] = SampleLoader.upgradeToFaceTemporalSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) ) datas[sample_type] = SampleLoader.upgradeToFaceTemporalSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) )
elif sample_type == SampleType.FACE_YAW_SORTED: elif sample_type == SampleType.FACE_YAW_SORTED:
if datas[sample_type] is None: if datas[sample_type] is None:
@ -69,15 +69,12 @@ class SampleLoader:
print ("%s is not a dfl image file required for training" % (s_filename_path.name) ) print ("%s is not a dfl image file required for training" % (s_filename_path.name) )
continue continue
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() )
sample_list.append( s.copy_and_set(sample_type=SampleType.FACE, sample_list.append( s.copy_and_set(sample_type=SampleType.FACE,
face_type=FaceType.fromString (dflimg.get_face_type()), face_type=FaceType.fromString (dflimg.get_face_type()),
shape=dflimg.get_shape(), shape=dflimg.get_shape(),
landmarks=dflimg.get_landmarks(), landmarks=dflimg.get_landmarks(),
ie_polys=dflimg.get_ie_polys(), ie_polys=dflimg.get_ie_polys(),
pitch=pitch, pitch_yaw_roll=dflimg.get_pitch_yaw_roll(),
yaw=yaw,
source_filename=dflimg.get_source_filename(), source_filename=dflimg.get_source_filename(),
fanseg_mask_exist=dflimg.get_fanseg_mask() is not None, ) ) fanseg_mask_exist=dflimg.get_fanseg_mask() is not None, ) )
except: except:
@ -85,12 +82,12 @@ class SampleLoader:
return sample_list return sample_list
# @staticmethod @staticmethod
# def upgradeToFaceTemporalSortedSamples( samples ): def upgradeToFaceTemporalSortedSamples( samples ):
# new_s = [ (s, s.source_filename) for s in samples] new_s = [ (s, s.source_filename) for s in samples]
# new_s = sorted(new_s, key=operator.itemgetter(1)) new_s = sorted(new_s, key=operator.itemgetter(1))
# return [ s[0] for s in new_s] return [ s[0] for s in new_s]
@staticmethod @staticmethod
def upgradeToFaceYawSortedSamples( samples ): def upgradeToFaceYawSortedSamples( samples ):

View file

@ -13,9 +13,10 @@ class SampleProcessor(object):
WARPED_TRANSFORMED = 0x00000004, WARPED_TRANSFORMED = 0x00000004,
TRANSFORMED = 0x00000008, TRANSFORMED = 0x00000008,
LANDMARKS_ARRAY = 0x00000010, #currently unused LANDMARKS_ARRAY = 0x00000010, #currently unused
PITCH_YAW_ROLL = 0x00000020,
RANDOM_CLOSE = 0x00000020, #currently unused RANDOM_CLOSE = 0x00000040, #currently unused
MORPH_TO_RANDOM_CLOSE = 0x00000040, #currently unused MORPH_TO_RANDOM_CLOSE = 0x00000080, #currently unused
FACE_TYPE_HALF = 0x00000100, FACE_TYPE_HALF = 0x00000100,
FACE_TYPE_FULL = 0x00000200, FACE_TYPE_FULL = 0x00000200,
@ -77,7 +78,7 @@ class SampleProcessor(object):
outputs = [] outputs = []
for sample_type in output_sample_types: for sample_type in output_sample_types:
f = sample_type[0] f = sample_type[0]
size = sample_type[1] size = 0 if len (sample_type) < 2 else sample_type[1]
random_sub_size = 0 if len (sample_type) < 3 else min( sample_type[2] , size) random_sub_size = 0 if len (sample_type) < 3 else min( sample_type[2] , size)
if f & SPTF.SOURCE != 0: if f & SPTF.SOURCE != 0:
@ -90,6 +91,8 @@ class SampleProcessor(object):
img_type = 3 img_type = 3
elif f & SPTF.LANDMARKS_ARRAY != 0: elif f & SPTF.LANDMARKS_ARRAY != 0:
img_type = 4 img_type = 4
elif f & SPTF.PITCH_YAW_ROLL != 0:
img_type = 5
else: else:
raise ValueError ('expected SampleTypeFlags type') raise ValueError ('expected SampleTypeFlags type')
@ -121,6 +124,16 @@ class SampleProcessor(object):
l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 ) l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 )
l = np.clip(l, 0.0, 1.0) l = np.clip(l, 0.0, 1.0)
img = l img = l
elif img_type == 5:
pitch_yaw_roll = sample.pitch_yaw_roll
if pitch_yaw_roll is not None:
pitch, yaw, roll = pitch_yaw_roll
else:
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll (sample.landmarks)
if params['flip']:
yaw = -yaw
img = (pitch, yaw, roll)
else: else:
if images[img_type][face_mask_type] is None: if images[img_type][face_mask_type] is None:
if img_type >= 10 and img_type <= 19: #RANDOM_CLOSE if img_type >= 10 and img_type <= 19: #RANDOM_CLOSE

View file

@ -4,4 +4,5 @@ from .SampleLoader import SampleLoader
from .SampleProcessor import SampleProcessor from .SampleProcessor import SampleProcessor
from .SampleGeneratorBase import SampleGeneratorBase from .SampleGeneratorBase import SampleGeneratorBase
from .SampleGeneratorFace import SampleGeneratorFace from .SampleGeneratorFace import SampleGeneratorFace
from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal
from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal

View file

@ -167,7 +167,9 @@ class DFLJPG(object):
source_rect=None, source_rect=None,
source_landmarks=None, source_landmarks=None,
image_to_face_mat=None, image_to_face_mat=None,
fanseg_mask=None, **kwargs fanseg_mask=None,
pitch_yaw_roll=None,
**kwargs
): ):
if fanseg_mask is not None: if fanseg_mask is not None:
@ -191,6 +193,7 @@ class DFLJPG(object):
'source_landmarks': source_landmarks, 'source_landmarks': source_landmarks,
'image_to_face_mat': image_to_face_mat, 'image_to_face_mat': image_to_face_mat,
'fanseg_mask' : fanseg_mask, 'fanseg_mask' : fanseg_mask,
'pitch_yaw_roll' : pitch_yaw_roll
}) })
try: try:
@ -206,7 +209,9 @@ class DFLJPG(object):
source_rect=None, source_rect=None,
source_landmarks=None, source_landmarks=None,
image_to_face_mat=None, image_to_face_mat=None,
fanseg_mask=None, **kwargs fanseg_mask=None,
pitch_yaw_roll=None,
**kwargs
): ):
if face_type is None: face_type = self.get_face_type() if face_type is None: face_type = self.get_face_type()
if landmarks is None: landmarks = self.get_landmarks() if landmarks is None: landmarks = self.get_landmarks()
@ -216,6 +221,7 @@ class DFLJPG(object):
if source_landmarks is None: source_landmarks = self.get_source_landmarks() if source_landmarks is None: source_landmarks = self.get_source_landmarks()
if image_to_face_mat is None: image_to_face_mat = self.get_image_to_face_mat() if image_to_face_mat is None: image_to_face_mat = self.get_image_to_face_mat()
if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask() if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask()
if pitch_yaw_roll is None: pitch_yaw_roll = self.get_pitch_yaw_roll()
DFLJPG.embed_data (filename, face_type=face_type, DFLJPG.embed_data (filename, face_type=face_type,
landmarks=landmarks, landmarks=landmarks,
ie_polys=ie_polys, ie_polys=ie_polys,
@ -223,7 +229,8 @@ class DFLJPG(object):
source_rect=source_rect, source_rect=source_rect,
source_landmarks=source_landmarks, source_landmarks=source_landmarks,
image_to_face_mat=image_to_face_mat, image_to_face_mat=image_to_face_mat,
fanseg_mask=fanseg_mask) fanseg_mask=fanseg_mask,
pitch_yaw_roll=pitch_yaw_roll)
def remove_fanseg_mask(self): def remove_fanseg_mask(self):
self.dfl_dict['fanseg_mask'] = None self.dfl_dict['fanseg_mask'] = None
@ -291,3 +298,6 @@ class DFLJPG(object):
if fanseg_mask is not None: if fanseg_mask is not None:
return np.clip ( np.array (fanseg_mask) / 255.0, 0.0, 1.0 )[...,np.newaxis] return np.clip ( np.array (fanseg_mask) / 255.0, 0.0, 1.0 )[...,np.newaxis]
return None return None
def get_pitch_yaw_roll(self):
return self.dfl_dict.get ('pitch_yaw_roll', None)

View file

@ -283,7 +283,9 @@ class DFLPNG(object):
source_rect=None, source_rect=None,
source_landmarks=None, source_landmarks=None,
image_to_face_mat=None, image_to_face_mat=None,
fanseg_mask=None, **kwargs fanseg_mask=None,
pitch_yaw_roll=None,
**kwargs
): ):
if fanseg_mask is not None: if fanseg_mask is not None:
@ -307,6 +309,7 @@ class DFLPNG(object):
'source_landmarks': source_landmarks, 'source_landmarks': source_landmarks,
'image_to_face_mat':image_to_face_mat, 'image_to_face_mat':image_to_face_mat,
'fanseg_mask' : fanseg_mask, 'fanseg_mask' : fanseg_mask,
'pitch_yaw_roll' : pitch_yaw_roll
}) })
try: try:
@ -322,7 +325,9 @@ class DFLPNG(object):
source_rect=None, source_rect=None,
source_landmarks=None, source_landmarks=None,
image_to_face_mat=None, image_to_face_mat=None,
fanseg_mask=None, **kwargs fanseg_mask=None,
pitch_yaw_roll=None,
**kwargs
): ):
if face_type is None: face_type = self.get_face_type() if face_type is None: face_type = self.get_face_type()
if landmarks is None: landmarks = self.get_landmarks() if landmarks is None: landmarks = self.get_landmarks()
@ -332,6 +337,7 @@ class DFLPNG(object):
if source_landmarks is None: source_landmarks = self.get_source_landmarks() if source_landmarks is None: source_landmarks = self.get_source_landmarks()
if image_to_face_mat is None: image_to_face_mat = self.get_image_to_face_mat() if image_to_face_mat is None: image_to_face_mat = self.get_image_to_face_mat()
if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask() if fanseg_mask is None: fanseg_mask = self.get_fanseg_mask()
if pitch_yaw_roll is None: pitch_yaw_roll = self.get_pitch_yaw_roll()
DFLPNG.embed_data (filename, face_type=face_type, DFLPNG.embed_data (filename, face_type=face_type,
landmarks=landmarks, landmarks=landmarks,
ie_polys=ie_polys, ie_polys=ie_polys,
@ -339,7 +345,8 @@ class DFLPNG(object):
source_rect=source_rect, source_rect=source_rect,
source_landmarks=source_landmarks, source_landmarks=source_landmarks,
image_to_face_mat=image_to_face_mat, image_to_face_mat=image_to_face_mat,
fanseg_mask=fanseg_mask) fanseg_mask=fanseg_mask,
pitch_yaw_roll=pitch_yaw_roll)
def remove_fanseg_mask(self): def remove_fanseg_mask(self):
self.dfl_dict['fanseg_mask'] = None self.dfl_dict['fanseg_mask'] = None
@ -397,5 +404,7 @@ class DFLPNG(object):
if fanseg_mask is not None: if fanseg_mask is not None:
return np.clip ( np.array (fanseg_mask) / 255.0, 0.0, 1.0 )[...,np.newaxis] return np.clip ( np.array (fanseg_mask) / 255.0, 0.0, 1.0 )[...,np.newaxis]
return None return None
def get_pitch_yaw_roll(self):
return self.dfl_dict.get ('pitch_yaw_roll', None)
def __str__(self): def __str__(self):
return "<PNG length={length} chunks={}>".format(len(self.chunks), **self.__dict__) return "<PNG length={length} chunks={}>".format(len(self.chunks), **self.__dict__)