S3FD and 2DFAN-4 were improperly ported from pytorch. now fixed.

This commit is contained in:
Colombo 2019-09-20 17:16:37 +04:00
parent e928ee0d30
commit d9d10f91c2
5 changed files with 98 additions and 14 deletions

Binary file not shown.

View file

@ -1,27 +1,33 @@
import traceback
import numpy as np
import os
import cv2
import traceback
from pathlib import Path
from facelib import FaceType
from facelib import LandmarksProcessor
class LandmarksExtractor(object):
def __init__ (self, keras):
self.keras = keras
K = self.keras.backend
import cv2
import numpy as np
from facelib import FaceType, LandmarksProcessor
from nnlib import nnlib
"""
ported from https://github.com/1adrianb/face-alignment
"""
class FANExtractor(object):
def __init__ (self):
pass
def __enter__(self):
keras_model_path = Path(__file__).parent / "2DFAN-4.h5"
if not keras_model_path.exists():
return None
self.keras_model = self.keras.models.load_model (str(keras_model_path))
exec( nnlib.import_all(), locals(), globals() )
self.model = FANExtractor.BuildModel()
self.model.load_weights(str(keras_model_path))
return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
del self.keras_model
del self.model
return False #pass exception between __enter__ and __exit__ to outter level
def extract (self, input_image, rects, second_pass_extractor=None, is_bgr=True):
@ -43,7 +49,7 @@ class LandmarksExtractor(object):
image = self.crop(input_image, center, scale).astype(np.float32)
image = np.expand_dims(image, 0)
predicted = self.keras_model.predict (image).transpose (0,3,1,2)
predicted = self.model.predict (image / 255.0).transpose (0,3,1,2)
pts_img = self.get_pts_from_predict ( predicted[-1], center, scale)
landmarks.append (pts_img)
@ -118,3 +124,81 @@ class LandmarksExtractor(object):
c += 0.5
return np.array( [ self.transform (c[i], center, scale, a.shape[2]) for i in range(a.shape[0]) ] )
@staticmethod
def BuildModel():
def ConvBlock(out_planes, input):
in_planes = K.int_shape(input)[-1]
x = input
x = BatchNormalization(momentum=0.1, epsilon=1e-05)(x)
x = ReLU() (x)
x = out1 = Conv2D( int(out_planes/2), kernel_size=3, strides=1, padding='valid', use_bias = False) (ZeroPadding2D(1)(x))
x = BatchNormalization(momentum=0.1, epsilon=1e-05)(x)
x = ReLU() (x)
x = out2 = Conv2D( int(out_planes/4), kernel_size=3, strides=1, padding='valid', use_bias = False) (ZeroPadding2D(1)(x))
x = BatchNormalization(momentum=0.1, epsilon=1e-05)(x)
x = ReLU() (x)
x = out3 = Conv2D( int(out_planes/4), kernel_size=3, strides=1, padding='valid', use_bias = False) (ZeroPadding2D(1)(x))
x = Concatenate()([out1, out2, out3])
if in_planes != out_planes:
downsample = BatchNormalization(momentum=0.1, epsilon=1e-05)(input)
downsample = ReLU() (downsample)
downsample = Conv2D( out_planes, kernel_size=1, strides=1, padding='valid', use_bias = False) (downsample)
x = Add ()([x, downsample])
else:
x = Add ()([x, input])
return x
def HourGlass (depth, input):
up1 = ConvBlock(256, input)
low1 = AveragePooling2D (pool_size=2, strides=2, padding='valid' )(input)
low1 = ConvBlock (256, low1)
if depth > 1:
low2 = HourGlass (depth-1, low1)
else:
low2 = ConvBlock(256, low1)
low3 = ConvBlock(256, low2)
up2 = UpSampling2D(size=2) (low3)
return Add() ( [up1, up2] )
FAN_Input = Input ( (256, 256, 3) )
x = FAN_Input
x = Conv2D (64, kernel_size=7, strides=2, padding='valid')(ZeroPadding2D(3)(x))
x = BatchNormalization(momentum=0.1, epsilon=1e-05)(x)
x = ReLU()(x)
x = ConvBlock (128, x)
x = AveragePooling2D (pool_size=2, strides=2, padding='valid') (x)
x = ConvBlock (128, x)
x = ConvBlock (256, x)
outputs = []
previous = x
for i in range(4):
ll = HourGlass (4, previous)
ll = ConvBlock (256, ll)
ll = Conv2D(256, kernel_size=1, strides=1, padding='valid') (ll)
ll = BatchNormalization(momentum=0.1, epsilon=1e-05)(ll)
ll = ReLU() (ll)
tmp_out = Conv2D(68, kernel_size=1, strides=1, padding='valid') (ll)
outputs.append(tmp_out)
if i < 4 - 1:
ll = Conv2D(256, kernel_size=1, strides=1, padding='valid') (ll)
previous = Add() ( [previous, ll, KL.Conv2D(256, kernel_size=1, strides=1, padding='valid') (tmp_out) ] )
return Model(FAN_Input, outputs[-1] )

Binary file not shown.

View file

@ -2,6 +2,6 @@ from .FaceType import FaceType
from .DLIBExtractor import DLIBExtractor
from .MTCExtractor import MTCExtractor
from .S3FDExtractor import S3FDExtractor
from .LandmarksExtractor import LandmarksExtractor
from .FANExtractor import FANExtractor
from .FANSegmentator import FANSegmentator
from .PoseEstimator import PoseEstimator

View file

@ -85,7 +85,7 @@ class ExtractSubprocessor(Subprocessor):
elif self.type == 'landmarks':
nnlib.import_all (device_config)
self.e = facelib.LandmarksExtractor(nnlib.keras)
self.e = facelib.FANExtractor()
self.e.__enter__()
if self.device_vram >= 2:
self.second_pass_e = facelib.S3FDExtractor()