mirror of
https://github.com/fauxpilot/fauxpilot.git
synced 2025-08-19 13:01:53 -07:00
Pep8 formatting
This commit is contained in:
parent
01f1cbb629
commit
4f936c3049
2 changed files with 18 additions and 13 deletions
|
@ -1,28 +1,33 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
import torch
|
||||||
from transformers import AutoTokenizer
|
|
||||||
import triton_python_backend_utils as pb_utils
|
import triton_python_backend_utils as pb_utils
|
||||||
from torch.utils.dlpack import to_dlpack, from_dlpack
|
from torch.utils.dlpack import to_dlpack, from_dlpack
|
||||||
import torch
|
from transformers import AutoModelForCausalLM
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def pb2torch(request, name):
|
def pb2torch(request, name):
|
||||||
tensor = pb_utils.get_input_tensor_by_name(request, name)
|
tensor = pb_utils.get_input_tensor_by_name(request, name)
|
||||||
return from_dlpack(tensor.to_dlpack())
|
return from_dlpack(tensor.to_dlpack())
|
||||||
|
|
||||||
|
|
||||||
def torch2pb(name, tensor):
|
def torch2pb(name, tensor):
|
||||||
return pb_utils.Tensor.from_dlpack(name, to_dlpack(tensor))
|
return pb_utils.Tensor.from_dlpack(name, to_dlpack(tensor))
|
||||||
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
def initialize(self, args):
|
def initialize(self, args):
|
||||||
self.model_config = model_config = json.loads(args["model_config"])
|
self.model_config = model_config = json.loads(args["model_config"])
|
||||||
org_name = model_config["parameters"].get("org_name", {"string_value": "Salesforce"})["string_value"]
|
org_name = model_config["parameters"].get("org_name", {"string_value": "Salesforce"})["string_value"]
|
||||||
model_name = org_name + "/" + model_config["parameters"]["model_name"]["string_value"]
|
model_name = org_name + "/" + model_config["parameters"]["model_name"]["string_value"]
|
||||||
|
|
||||||
get_bool = lambda x: model_config["parameters"][x]["string_value"].lower() in ["1", "true"]
|
def get_bool(x):
|
||||||
|
return model_config["parameters"][x]["string_value"].lower() in ["1", "true"]
|
||||||
|
|
||||||
is_half = get_bool("use_half")
|
is_half = get_bool("use_half")
|
||||||
int8 = get_bool("use_int8") # this will make inference marginally slower, but will allow bigger models to fit in GPU
|
# This will make inference marginally slower, but will allow bigger models to fit in GPU
|
||||||
|
int8 = get_bool("use_int8")
|
||||||
auto_device_map = get_bool("use_auto_device_map")
|
auto_device_map = get_bool("use_auto_device_map")
|
||||||
|
|
||||||
print(f"is_half: {is_half}, int8: {int8}, auto_device_map: {auto_device_map}")
|
print(f"is_half: {is_half}, int8: {int8}, auto_device_map: {auto_device_map}")
|
||||||
|
@ -71,9 +76,9 @@ class TritonPythonModel:
|
||||||
max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p, num_return_sequences=n_samples,
|
max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p, num_return_sequences=n_samples,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
# assert len(output_ids.shape) == 2, "huggingface format is batch x seq_len"
|
|
||||||
# assert output_ids.shape[0] == input_ids_torch.shape[0], "expecting batch size to match input"
|
# client wants batch x beam_width x seq_len and we don't support beam_width yet
|
||||||
output_ids = output_ids.unsqueeze(1) # client wants batch x beam_width x seq_len and we don't support beam_width yet
|
output_ids = output_ids.unsqueeze(1)
|
||||||
|
|
||||||
# create output tensors
|
# create output tensors
|
||||||
out_tensor_pb = torch2pb("output_ids", output_ids)
|
out_tensor_pb = torch2pb("output_ids", output_ids)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue