Adds doc comments, explains the threshold parameter, cleaned up and refactored the class

This commit is contained in:
jh 2019-08-23 13:41:12 -07:00
commit b400adadf5

View file

@ -3,13 +3,18 @@ from pathlib import Path
import cv2
from nnlib import nnlib
class S3FDExtractor(object):
"""
S3FD: Single Shot Scale-invariant Face Detector
https://arxiv.org/pdf/1708.05237.pdf
"""
def __init__(self):
exec(nnlib.import_all(), locals(), globals())
model_path = Path(__file__).parent / "S3FD.h5"
if not model_path.exists():
return None
raise Exception(f'Could not find S3DF model at path {model_path}')
self.model = nnlib.keras.models.load_model(str(model_path))
@ -19,7 +24,15 @@ class S3FDExtractor(object):
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
return False # pass exception between __enter__ and __exit__ to outter level
def extract (self, input_image, is_bgr=True):
def extract(self, input_image, is_bgr=True, nms_thresh=0.3):
"""
Extracts the bounding boxes for all faces found in image
:param input_image: The image to look for faces in
:param is_bgr: Is this image in OpenCV's BGR color mode, if not, assume RGB color mode
:param nms_thresh: The NMS (non-maximum suppression) threshold. Of all bounding boxes found, only return
bounding boxes with an overlap ratio less then threshold
:return:
"""
if is_bgr:
input_image = input_image[:, :, ::-1]
@ -32,12 +45,13 @@ class S3FDExtractor(object):
scale_to = max(64, scale_to)
input_scale = d / scale_to
input_image = cv2.resize (input_image, ( int(w/input_scale), int(h/input_scale) ), interpolation=cv2.INTER_LINEAR)
input_image = cv2.resize(input_image, (int(w / input_scale), int(h / input_scale)),
interpolation=cv2.INTER_LINEAR)
olist = self.model.predict(np.expand_dims(input_image, 0))
detected_faces = []
for ltrb in self.refine (olist):
for ltrb in self._refine(olist, nms_thresh):
l, t, r, b = [x * input_scale for x in ltrb]
bt = b - t
if min(r - l, bt) < 40: # filtering faces < 40pix by any side
@ -47,7 +61,7 @@ class S3FDExtractor(object):
return detected_faces
def refine(self, olist):
def _refine(self, olist, thresh):
bboxlist = []
for i, ((ocls,), (oreg,)) in enumerate(zip(olist[::2], olist[1::2])):
stride = 2 ** (i + 2) # 4,8,16,32,64,128
@ -69,12 +83,11 @@ class S3FDExtractor(object):
bboxlist = np.array(bboxlist)
if len(bboxlist) == 0:
bboxlist = np.zeros((1, 5))
#Originally 0.3 thresh
bboxlist = bboxlist[self.refine_nms(bboxlist, 0.8), :]
bboxlist = bboxlist[self._refine_nms(bboxlist, thresh), :]
bboxlist = [x[:-1].astype(np.int) for x in bboxlist if x[-1] >= 0.5]
return bboxlist
def refine_nms(self, dets, thresh):
def _refine_nms(self, dets, nms_thresh):
keep = list()
if len(dets) == 0:
return keep
@ -93,6 +106,6 @@ class S3FDExtractor(object):
width, height = np.maximum(0.0, xx_2 - xx_1 + 1), np.maximum(0.0, yy_2 - yy_1 + 1)
ovr = width * height / (areas[i] + areas[order[1:]] - width * height)
inds = np.where(ovr <= thresh)[0]
inds = np.where(ovr <= nms_thresh)[0]
order = order[inds + 1]
return keep