mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-06 04:52:14 -07:00
38 lines
1.3 KiB
Python
38 lines
1.3 KiB
Python
import onnx
|
|
import onnxruntime as rt
|
|
from io import BytesIO
|
|
from .device import ORTDeviceInfo
|
|
|
|
|
|
def InferenceSession_with_device(onnx_model_or_path, device_info : ORTDeviceInfo):
|
|
"""
|
|
Construct onnxruntime.InferenceSession with this Device.
|
|
|
|
device_info ORTDeviceInfo
|
|
|
|
can raise Exception
|
|
"""
|
|
|
|
if isinstance(onnx_model_or_path, onnx.ModelProto):
|
|
b = BytesIO()
|
|
onnx.save(onnx_model_or_path, b)
|
|
onnx_model_or_path = b.getvalue()
|
|
|
|
prs = rt.get_available_providers()
|
|
|
|
if device_info.is_cpu():
|
|
if 'CPUExecutionProvider' not in prs:
|
|
raise Exception('CPUExecutionProvider is not avaiable in onnxruntime')
|
|
providers = ['CPUExecutionProvider']
|
|
else:
|
|
if 'CUDAExecutionProvider' not in prs:
|
|
raise Exception('CUDAExecutionProvider is not avaiable in onnxruntime')
|
|
providers = [ ('CUDAExecutionProvider', {'device_id': device_info.get_index() }) ]
|
|
#providers = [ ('DmlExecutionProvider', {'device_id': 1 }) ]
|
|
|
|
sess_options = rt.SessionOptions()
|
|
#sess_options.enable_mem_pattern = False #for DmlExecutionProvider
|
|
sess_options.log_severity_level = 4
|
|
sess_options.log_verbosity_level = -1
|
|
sess = rt.InferenceSession(onnx_model_or_path, providers=providers, sess_options=sess_options)
|
|
return sess
|