diff --git a/xlib/onnxruntime/InferenceSession.py b/xlib/onnxruntime/InferenceSession.py index 498050c..e3eed9b 100644 --- a/xlib/onnxruntime/InferenceSession.py +++ b/xlib/onnxruntime/InferenceSession.py @@ -1,9 +1,10 @@ +import onnx import onnxruntime as rt - +from io import BytesIO from .device import ORTDeviceInfo -def InferenceSession_with_device(onnx_modelpath, device_info : ORTDeviceInfo): +def InferenceSession_with_device(onnx_model_or_path, device_info : ORTDeviceInfo): """ Construct onnxruntime.InferenceSession with this Device. @@ -11,6 +12,11 @@ def InferenceSession_with_device(onnx_modelpath, device_info : ORTDeviceInfo): can raise Exception """ + + if isinstance(onnx_model_or_path, onnx.ModelProto): + b = BytesIO() + onnx.save(onnx_model_or_path, b) + onnx_model_or_path = b.getvalue() prs = rt.get_available_providers() @@ -28,5 +34,5 @@ def InferenceSession_with_device(onnx_modelpath, device_info : ORTDeviceInfo): #sess_options.enable_mem_pattern = False #for DmlExecutionProvider sess_options.log_severity_level = 4 sess_options.log_verbosity_level = -1 - sess = rt.InferenceSession(onnx_modelpath, providers=providers, sess_options=sess_options) + sess = rt.InferenceSession(onnx_model_or_path, providers=providers, sess_options=sess_options) return sess