removing trailing spaces

This commit is contained in:
iperov 2019-03-19 23:53:27 +04:00
parent fa4e579b95
commit a3df04999c
61 changed files with 2110 additions and 2103 deletions

View file

@ -3,35 +3,35 @@ from pathlib import Path
import cv2
from nnlib import nnlib
class S3FDExtractor(object):
class S3FDExtractor(object):
def __init__(self):
exec( nnlib.import_all(), locals(), globals() )
model_path = Path(__file__).parent / "S3FD.h5"
if not model_path.exists():
return None
self.model = nnlib.keras.models.load_model ( str(model_path) )
self.model = nnlib.keras.models.load_model ( str(model_path) )
def __enter__(self):
return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
return False #pass exception between __enter__ and __exit__ to outter level
def extract_from_bgr (self, input_image):
input_image = input_image[:,:,::-1].copy()
(h, w, ch) = input_image.shape
d = max(w, h)
scale_to = 640 if d >= 1280 else d / 2
scale_to = max(64, scale_to)
input_scale = d / scale_to
input_image = cv2.resize (input_image, ( int(w/input_scale), int(h/input_scale) ), interpolation=cv2.INTER_LINEAR)
olist = self.model.predict( np.expand_dims(input_image,0) )
detected_faces = []
for ltrb in self.refine (olist):
l,t,r,b = [ x*input_scale for x in ltrb]
@ -42,7 +42,7 @@ class S3FDExtractor(object):
detected_faces.append ( [int(x) for x in (l,t,r,b) ] )
return detected_faces
def refine(self, olist):
bboxlist = []
for i, ((ocls,), (oreg,)) in enumerate ( zip ( olist[::2], olist[1::2] ) ):
@ -51,7 +51,7 @@ class S3FDExtractor(object):
s_m4 = stride * 4
for hindex, windex in zip(*np.where(ocls > 0.05)):
score = ocls[hindex, windex]
score = ocls[hindex, windex]
loc = oreg[hindex, windex, :]
priors = np.array([windex * stride + s_d2, hindex * stride + s_d2, s_m4, s_m4])
priors_2p = priors[2:]
@ -61,15 +61,15 @@ class S3FDExtractor(object):
box[2:] += box[:2]
bboxlist.append([*box, score])
bboxlist = np.array(bboxlist)
if len(bboxlist) == 0:
bboxlist = np.zeros((1, 5))
bboxlist = bboxlist[self.refine_nms(bboxlist, 0.3), :]
bboxlist = [ x[:-1].astype(np.int) for x in bboxlist if x[-1] >= 0.5]
return bboxlist
def refine_nms(self, dets, thresh):
keep = list()
if len(dets) == 0:
@ -91,4 +91,4 @@ class S3FDExtractor(object):
inds = np.where(ovr <= thresh)[0]
order = order[inds + 1]
return keep
return keep