mirror of
https://github.com/fauxpilot/fauxpilot.git
synced 2025-07-06 21:11:53 -07:00
Fix setup issues and add test script
Signed-off-by: Parth Thakkar <thakkarparth007@gmail.com>
This commit is contained in:
parent
2a91018792
commit
c6be12979e
8 changed files with 262 additions and 41 deletions
|
@ -25,15 +25,16 @@ class TritonPythonModel:
|
|||
def get_bool(x):
|
||||
return model_config["parameters"][x]["string_value"].lower() in ["1", "true"]
|
||||
|
||||
is_half = get_bool("use_half")
|
||||
is_half = get_bool("use_half") and torch.cuda.is_available()
|
||||
# This will make inference marginally slower, but will allow bigger models to fit in GPU
|
||||
int8 = get_bool("use_int8")
|
||||
auto_device_map = get_bool("use_auto_device_map")
|
||||
int8 = get_bool("use_int8") and torch.cuda.is_available()
|
||||
auto_device_map = get_bool("use_auto_device_map") and torch.cuda.is_available()
|
||||
|
||||
print("Cuda available?", torch.cuda.is_available())
|
||||
print(f"is_half: {is_half}, int8: {int8}, auto_device_map: {auto_device_map}")
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if is_half else "auto",
|
||||
torch_dtype=torch.float16 if is_half else ("auto" if torch.cuda.is_available() else torch.float32),
|
||||
load_in_8bit=int8,
|
||||
device_map="auto" if auto_device_map else None,
|
||||
low_cpu_mem_usage=True,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue