code release

This commit is contained in:
iperov 2021-07-23 17:34:49 +04:00
commit a902f11f74
354 changed files with 826570 additions and 1 deletions

224
modelhub/DFLive/DFMModel.py Normal file
View file

@ -0,0 +1,224 @@
from pathlib import Path
from typing import Iterator, List, Tuple, Union
import numpy as np
from xlib import onnxruntime as lib_ort
from xlib import path as lib_path
from xlib.image import ImageProcessor
from xlib.net import ThreadFileDownloader
from xlib.onnxruntime.device import ORTDeviceInfo
class DFMModelInfo:
def __init__(self, name : str, model_path : Path, url : str = None):
self._name = name
self._model_path = model_path
self._url = url
def get_name(self) -> str: return self._name
def get_model_path(self) -> Path: return self._model_path
def get_url(self) -> Union[str, None]: return self._url
def __eq__(self, other):
if self is not None and other is not None and isinstance(self, DFMModelInfo) and isinstance(other, DFMModelInfo):
return self._name == other._name
return False
def __hash__(self):
return self._name.__hash__()
def __str__(self):
return self._name
def get_available_models_info(models_path : Path) -> List[DFMModelInfo]:
# predefined list of celebs with urls
dfm_models = [
#DFMModelInfo(name='Tom Cruise', model_path=models_path / f'Tom Cruise.dfm', url=rf'https://github.com/iperov/DeepFaceLive/releases/download/test/TOM_CRISE.onnx'),#TODO https://github.com/iperov/DeepFaceLive/releases/download/dfm/TOM_CRUISE.dfm'),
#DFMModelInfo(name='Vladimir Putin', model_path=models_path / f'Vladimir Putin.dfm', url=rf'https://github.com/iperov/DeepFaceLive/releases/download/dfm/VLADIMIR_PUTIN.dfm'),
]
# scan additional models in directory
dfm_model_paths = [ celeb.get_model_path() for celeb in dfm_models]
for dfm_path in lib_path.get_files_paths(models_path, extensions=['.dfm']):
if dfm_path not in dfm_model_paths:
dfm_models.append( DFMModelInfo(dfm_path.stem, model_path=dfm_path, ) )
return dfm_models
def get_available_devices() -> List[ORTDeviceInfo]:
"""
"""
return lib_ort.get_available_devices_info()
class DFMModel:
def __init__(self, model_path : Path, device : ORTDeviceInfo = None):
if device is None:
device = lib_ort.get_cpu_device()
self._model_path = model_path
sess = self._sess = lib_ort.InferenceSession_with_device(str(model_path), device)
inputs = sess.get_inputs()
if len(inputs) == 0:
raise Exception(f'Invalid model {model_path}')
else:
if 'in_face' not in inputs[0].name:
raise Exception(f'Invalid model {model_path}')
else:
self._input_height, self._input_width = inputs[0].shape[1:3]
self._model_type = 1
if len(inputs) == 2:
if 'morph_value' not in inputs[1].name:
raise Exception(f'Invalid model {model_path}')
self._model_type = 2
elif len(inputs) > 2:
raise Exception(f'Invalid model {model_path}')
def get_model_path(self) -> Path: return self._model_path
def get_input_res(self) -> Tuple[int, int]:
return self._input_width, self._input_height
def has_morph_value(self) -> bool:
return self._model_type == 2
def convert(self, img, morph_factor=0.75):
"""
img np.ndarray HW,HWC,NHWC uint8,float32
morph_factor float used if model supports it
returns
img NHW3 same dtype as img
celeb_mask NHW1 same dtype as img
face_mask NHW1 same dtype as img
"""
ip = ImageProcessor(img)
N,H,W,C = ip.get_dims()
dtype = ip.get_dtype()
img = ip.resize( (self._input_width,self._input_height) ).ch(3).to_ufloat32().get_image('NHWC')
if self._model_type == 1:
out_face_mask, out_celeb, out_celeb_mask = self._sess.run(None, {'in_face:0': img})
elif self._model_type == 2:
out_face_mask, out_celeb, out_celeb_mask = self._sess.run(None, {'in_face:0': img, 'morph_value:0':np.float32([morph_factor]) })
out_celeb = ImageProcessor(out_celeb).resize((W,H)).ch(3).to_dtype(dtype).get_image('NHWC')
out_celeb_mask = ImageProcessor(out_celeb_mask).resize((W,H)).ch(1).to_dtype(dtype).get_image('NHWC')
out_face_mask = ImageProcessor(out_face_mask).resize((W,H)).ch(1).to_dtype(dtype).get_image('NHWC')
return out_celeb, out_celeb_mask, out_face_mask
class DFMModelInitializer:
"""
class to initialize DFMModel from DFMModelInfo
use .process_events() to iterate initialization process with events
"""
class Events:
prev_status_initializing : bool = False
prev_status_downloading : bool = False
prev_status_initialized : bool = False
prev_status_error : bool = False
new_status_initializing : bool = False
new_status_downloading : bool = False
new_status_initialized : bool = False
new_status_error : bool = False
download_progress : float = None
error : str = None
dfm_model : DFMModel = None
def __init__(self, dfm_model_info : DFMModelInfo, device : ORTDeviceInfo = None ): self._gen = self._generator(dfm_model_info, device)
def process_events(self) -> 'DFMModelInitializer.Events': return next(self._gen)
def _generator(self, dfm_model_info : DFMModelInfo, device : ORTDeviceInfo = None) -> Iterator['DFMModelInitializer.Events']:
"""
Creates a generator object to initialize DFM model from provided parameters
"""
INITIALIZING, DOWNLOADING, INITIALIZED, ERROR, = range(4)
downloader : ThreadFileDownloader = None
status = None
while True:
events = DFMModelInitializer.Events()
new_status = status
if status is None:
new_status = INITIALIZING
elif status == INITIALIZING:
model_path = dfm_model_info.get_model_path()
if not model_path.exists():
url = dfm_model_info.get_url()
if url is None:
new_status = ERROR
events.error = 'Model file is not found and URL is not defined.'
else:
downloader = ThreadFileDownloader(url=url, savepath=model_path)
new_status = DOWNLOADING
else:
error = None
try:
dfm_model = DFMModel(model_path, device)
except Exception as e:
error = str(e)
if error is None:
new_status = INITIALIZED
events.dfm_model = dfm_model
else:
new_status = ERROR
events.error = error
elif status == DOWNLOADING:
error = downloader.get_error()
if error is None:
progress = downloader.get_progress()
if progress == 100.0:
new_status = INITIALIZING
else:
events.download_progress = progress
else:
new_status = ERROR
events.error = error
if new_status != status:
events.prev_status_initializing = status == INITIALIZING
events.prev_status_downloading = status == DOWNLOADING
events.prev_status_initialized = status == INITIALIZED
events.prev_status_error = status == ERROR
events.new_status_initializing = new_status == INITIALIZING
events.new_status_downloading = new_status == DOWNLOADING
events.new_status_initialized = new_status == INITIALIZED
events.new_status_error = new_status == ERROR
status = new_status
yield events
def DFMModel_from_path(model_path : Path, device : ORTDeviceInfo = None) -> DFMModel:
"""
instantiates DFMModel
"""
return DFMModel(model_path=model_path, device=device, __check=1)
def DFMModel_from_info(dfm_model_info : DFMModelInfo, device : ORTDeviceInfo = None) -> DFMModelInitializer:
"""
instantiates DFMModelInitializer
"""
return DFMModelInitializer(dfm_model_info=dfm_model_info, device=device)

View file

@ -0,0 +1,3 @@
from .DFMModel import (DFMModel_from_info, DFMModel_from_path,
DFMModelInfo, get_available_devices,
get_available_models_info)

115
modelhub/DFLive/__trash.py Normal file
View file

@ -0,0 +1,115 @@
from enum import IntEnum
from pathlib import Path
from typing import List, Union
import cv2
import numpy as np
from xlib.image import ImageProcessor
from xlib.net import ThreadFileDownloader
from xlib import tf
class DFMModelInfo:
"""
Model for face swapping.
"""
def __init__(self, celeb_type, device):
celeb_name = str(celeb_type).split('.')[-1]
model_path = self._model_path = Path(__file__).parent / 'CELEB_MODEL' / f'{celeb_name}.onnx'
self._dl_error = None
self._dl = None
if not model_path.exists():
self._dl = ThreadFileDownloader(rf'https://github.com/iperov/DeepFaceLive/releases/download/CELEB_MODEL/{celeb_name}.onnx', savepath=model_path)
self._device = device
self._sess = None
@staticmethod
def get_available_devices(celeb_type):
return tf.get_available_devices_info()
def get_download_error(self) -> Union[str,None]:
"""
returns download error or None if no error.
"""
if self._dl_error is not None:
return 'DFMModelInfo download error: ' + self._dl_error
return None
def update_download_progress(self) -> float:
"""
returns [0.0..100.0] where 100.0 mean the model is ready to convert.
"""
if self._dl_error is not None:
progress = 0.0
elif self._dl is None:
progress = 100.0
else:
err = self._dl_error = self._dl.get_error()
if err is not None:
progress = 0.0
else:
progress = self._dl.get_progress()
if progress == 100.0 and self._sess is None:
self._dl = None
model_path = r'D:\DevelopPPP\projects\DeepFaceLive\github_project\xlib\model_hub\tf\TOM_CRUISE.pb'
self._sess = tf.TFInferenceSession(model_path,
in_tensor_names=['in_face:0', 'morph_value:0'],
out_tensor_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'],
device_info=self._device)
return progress
def convert(self, img, morph_factor=0.75):
"""
img np.ndarray HW,HWC,NHWC uint8,float32
returns
img NHW3 same dtype as img
celeb_mask NHW1 same dtype as img
face_mask NHW1 same dtype as img
convert the face
if the model is not ready, it will return black images
"""
ip = ImageProcessor(img)
N,H,W,C = ip.get_dims()
dtype = ip.get_dtype()
progress = self.update_download_progress()
if progress == 100.0:
img = ip.resize( (224,224) ).ch(3).to_ufloat32().get_image('NHWC')
#out_face_mask, out_celeb, out_celeb_mask = self._sess.run(None, {self._input_name: img})
out_face_mask, out_celeb, out_celeb_mask = self._sess.run([img, [morph_factor] ])
out_celeb = ImageProcessor(out_celeb).resize((W,H)).ch(3).to_dtype(dtype).get_image('NHWC')
out_celeb_mask = ImageProcessor(out_celeb_mask).resize((W,H)).ch(1).to_dtype(dtype).get_image('NHWC')
out_face_mask = ImageProcessor(out_face_mask).resize((W,H)).ch(1).to_dtype(dtype).get_image('NHWC')
else:
out_celeb = np.zeros( (N,H,W,3), dtype=dtype )
out_celeb_mask = ImageProcessor(np.full( (N,H,W,1), 255, dtype=np.uint8 )).to_dtype(dtype).get_image('NHWC')
out_face_mask = ImageProcessor(np.full( (N,H,W,1), 255, dtype=np.uint8 )).to_dtype(dtype).get_image('NHWC')
return out_celeb, out_celeb_mask, out_face_mask
class CelebType(IntEnum):
TOM_CRUISE = 0
VLADIMIR_PUTIN = 1
CelebTypeNames = ['Tom Cruise',
'Vladimir Putin']

Binary file not shown.

View file

@ -0,0 +1,111 @@
from pathlib import Path
from typing import List
import numpy as np
from xlib import math as lib_math
from xlib.image import ImageProcessor
from xlib.onnxruntime import (InferenceSession_with_device, ORTDeviceInfo,
get_available_devices_info)
class CenterFace:
"""
CenterFace face detection model.
arguments
device_info ORTDeviceInfo
use CenterFace.get_available_devices()
to determine a list of avaliable devices accepted by model
raises
Exception
"""
@staticmethod
def get_available_devices() -> List[ORTDeviceInfo]:
# CenterFace ONNX model does not work correctly on CPU
# but it is much faster than Pytorch version
return get_available_devices_info(include_cpu=False)
def __init__(self, device_info : ORTDeviceInfo ):
if device_info not in CenterFace.get_available_devices():
raise Exception(f'device_info {device_info} is not in available devices for CenterFace')
path = Path(__file__).parent / 'CenterFace.onnx'
self._sess = sess = InferenceSession_with_device(str(path), device_info)
self._input_name = sess.get_inputs()[0].name
def extract(self, img, threshold : float = 0.5, fixed_window=0, min_face_size=40):
"""
arguments
img np.ndarray ndim 2,3,4
fixed_window(0) int size
0 mean don't use
fit image in fixed window
downscale if bigger than window
pad if smaller than window
increases performance, but decreases accuracy
returns a list of [l,t,r,b] for every batch dimension of img
"""
ip = ImageProcessor(img)
N,H,W,_ = ip.get_dims()
if fixed_window != 0:
fixed_window = max(64, max(1, fixed_window // 32) * 32 )
img_scale = ip.fit_in(fixed_window, fixed_window, pad_to_target=True, allow_upscale=False)
else:
ip.pad_to_next_divisor(64, 64)
img_scale = 1.0
img = ip.ch(3).swap_ch().to_uint8().as_float32().get_image('NCHW')
heatmaps, scales, offsets = self._sess.run(None, {self._input_name: img})
faces_per_batch = []
for heatmap, offset, scale in zip(heatmaps, offsets, scales):
faces = []
for face in self.refine(heatmap, offset, scale, H, W, threshold):
l,t,r,b,c = face
if img_scale != 1.0:
l,t,r,b = l/img_scale, t/img_scale, r/img_scale, b/img_scale
bt = b-t
if min(r-l,bt) < min_face_size:
continue
b += bt*0.1
faces.append( (l,t,r,b) )
faces_per_batch.append(faces)
return faces_per_batch
def refine(self, heatmap, offset, scale, h, w, threshold):
heatmap = heatmap[0]
scale0, scale1 = scale[0, :, :], scale[1, :, :]
offset0, offset1 = offset[0, :, :], offset[1, :, :]
c0, c1 = np.where(heatmap > threshold)
bboxlist = []
if len(c0) > 0:
for i in range(len(c0)):
s0, s1 = np.exp(scale0[c0[i], c1[i]]) * 4, np.exp(scale1[c0[i], c1[i]]) * 4
o0, o1 = offset0[c0[i], c1[i]], offset1[c0[i], c1[i]]
s = heatmap[c0[i], c1[i]]
x1, y1 = max(0, (c1[i] + o1 + 0.5) * 4 - s1 / 2), max(0, (c0[i] + o0 + 0.5) * 4 - s0 / 2)
x1, y1 = min(x1, w), min(y1, h)
bboxlist.append([x1, y1, min(x1 + s1, w), min(y1 + s0, h), s])
bboxlist = np.array(bboxlist, dtype=np.float32)
bboxlist = bboxlist[ lib_math.nms(bboxlist[:,0], bboxlist[:,1], bboxlist[:,2], bboxlist[:,3], bboxlist[:,4], 0.3), : ]
bboxlist = [x for x in bboxlist if x[-1] >= 0.5]
return bboxlist

Binary file not shown.

View file

@ -0,0 +1,58 @@
from pathlib import Path
from typing import List
from xlib.image import ImageProcessor
from xlib.onnxruntime import (InferenceSession_with_device, ORTDeviceInfo,
get_available_devices_info)
class FaceMesh:
"""
Google FaceMesh detection model.
arguments
device_info ORTDeviceInfo
use FaceMesh.get_available_devices()
to determine a list of avaliable devices accepted by model
raises
Exception
"""
@staticmethod
def get_available_devices() -> List[ORTDeviceInfo]:
return get_available_devices_info()
def __init__(self, device_info : ORTDeviceInfo):
if device_info not in FaceMesh.get_available_devices():
raise Exception(f'device_info {device_info} is not in available devices for FaceMesh')
path = Path(__file__).parent / 'FaceMesh.onnx'
self._sess = sess = InferenceSession_with_device(str(path), device_info)
self._input_name = sess.get_inputs()[0].name
self._input_width = 192
self._input_height = 192
def extract(self, img):
"""
arguments
img np.ndarray HW,HWC,NHWC uint8/float32
returns (N,468,3)
"""
ip = ImageProcessor(img)
N,H,W,_ = ip.get_dims()
h_scale = H / self._input_height
w_scale = W / self._input_width
feed_img = ip.resize( (self._input_width, self._input_height) ).to_ufloat32().ch(3).get_image('NHWC')
lmrks = self._sess.run(None, {self._input_name: feed_img})[0]
lmrks = lmrks.reshape( (N,468,3))
lmrks *= (w_scale, h_scale, 1)
return lmrks

Binary file not shown.

View file

@ -0,0 +1,89 @@
from pathlib import Path
from typing import List
import numpy as np
from xlib import math as lib_math
from xlib.image import ImageProcessor
from xlib.onnxruntime import (InferenceSession_with_device, ORTDeviceInfo,
get_available_devices_info)
class S3FD:
@staticmethod
def get_available_devices() -> List[ORTDeviceInfo]:
return get_available_devices_info()
def __init__(self, device_info : ORTDeviceInfo ):
if device_info not in S3FD.get_available_devices():
raise Exception(f'device_info {device_info} is not in available devices for S3FD')
path = Path(__file__).parent / 'S3FD.onnx'
self._sess = sess = InferenceSession_with_device(str(path), device_info)
self._input_name = sess.get_inputs()[0].name
def extract(self, img : np.ndarray, threshold=0.95, fixed_window=0, min_face_size=40):
"""
img HW,HWC,NHWC [0..255]
"""
ip = ImageProcessor(img)
if fixed_window != 0:
fixed_window = max(64, max(1, fixed_window // 32) * 32 )
img_scale = ip.fit_in(fixed_window, fixed_window, pad_to_target=True, allow_upscale=False)
else:
ip.pad_to_next_divisor(64, 64)
img_scale = 1.0
img = ip.ch(3).to_uint8().as_float32().apply( lambda img: img - [104,117,123]).get_image('NCHW')
batches_bbox = self._sess.run(None, {self._input_name: img})
faces_per_batch = []
for batch in range(img.shape[0]):
bbox = self.refine( [ x[batch] for x in batches_bbox ], threshold )
faces = []
for l,t,r,b,c in bbox:
if img_scale != 1.0:
l,t,r,b = l/img_scale, t/img_scale, r/img_scale, b/img_scale
bt = b-t
if min(r-l,bt) < min_face_size:
continue
b += bt*0.1
faces.append ( (l,t,r,b) )
faces_per_batch.append(faces)
return faces_per_batch
def refine(self, olist, threshold):
bboxlist = []
variances = [0.1, 0.2]
for i in range(len(olist) // 2):
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
stride = 2**(i + 2) # 4,8,16,32,64,128
for hindex, windex in [*zip(*np.where(ocls[1, :, :] > threshold))]:
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
score = ocls[1, hindex, windex]
loc = np.ascontiguousarray(oreg[:, hindex, windex]).reshape((1, 4))
priors = np.array([[axc, ayc, stride * 4, stride * 4]])
bbox = np.concatenate((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * np.exp(loc[:, 2:] * variances[1])), 1)
bbox[:, :2] -= bbox[:, 2:] / 2
bbox[:, 2:] += bbox[:, :2]
x1, y1, x2, y2 = bbox[0]
bboxlist.append([x1, y1, x2, y2, score])
if len(bboxlist) != 0:
bboxlist = np.array(bboxlist)
bboxlist = bboxlist[ lib_math.nms(bboxlist[:,0], bboxlist[:,1], bboxlist[:,2], bboxlist[:,3], bboxlist[:,4], 0.3), : ]
bboxlist = [x for x in bboxlist if x[-1] >= 0.5]
return bboxlist

Binary file not shown.

View file

@ -0,0 +1,145 @@
from pathlib import Path
from typing import List
import numpy as np
from xlib import math as lib_math
from xlib.image import ImageProcessor
from xlib.onnxruntime import (InferenceSession_with_device, ORTDeviceInfo,
get_available_devices_info)
class YoloV5Face:
"""
YoloV5Face face detection model.
arguments
device_info ORTDeviceInfo
use YoloV5Face.get_available_devices()
to determine a list of avaliable devices accepted by model
raises
Exception
"""
@staticmethod
def get_available_devices() -> List[ORTDeviceInfo]:
return get_available_devices_info()
def __init__(self, device_info : ORTDeviceInfo ):
if device_info not in YoloV5Face.get_available_devices():
raise Exception(f'device_info {device_info} is not in available devices for YoloV5Face')
path = Path(__file__).parent / 'YoloV5Face.onnx'
self._sess = sess = InferenceSession_with_device(str(path), device_info)
self._input_name = sess.get_inputs()[0].name
def extract(self, img, threshold : float = 0.3, fixed_window=0, min_face_size=8, augment=False):
"""
arguments
img np.ndarray ndim 2,3,4
fixed_window(0) int size
0 mean don't use
fit image in fixed window
downscale if bigger than window
pad if smaller than window
increases performance, but decreases accuracy
min_face_size(8)
augment(False) bool augment image to increase accuracy
decreases performance
returns a list of [l,t,r,b] for every batch dimension of img
"""
ip = ImageProcessor(img)
_,H,W,_ = ip.get_dims()
if H > 2048 or W > 2048:
fixed_window = 2048
if fixed_window != 0:
fixed_window = max(32, max(1, fixed_window // 32) * 32 )
img_scale = ip.fit_in(fixed_window, fixed_window, pad_to_target=True, allow_upscale=False)
else:
ip.pad_to_next_divisor(64, 64)
img_scale = 1.0
ip.ch(3).to_ufloat32()
_,H,W,_ = ip.get_dims()
preds = self._get_preds(ip.get_image('NCHW'))
if augment:
rl_preds = self._get_preds( ip.flip_horizontal().get_image('NCHW') )
rl_preds[:,:,0] = W-rl_preds[:,:,0]
preds = np.concatenate([preds, rl_preds],1)
faces_per_batch = []
for pred in preds:
pred = pred[pred[...,4] >= threshold]
x,y,w,h,score = pred.T
l, t, r, b = x-w/2, y-h/2, x+w/2, y+h/2
keep = lib_math.nms(l,t,r,b, score, 0.5)
l, t, r, b = l[keep], t[keep], r[keep], b[keep]
faces = []
for l,t,r,b in np.stack([l, t, r, b], -1):
if img_scale != 1.0:
l,t,r,b = l/img_scale, t/img_scale, r/img_scale, b/img_scale
if min(r-l,b-t) < min_face_size:
continue
faces.append( (l,t,r,b) )
faces_per_batch.append(faces)
return faces_per_batch
def _get_preds(self, img):
N,C,H,W = img.shape
preds = self._sess.run(None, {self._input_name: img})
# YoloV5Face returns 3x [N,C*16,H,W].
# C = [cx,cy,w,h,thres, 5*x,y of landmarks, cls_id ]
# Transpose and cut first 5 channels.
pred0, pred1, pred2 = [pred.reshape( (N,C,16,pred.shape[-2], pred.shape[-1]) ).transpose(0,1,3,4,2)[...,0:5] for pred in preds]
pred0 = YoloV5Face.process_pred(pred0, W, H, anchor=[ [4,5],[8,10],[13,16] ] ).reshape( (N, -1, 5) )
pred1 = YoloV5Face.process_pred(pred1, W, H, anchor=[ [23,29],[43,55],[73,105] ] ).reshape( (N, -1, 5) )
pred2 = YoloV5Face.process_pred(pred2, W, H, anchor=[ [146,217],[231,300],[335,433] ] ).reshape( (N, -1, 5) )
return np.concatenate( [pred0, pred1, pred2], 1 )[...,:5]
@staticmethod
def process_pred(pred, img_w, img_h, anchor):
pred_h = pred.shape[-3]
pred_w = pred.shape[-2]
anchor = np.float32(anchor)[None,:,None,None,:]
_xv, _yv, = np.meshgrid(np.arange(pred_w), np.arange(pred_h), )
grid = np.stack((_xv, _yv), 2).reshape((1, 1, pred_h, pred_w, 2)).astype(np.float32)
stride = (img_w // pred_w, img_h // pred_h)
pred[..., [0,1,2,3,4] ] = YoloV5Face._np_sigmoid(pred[..., [0,1,2,3,4] ])
pred[..., 0:2] = (pred[..., 0:2]*2 - 0.5 + grid) * stride
pred[..., 2:4] = (pred[..., 2:4]*2)**2 * anchor
return pred
@staticmethod
def _np_sigmoid(x : np.ndarray):
"""
sigmoid with safe check of overflow
"""
x = -x
c = x > np.log( np.finfo(x.dtype).max )
x[c] = 0.0
result = 1 / (1+np.exp(x))
result[c] = 0.0
return result

View file

@ -0,0 +1,4 @@
from .CenterFace.CenterFace import CenterFace
from .FaceMesh.FaceMesh import FaceMesh
from .S3FD.S3FD import S3FD
from .YoloV5Face.YoloV5Face import YoloV5Face

Binary file not shown.

View file

@ -0,0 +1,410 @@
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def CenterFace_to_onnx(onnx_filepath):
"""Convert Pytorch CenterFace model to ONNX"""
pth_file = Path(__file__).parent / 'CenterFace.pth'
if not pth_file.exists():
raise Exception(f'{pth_file} does not exist.')
net = CenterFaceNet()
net.load_state_dict( torch.load(pth_file) )
torch.onnx.export(net,
torch.from_numpy( np.zeros( (1,3,640,640), dtype=np.float32)),
str(onnx_filepath),
verbose=True,
training=torch.onnx.TrainingMode.TRAINING,
opset_version=12,
do_constant_folding=False,
input_names=['in'],
output_names=['heatmap','scale','offset'],
dynamic_axes={'in' : {0:'batch_size',2:'height',3:'width'},
'heatmap' : {2:'height',3:'width'},
'scale' : {2:'height',3:'width'},
'offset' : {2:'height',3:'width'},
},
)
# class BatchNorm2D(nn.Module):
# def __init__(self, num_features, momentum=0.1, eps=1e-5):
# super().__init__()
# self.num_features = num_features
# self.momentum = momentum
# self.eps = 1e-5
# self.weight = nn.Parameter(torch.Tensor(num_features))
# self.bias = nn.Parameter(torch.Tensor(num_features))
# self.register_buffer('running_mean', torch.zeros(num_features))
# self.register_buffer('running_var', torch.ones(num_features))
# self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
# def forward(self, input : torch.Tensor):
# input_mean = input.mean([0,2,3], keepdim=True)
# v = input-input_mean
# var = (v*v).mean([0,2,3], keepdim=True)
# return self.weight.view([1, self.num_features, 1, 1]) * v / (var + self.eps).sqrt() \
# + self.bias.view([1, self.num_features, 1, 1])
class CenterFaceNet(nn.Module):
def __init__(self):
super().__init__()
self.conv_363 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False)
self.bn_364 = nn.BatchNorm2d(32)
self.dconv_366 = nn.Conv2d(32, 32, 3, padding=1, groups=32, bias=False)
self.bn_367 = nn.BatchNorm2d(32)
self.conv_369 = nn.Conv2d(32, 16, 1, padding=0, bias=False)
self.bn_370 = nn.BatchNorm2d(16)
self.conv_371 = nn.Conv2d(16, 96, 1, padding=0, bias=False)
self.bn_372 = nn.BatchNorm2d(96)
self.dconv_374 = nn.Conv2d(96, 96, 3, stride=2, padding=1, groups=96, bias=False)
self.bn_375 = nn.BatchNorm2d(96)
self.conv_377 = nn.Conv2d(96, 24, 1, padding=0, bias=False)
self.bn_378 = nn.BatchNorm2d(24)
self.conv_379 = nn.Conv2d(24, 144, 1, padding=0, bias=False)
self.bn_380 = nn.BatchNorm2d(144)
self.dconv_382 = nn.Conv2d(144, 144, 3, padding=1, groups=144, bias=False)
self.bn_383 = nn.BatchNorm2d(144)
self.conv_385 = nn.Conv2d(144, 24, 1, padding=0, bias=False)
self.bn_386 = nn.BatchNorm2d(24)
self.conv_388 = nn.Conv2d(24, 144, 1, padding=0, bias=False)
self.bn_389 = nn.BatchNorm2d(144)
self.dconv_391 = nn.Conv2d(144, 144, 3, stride=2, padding=1, groups=144, bias=False)
self.bn_392 = nn.BatchNorm2d(144)
self.conv_394 = nn.Conv2d(144, 32, 1, padding=0, bias=False)
self.bn_395 = nn.BatchNorm2d(32)
self.conv_396 = nn.Conv2d(32, 192, 1, padding=0, bias=False)
self.bn_397 = nn.BatchNorm2d(192)
self.dconv_399 = nn.Conv2d(192, 192, 3, padding=1, groups=192, bias=False)
self.bn_400 = nn.BatchNorm2d(192)
self.conv_402 = nn.Conv2d(192, 32, 1, padding=0, bias=False)
self.bn_403 = nn.BatchNorm2d(32)
self.conv_405 = nn.Conv2d(32, 192, 1, padding=0, bias=False)
self.bn_406 = nn.BatchNorm2d(192)
self.dconv_408 = nn.Conv2d(192, 192, 3, padding=1, groups=192, bias=False)
self.bn_409 = nn.BatchNorm2d(192)
self.conv_411 = nn.Conv2d(192, 32, 1, padding=0, bias=False)
self.bn_412 = nn.BatchNorm2d(32)
self.conv_414 = nn.Conv2d(32, 192, 1, padding=0, bias=False)
self.bn_415 = nn.BatchNorm2d(192)
self.dconv_417 = nn.Conv2d(192, 192, 3, stride=2, padding=1, groups=192, bias=False)
self.bn_418 = nn.BatchNorm2d(192)
self.conv_420 = nn.Conv2d(192, 64, 1, padding=0, bias=False)
self.bn_421 = nn.BatchNorm2d(64)
self.conv_422 = nn.Conv2d(64, 384, 1, padding=0, bias=False)
self.bn_423 = nn.BatchNorm2d(384)
self.dconv_425 = nn.Conv2d(384, 384, 3, padding=1, groups=384, bias=False)
self.bn_426 = nn.BatchNorm2d(384)
self.conv_428 = nn.Conv2d(384, 64, 1, padding=0, bias=False)
self.bn_429 = nn.BatchNorm2d(64)
self.conv_431 = nn.Conv2d(64, 384, 1, padding=0, bias=False)
self.bn_432 = nn.BatchNorm2d(384)
self.dconv_434 = nn.Conv2d(384, 384, 3, padding=1, groups=384, bias=False)
self.bn_435 = nn.BatchNorm2d(384)
self.conv_437 = nn.Conv2d(384, 64, 1, padding=0, bias=False)
self.bn_438 = nn.BatchNorm2d(64)
self.conv_440 = nn.Conv2d(64, 384, 1, padding=0, bias=False)
self.bn_441 = nn.BatchNorm2d(384)
self.dconv_443 = nn.Conv2d(384, 384, 3, padding=1, groups=384, bias=False)
self.bn_444 = nn.BatchNorm2d(384)
self.conv_446 = nn.Conv2d(384, 64, 1, padding=0, bias=False)
self.bn_447 = nn.BatchNorm2d(64)
self.conv_449 = nn.Conv2d(64, 384, 1, padding=0, bias=False)
self.bn_450 = nn.BatchNorm2d(384)
self.dconv_452 = nn.Conv2d(384, 384, 3, padding=1, groups=384, bias=False)
self.bn_453 = nn.BatchNorm2d(384)
self.conv_455 = nn.Conv2d(384, 96, 1, padding=0, bias=False)
self.bn_456 = nn.BatchNorm2d(96)
self.conv_457 = nn.Conv2d(96, 576, 1, padding=0, bias=False)
self.bn_458 = nn.BatchNorm2d(576)
self.dconv_460 = nn.Conv2d(576, 576, 3, padding=1, groups=576, bias=False)
self.bn_461 = nn.BatchNorm2d(576)
self.conv_463 = nn.Conv2d(576, 96, 1, padding=0, bias=False)
self.bn_464 = nn.BatchNorm2d(96)
self.conv_466 = nn.Conv2d(96, 576, 1, padding=0, bias=False)
self.bn_467 = nn.BatchNorm2d(576)
self.dconv_469 = nn.Conv2d(576, 576, 3, padding=1, groups=576, bias=False)
self.bn_470 = nn.BatchNorm2d(576)
self.conv_472 = nn.Conv2d(576, 96, 1, padding=0, bias=False)
self.bn_473 = nn.BatchNorm2d(96)
self.conv_475 = nn.Conv2d(96, 576, 1, padding=0, bias=False)
self.bn_476 = nn.BatchNorm2d(576)
self.dconv_478 = nn.Conv2d(576, 576, 3, stride=2, padding=1, groups=576, bias=False)
self.bn_479 = nn.BatchNorm2d(576)
self.conv_481 = nn.Conv2d(576, 160, 1, padding=0, bias=False)
self.bn_482 = nn.BatchNorm2d(160)
self.conv_483 = nn.Conv2d(160, 960, 1, padding=0, bias=False)
self.bn_484 = nn.BatchNorm2d(960)
self.dconv_486 = nn.Conv2d(960, 960, 3, padding=1, groups=960, bias=False)
self.bn_487 = nn.BatchNorm2d(960)
self.conv_489 = nn.Conv2d(960, 160, 1, padding=0, bias=False)
self.bn_490 = nn.BatchNorm2d(160)
self.conv_492 = nn.Conv2d(160, 960, 1, padding=0, bias=False)
self.bn_493 = nn.BatchNorm2d(960)
self.dconv_495 = nn.Conv2d(960, 960, 3, padding=1, groups=960, bias=False)
self.bn_496 = nn.BatchNorm2d(960)
self.conv_498 = nn.Conv2d(960, 160, 1, padding=0, bias=False)
self.bn_499 = nn.BatchNorm2d(160)
self.conv_501 = nn.Conv2d(160, 960, 1, padding=0, bias=False)
self.bn_502 = nn.BatchNorm2d(960)
self.dconv_504 = nn.Conv2d(960, 960, 3, padding=1, groups=960, bias=False)
self.bn_505 = nn.BatchNorm2d(960)
self.conv_507 = nn.Conv2d(960, 320, 1, padding=0, bias=False)
self.bn_508 = nn.BatchNorm2d(320)
self.conv_509 = nn.Conv2d(320, 24, 1, padding=0, bias=False)
self.bn_510 = nn.BatchNorm2d(24)
self.conv_512 = nn.ConvTranspose2d(24, 24, 2, stride=2, padding=0, bias=False)
self.bn_513 = nn.BatchNorm2d(24)
self.conv_515 = nn.Conv2d(96, 24, 1, padding=0, bias=False)
self.bn_516 = nn.BatchNorm2d(24)
self.conv_519 = nn.ConvTranspose2d(24,24, 2, stride=2, padding=0, bias=False)
self.bn_520 = nn.BatchNorm2d(24)
self.conv_522 = nn.Conv2d(32, 24, 1, padding=0, bias=False)
self.bn_523 = nn.BatchNorm2d(24)
self.conv_526 = nn.ConvTranspose2d(24,24, 2, stride=2, padding=0, bias=False)
self.bn_527 = nn.BatchNorm2d(24)
self.conv_529 = nn.Conv2d(24, 24, 1, padding=0, bias=False)
self.bn_530 = nn.BatchNorm2d(24)
self.conv_533 = nn.Conv2d(24, 24, 3, padding=1, bias=False)
self.bn_534 = nn.BatchNorm2d(24)
self.conv_536 = nn.Conv2d(24, 1, 1)
self.conv_538 = nn.Conv2d(24, 2, 1)
self.conv_539 = nn.Conv2d(24, 2, 1)
self.conv_540 = nn.Conv2d(24, 10, 1)
def forward(self, x):
x = self.conv_363(x)
x = self.bn_364(x)
x = F.relu(x)
x = self.dconv_366(x)
x = self.bn_367(x)
x = F.relu(x)
x = self.conv_369(x)
x = self.bn_370(x)
x = self.conv_371(x)
x = self.bn_372(x)
x = F.relu(x)
x = self.dconv_374(x)
x = self.bn_375(x)
x = F.relu(x)
x = self.conv_377(x)
x = x378 = self.bn_378(x)
x = self.conv_379(x)
x = self.bn_380(x)
x = F.relu(x)
x = self.dconv_382(x)
x = self.bn_383(x)
x = F.relu(x)
x = self.conv_385(x)
x = self.bn_386(x)
x = x387 = x + x378
x = self.conv_388(x)
x = self.bn_389(x)
x = F.relu(x)
x = self.dconv_391(x)
x = self.bn_392(x)
x = F.relu(x)
x = self.conv_394(x)
x = x395 = self.bn_395(x)
x = self.conv_396(x)
x = self.bn_397(x)
x = F.relu(x)
x = self.dconv_399(x)
x = self.bn_400(x)
x = F.relu(x)
x = self.conv_402(x)
x = self.bn_403(x)
x = x404 = x + x395
x = self.conv_405(x)
x = self.bn_406(x)
x = F.relu(x)
x = self.dconv_408(x)
x = self.bn_409(x)
x = F.relu(x)
x = self.conv_411(x)
x = self.bn_412(x)
x = x413 = x + x404
x = self.conv_414(x)
x = self.bn_415(x)
x = F.relu(x)
x = self.dconv_417(x)
x = self.bn_418(x)
x = F.relu(x)
x = self.conv_420(x)
x = x421 = self.bn_421(x)
x = self.conv_422(x)
x = self.bn_423(x)
x = F.relu(x)
x = self.dconv_425(x)
x = self.bn_426(x)
x = F.relu(x)
x = self.conv_428(x)
x = self.bn_429(x)
x = x430 = x + x421
x = self.conv_431(x)
x = self.bn_432(x)
x = F.relu(x)
x = self.dconv_434(x)
x = self.bn_435(x)
x = F.relu(x)
x = self.conv_437(x)
x = self.bn_438(x)
x = x439 = x + x430
x = self.conv_440(x)
x = self.bn_441(x)
x = F.relu(x)
x = self.dconv_443(x)
x = self.bn_444(x)
x = F.relu(x)
x = self.conv_446(x)
x = self.bn_447(x)
x = x + x439
x = self.conv_449(x)
x = self.bn_450(x)
x = F.relu(x)
x = self.dconv_452(x)
x = self.bn_453(x)
x = F.relu(x)
x = self.conv_455(x)
x = x456 = self.bn_456(x)
x = self.conv_457(x)
x = self.bn_458(x)
x = F.relu(x)
x = self.dconv_460(x)
x = self.bn_461(x)
x = F.relu(x)
x = self.conv_463(x)
x = self.bn_464(x)
x = x465 = x + x456
x = self.conv_466(x)
x = self.bn_467(x)
x = F.relu(x)
x = self.dconv_469(x)
x = self.bn_470(x)
x = F.relu(x)
x = self.conv_472(x)
x = self.bn_473(x)
x = x474 = x + x465
x = self.conv_475(x)
x = self.bn_476(x)
x = F.relu(x)
x = self.dconv_478(x)
x = self.bn_479(x)
x = F.relu(x)
x = self.conv_481(x)
x = x482 = self.bn_482(x)
x = self.conv_483(x)
x = self.bn_484(x)
x = F.relu(x)
x = self.dconv_486(x)
x = self.bn_487(x)
x = F.relu(x)
x = self.conv_489(x)
x = self.bn_490(x)
x = x491 = x + x482
x = self.conv_492(x)
x = self.bn_493(x)
x = F.relu(x)
x = self.dconv_495(x)
x = self.bn_496(x)
x = F.relu(x)
x = self.conv_498(x)
x = self.bn_499(x)
x = x + x491
x = self.conv_501(x)
x = self.bn_502(x)
x = F.relu(x)
x = self.dconv_504(x)
x = self.bn_505(x)
x = F.relu(x)
x = self.conv_507(x)
x = self.bn_508(x)
x = self.conv_509(x)
x = self.bn_510(x)
x = F.relu(x)
x = self.conv_512(x)
x = self.bn_513(x)
x = x514 = F.relu(x)
x = self.conv_515(x474)
x = self.bn_516(x)
x = F.relu(x)
x = x + x514
x = self.conv_519(x)
x = self.bn_520(x)
x = x521 = F.relu(x)
x = self.conv_522(x413)
x = self.bn_523(x)
x = F.relu(x)
x = x + x521
x = self.conv_526(x)
x = self.bn_527(x)
x = x528 = F.relu(x)
x = self.conv_529(x387)
x = self.bn_530(x)
x = F.relu(x)
x = x + x528
x = self.conv_533(x)
x = self.bn_534(x)
x = F.relu(x)
heatmap = torch.sigmoid( self.conv_536(x) )
scale = self.conv_538(x)
offset = self.conv_539(x)
return heatmap, scale, offset

261
modelhub/torch/S3FD/S3FD.py Normal file
View file

@ -0,0 +1,261 @@
import operator
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from xlib import math as lib_math
from xlib.image import ImageProcessor
from xlib.torch import TorchDeviceInfo, get_cpu_device
class S3FD:
def __init__(self, device_info : TorchDeviceInfo = None ):
if device_info is None:
device_info = get_cpu_device()
self.device_info = device_info
net = self.net = S3FDNet()
net.load_state_dict( torch.load(str(Path(__file__).parent / 's3fd.pth')) )
net.eval()
if not device_info.is_cpu():
net.cuda(device_info.get_index())
def extract(self, img : np.ndarray, fixed_window, min_face_size=40):
"""
"""
ip = ImageProcessor(img)
if fixed_window != 0:
fixed_window = max(64, max(1, fixed_window // 32) * 32 )
img_scale = ip.fit_in(fixed_window, fixed_window, pad_to_target=True, allow_upscale=False)
else:
ip.pad_to_next_divisor(64, 64)
img_scale = 1.0
img = ip.ch(3).as_float32().apply( lambda img: img - [104,117,123]).get_image('NCHW')
tensor = torch.from_numpy(img)
if not self.device_info.is_cpu():
tensor = tensor.cuda(self.device_info.get_index())
batches_bbox = [x.data.cpu().numpy() for x in self.net(tensor)]
faces_per_batch = []
for batch in range(img.shape[0]):
bbox = self.refine( [ x[batch] for x in batches_bbox ] )
faces = []
for l,t,r,b,c in bbox:
if img_scale != 1.0:
l,t,r,b = l/img_scale, t/img_scale, r/img_scale, b/img_scale
bt = b-t
if min(r-l,bt) < min_face_size:
continue
b += bt*0.1
faces.append ( (l,t,r,b) )
#sort by largest area first
faces = [ [(l,t,r,b), (r-l)*(b-t) ] for (l,t,r,b) in faces ]
faces = sorted(faces, key=operator.itemgetter(1), reverse=True )
faces = [ x[0] for x in faces]
faces_per_batch.append(faces)
return faces_per_batch
def refine(self, olist):
bboxlist = []
variances = [0.1, 0.2]
for i in range(len(olist) // 2):
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
stride = 2**(i + 2) # 4,8,16,32,64,128
for hindex, windex in [*zip(*np.where(ocls[1, :, :] > 0.05))]:
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
score = ocls[1, hindex, windex]
loc = np.ascontiguousarray(oreg[:, hindex, windex]).reshape((1, 4))
priors = np.array([[axc, ayc, stride * 4, stride * 4]])
bbox = np.concatenate((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * np.exp(loc[:, 2:] * variances[1])), 1)
bbox[:, :2] -= bbox[:, 2:] / 2
bbox[:, 2:] += bbox[:, :2]
x1, y1, x2, y2 = bbox[0]
bboxlist.append([x1, y1, x2, y2, score])
if len(bboxlist) != 0:
bboxlist = np.array(bboxlist)
bboxlist = bboxlist[ lib_math.nms(bboxlist[:,0], bboxlist[:,1], bboxlist[:,2], bboxlist[:,3], bboxlist[:,4], 0.3), : ]
bboxlist = [x for x in bboxlist if x[-1] >= 0.5]
return bboxlist
@staticmethod
def save_as_onnx(onnx_filepath):
s3fd = S3FD()
torch.onnx.export(s3fd.net,
torch.from_numpy( np.zeros( (1,3,640,640), dtype=np.float32)),
str(onnx_filepath),
verbose=True,
training=torch.onnx.TrainingMode.EVAL,
opset_version=9,
do_constant_folding=True,
input_names=['in'],
output_names=['cls1', 'reg1', 'cls2', 'reg2', 'cls3', 'reg3', 'cls4', 'reg4', 'cls5', 'reg5', 'cls6', 'reg6'],
dynamic_axes={'in' : {0:'batch_size',2:'height',3:'width'},
'cls1' : {2:'height',3:'width'},
'reg1' : {2:'height',3:'width'},
'cls2' : {2:'height',3:'width'},
'reg2' : {2:'height',3:'width'},
'cls3' : {2:'height',3:'width'},
'reg3' : {2:'height',3:'width'},
'cls4' : {2:'height',3:'width'},
'reg4' : {2:'height',3:'width'},
'cls5' : {2:'height',3:'width'},
'reg5' : {2:'height',3:'width'},
'cls6' : {2:'height',3:'width'},
'reg6' : {2:'height',3:'width'},
},
)
class L2Norm(nn.Module):
def __init__(self, n_channels, scale=1.0):
super().__init__()
self.n_channels = n_channels
self.scale = scale
self.eps = 1e-10
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
self.weight.data *= 0.0
self.weight.data += self.scale
def forward(self, x):
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
x = x / norm * self.weight.view(1, -1, 1, 1)
return x
class S3FDNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
self.conv3_3_norm = L2Norm(256, scale=10)
self.conv4_3_norm = L2Norm(512, scale=8)
self.conv5_3_norm = L2Norm(512, scale=5)
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
def forward(self, x):
h = F.relu(self.conv1_1(x))
h = F.relu(self.conv1_2(h))
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv2_1(h))
h = F.relu(self.conv2_2(h))
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv3_1(h))
h = F.relu(self.conv3_2(h))
h = F.relu(self.conv3_3(h))
f3_3 = h
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv4_1(h))
h = F.relu(self.conv4_2(h))
h = F.relu(self.conv4_3(h))
f4_3 = h
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.conv5_1(h))
h = F.relu(self.conv5_2(h))
h = F.relu(self.conv5_3(h))
f5_3 = h
h = F.max_pool2d(h, 2, 2)
h = F.relu(self.fc6(h))
h = F.relu(self.fc7(h))
ffc7 = h
h = F.relu(self.conv6_1(h))
h = F.relu(self.conv6_2(h))
f6_2 = h
h = F.relu(self.conv7_1(h))
h = F.relu(self.conv7_2(h))
f7_2 = h
f3_3 = self.conv3_3_norm(f3_3)
f4_3 = self.conv4_3_norm(f4_3)
f5_3 = self.conv5_3_norm(f5_3)
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
cls4 = self.fc7_mbox_conf(ffc7)
reg4 = self.fc7_mbox_loc(ffc7)
cls5 = self.conv6_2_mbox_conf(f6_2)
reg5 = self.conv6_2_mbox_loc(f6_2)
cls6 = self.conv7_2_mbox_conf(f7_2)
reg6 = self.conv7_2_mbox_loc(f7_2)
# max-out background label
chunk = torch.chunk(cls1, 4, 1)
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
cls1 = torch.cat ([bmax,chunk[3]], dim=1)
cls1, cls2, cls3, cls4, cls5, cls6 = [ F.softmax(x, dim=1) for x in [cls1, cls2, cls3, cls4, cls5, cls6] ]
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]

Binary file not shown.

View file

@ -0,0 +1,2 @@
from .CenterFace.CenterFace import CenterFace, CenterFace_to_onnx
from .S3FD.S3FD import S3FD