mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-16 10:03:42 -07:00
_
This commit is contained in:
parent
9366b011b9
commit
f890d5321d
1 changed files with 9 additions and 3 deletions
|
@ -1,9 +1,10 @@
|
||||||
|
import onnx
|
||||||
import onnxruntime as rt
|
import onnxruntime as rt
|
||||||
|
from io import BytesIO
|
||||||
from .device import ORTDeviceInfo
|
from .device import ORTDeviceInfo
|
||||||
|
|
||||||
|
|
||||||
def InferenceSession_with_device(onnx_modelpath, device_info : ORTDeviceInfo):
|
def InferenceSession_with_device(onnx_model_or_path, device_info : ORTDeviceInfo):
|
||||||
"""
|
"""
|
||||||
Construct onnxruntime.InferenceSession with this Device.
|
Construct onnxruntime.InferenceSession with this Device.
|
||||||
|
|
||||||
|
@ -11,6 +12,11 @@ def InferenceSession_with_device(onnx_modelpath, device_info : ORTDeviceInfo):
|
||||||
|
|
||||||
can raise Exception
|
can raise Exception
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(onnx_model_or_path, onnx.ModelProto):
|
||||||
|
b = BytesIO()
|
||||||
|
onnx.save(onnx_model_or_path, b)
|
||||||
|
onnx_model_or_path = b.getvalue()
|
||||||
|
|
||||||
prs = rt.get_available_providers()
|
prs = rt.get_available_providers()
|
||||||
|
|
||||||
|
@ -28,5 +34,5 @@ def InferenceSession_with_device(onnx_modelpath, device_info : ORTDeviceInfo):
|
||||||
#sess_options.enable_mem_pattern = False #for DmlExecutionProvider
|
#sess_options.enable_mem_pattern = False #for DmlExecutionProvider
|
||||||
sess_options.log_severity_level = 4
|
sess_options.log_severity_level = 4
|
||||||
sess_options.log_verbosity_level = -1
|
sess_options.log_verbosity_level = -1
|
||||||
sess = rt.InferenceSession(onnx_modelpath, providers=providers, sess_options=sess_options)
|
sess = rt.InferenceSession(onnx_model_or_path, providers=providers, sess_options=sess_options)
|
||||||
return sess
|
return sess
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue