Fix setup issues and add test script

Signed-off-by: Parth Thakkar <thakkarparth007@gmail.com>
This commit is contained in:
Parth Thakkar 2022-10-21 13:23:10 -05:00
parent 2a91018792
commit c6be12979e
8 changed files with 262 additions and 41 deletions

View file

@ -25,15 +25,16 @@ class TritonPythonModel:
def get_bool(x):
return model_config["parameters"][x]["string_value"].lower() in ["1", "true"]
is_half = get_bool("use_half")
is_half = get_bool("use_half") and torch.cuda.is_available()
# This will make inference marginally slower, but will allow bigger models to fit in GPU
int8 = get_bool("use_int8")
auto_device_map = get_bool("use_auto_device_map")
int8 = get_bool("use_int8") and torch.cuda.is_available()
auto_device_map = get_bool("use_auto_device_map") and torch.cuda.is_available()
print("Cuda available?", torch.cuda.is_available())
print(f"is_half: {is_half}, int8: {int8}, auto_device_map: {auto_device_map}")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if is_half else "auto",
torch_dtype=torch.float16 if is_half else ("auto" if torch.cuda.is_available() else torch.float32),
load_in_8bit=int8,
device_map="auto" if auto_device_map else None,
low_cpu_mem_usage=True,