This commit is contained in:
iperov 2021-09-02 19:34:53 +04:00
parent 9366b011b9
commit f890d5321d

View file

@ -1,9 +1,10 @@
import onnx
import onnxruntime as rt
from io import BytesIO
from .device import ORTDeviceInfo
def InferenceSession_with_device(onnx_modelpath, device_info : ORTDeviceInfo):
def InferenceSession_with_device(onnx_model_or_path, device_info : ORTDeviceInfo):
"""
Construct onnxruntime.InferenceSession with this Device.
@ -11,6 +12,11 @@ def InferenceSession_with_device(onnx_modelpath, device_info : ORTDeviceInfo):
can raise Exception
"""
if isinstance(onnx_model_or_path, onnx.ModelProto):
b = BytesIO()
onnx.save(onnx_model_or_path, b)
onnx_model_or_path = b.getvalue()
prs = rt.get_available_providers()
@ -28,5 +34,5 @@ def InferenceSession_with_device(onnx_modelpath, device_info : ORTDeviceInfo):
#sess_options.enable_mem_pattern = False #for DmlExecutionProvider
sess_options.log_severity_level = 4
sess_options.log_verbosity_level = -1
sess = rt.InferenceSession(onnx_modelpath, providers=providers, sess_options=sess_options)
sess = rt.InferenceSession(onnx_model_or_path, providers=providers, sess_options=sess_options)
return sess