added DirectX12-compatible cards support through onnxruntime-directml

This commit is contained in:
iperov 2021-09-09 17:15:30 +04:00
parent 071bf80681
commit 6d504d5969
5 changed files with 163 additions and 171 deletions

View file

@ -12,27 +12,24 @@ def InferenceSession_with_device(onnx_model_or_path, device_info : ORTDeviceInfo
can raise Exception
"""
if isinstance(onnx_model_or_path, onnx.ModelProto):
b = BytesIO()
onnx.save(onnx_model_or_path, b)
onnx_model_or_path = b.getvalue()
prs = rt.get_available_providers()
device_ep = device_info.get_execution_provider()
if device_ep not in rt.get_available_providers():
raise Exception(f'{device_ep} is not avaiable in onnxruntime')
if device_info.is_cpu():
if 'CPUExecutionProvider' not in prs:
raise Exception('CPUExecutionProvider is not avaiable in onnxruntime')
providers = ['CPUExecutionProvider']
else:
if 'CUDAExecutionProvider' not in prs:
raise Exception('CUDAExecutionProvider is not avaiable in onnxruntime')
providers = [ ('CUDAExecutionProvider', {'device_id': device_info.get_index() }) ]
#providers = [ ('DmlExecutionProvider', {'device_id': 1 }) ]
ep_flags = {}
if device_ep in ['CUDAExecutionProvider','DmlExecutionProvider']:
ep_flags['device_id'] = device_info.get_index()
sess_options = rt.SessionOptions()
#sess_options.enable_mem_pattern = False #for DmlExecutionProvider
sess_options.log_severity_level = 4
sess_options.log_verbosity_level = -1
sess = rt.InferenceSession(onnx_model_or_path, providers=providers, sess_options=sess_options)
if device_ep == 'DmlExecutionProvider':
sess_options.enable_mem_pattern = False
sess = rt.InferenceSession(onnx_model_or_path, providers=[ (device_ep, ep_flags) ], sess_options=sess_options)
return sess