Fix segfault issue

I'd been using dlpack for copying triton tensors to torch tensors,
which I did because it was advertised to perform zero copy transfers.
Turns out that only worked on my laptop, and didn't work on other machines.
IDK why. But for now, I'm just copying the tensors as triton<->numpy<->torch.
That works on the VM on which earlier code was segfaulting

Signed-off-by: Parth <thakkarparth007@gmail.com>
This commit is contained in:
Parth 2022-11-19 18:32:50 +00:00
parent f0a12b5e8e
commit 7ea388fe19
4 changed files with 13 additions and 9 deletions

View file

@ -2,18 +2,23 @@ import json
import torch
import triton_python_backend_utils as pb_utils
from torch.utils.dlpack import to_dlpack, from_dlpack
# Using dlpack causes segfaults on some machines, so not using it for now
# But it supports zero copy transfer from triton tensors to torch tensors,
# so worth investigating further
# from torch.utils.dlpack import to_dlpack, from_dlpack
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
def pb2torch(request, name):
tensor = pb_utils.get_input_tensor_by_name(request, name)
return from_dlpack(tensor.to_dlpack())
return torch.from_numpy(tensor.as_numpy())
# return from_dlpack(tensor.to_dlpack())
def torch2pb(name, tensor):
return pb_utils.Tensor.from_dlpack(name, to_dlpack(tensor))
return pb_utils.Tensor(name, tensor.numpy())
# return pb_utils.Tensor.from_dlpack(name, to_dlpack(tensor))
class TritonPythonModel: