mirror of
https://github.com/fauxpilot/fauxpilot.git
synced 2025-08-14 02:28:06 -07:00
Pep8 formatting
This commit is contained in:
parent
01f1cbb629
commit
4f936c3049
2 changed files with 18 additions and 13 deletions
|
@ -39,4 +39,4 @@ config = template.substitute(
|
||||||
use_auto_device_map=args.use_auto_device_map,
|
use_auto_device_map=args.use_auto_device_map,
|
||||||
)
|
)
|
||||||
with open(model_dir_path/'../config.pbtxt', 'w') as f:
|
with open(model_dir_path/'../config.pbtxt', 'w') as f:
|
||||||
f.write(config)
|
f.write(config)
|
||||||
|
|
|
@ -1,28 +1,33 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
import torch
|
||||||
from transformers import AutoTokenizer
|
|
||||||
import triton_python_backend_utils as pb_utils
|
import triton_python_backend_utils as pb_utils
|
||||||
from torch.utils.dlpack import to_dlpack, from_dlpack
|
from torch.utils.dlpack import to_dlpack, from_dlpack
|
||||||
import torch
|
from transformers import AutoModelForCausalLM
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def pb2torch(request, name):
|
def pb2torch(request, name):
|
||||||
tensor = pb_utils.get_input_tensor_by_name(request, name)
|
tensor = pb_utils.get_input_tensor_by_name(request, name)
|
||||||
return from_dlpack(tensor.to_dlpack())
|
return from_dlpack(tensor.to_dlpack())
|
||||||
|
|
||||||
|
|
||||||
def torch2pb(name, tensor):
|
def torch2pb(name, tensor):
|
||||||
return pb_utils.Tensor.from_dlpack(name, to_dlpack(tensor))
|
return pb_utils.Tensor.from_dlpack(name, to_dlpack(tensor))
|
||||||
|
|
||||||
|
|
||||||
class TritonPythonModel:
|
class TritonPythonModel:
|
||||||
def initialize(self, args):
|
def initialize(self, args):
|
||||||
self.model_config = model_config = json.loads(args["model_config"])
|
self.model_config = model_config = json.loads(args["model_config"])
|
||||||
org_name = model_config["parameters"].get("org_name", {"string_value": "Salesforce"})["string_value"]
|
org_name = model_config["parameters"].get("org_name", {"string_value": "Salesforce"})["string_value"]
|
||||||
model_name = org_name + "/" + model_config["parameters"]["model_name"]["string_value"]
|
model_name = org_name + "/" + model_config["parameters"]["model_name"]["string_value"]
|
||||||
|
|
||||||
get_bool = lambda x: model_config["parameters"][x]["string_value"].lower() in ["1", "true"]
|
def get_bool(x):
|
||||||
|
return model_config["parameters"][x]["string_value"].lower() in ["1", "true"]
|
||||||
|
|
||||||
is_half = get_bool("use_half")
|
is_half = get_bool("use_half")
|
||||||
int8 = get_bool("use_int8") # this will make inference marginally slower, but will allow bigger models to fit in GPU
|
# This will make inference marginally slower, but will allow bigger models to fit in GPU
|
||||||
|
int8 = get_bool("use_int8")
|
||||||
auto_device_map = get_bool("use_auto_device_map")
|
auto_device_map = get_bool("use_auto_device_map")
|
||||||
|
|
||||||
print(f"is_half: {is_half}, int8: {int8}, auto_device_map: {auto_device_map}")
|
print(f"is_half: {is_half}, int8: {int8}, auto_device_map: {auto_device_map}")
|
||||||
|
@ -37,8 +42,8 @@ class TritonPythonModel:
|
||||||
print(f"Model {model_name} Loaded. Footprint: {self.model.get_memory_footprint()}")
|
print(f"Model {model_name} Loaded. Footprint: {self.model.get_memory_footprint()}")
|
||||||
|
|
||||||
# set max_batch_size
|
# set max_batch_size
|
||||||
self.max_batch_size = 0 # model_config["max_batch_size"]
|
self.max_batch_size = 0 # model_config["max_batch_size"]
|
||||||
|
|
||||||
def execute(self, requests):
|
def execute(self, requests):
|
||||||
# TODO: don't just loop over requests. batch them up
|
# TODO: don't just loop over requests. batch them up
|
||||||
|
|
||||||
|
@ -55,7 +60,7 @@ class TritonPythonModel:
|
||||||
attention_mask = torch.zeros(input_ids_torch.shape, dtype=torch.long)
|
attention_mask = torch.zeros(input_ids_torch.shape, dtype=torch.long)
|
||||||
for i, l in enumerate(input_lengths_torch):
|
for i, l in enumerate(input_lengths_torch):
|
||||||
attention_mask[i, :l] = 1
|
attention_mask[i, :l] = 1
|
||||||
|
|
||||||
# Output length
|
# Output length
|
||||||
max_new_tokens = request_output_len_torch[0][0]
|
max_new_tokens = request_output_len_torch[0][0]
|
||||||
|
|
||||||
|
@ -71,9 +76,9 @@ class TritonPythonModel:
|
||||||
max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p, num_return_sequences=n_samples,
|
max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p, num_return_sequences=n_samples,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
# assert len(output_ids.shape) == 2, "huggingface format is batch x seq_len"
|
|
||||||
# assert output_ids.shape[0] == input_ids_torch.shape[0], "expecting batch size to match input"
|
# client wants batch x beam_width x seq_len and we don't support beam_width yet
|
||||||
output_ids = output_ids.unsqueeze(1) # client wants batch x beam_width x seq_len and we don't support beam_width yet
|
output_ids = output_ids.unsqueeze(1)
|
||||||
|
|
||||||
# create output tensors
|
# create output tensors
|
||||||
out_tensor_pb = torch2pb("output_ids", output_ids)
|
out_tensor_pb = torch2pb("output_ids", output_ids)
|
||||||
|
@ -88,4 +93,4 @@ class TritonPythonModel:
|
||||||
response = pb_utils.InferenceResponse([out_tensor_pb, sequence_length_pb])
|
response = pb_utils.InferenceResponse([out_tensor_pb, sequence_length_pb])
|
||||||
responses.append(response)
|
responses.append(response)
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue