mirror of
https://github.com/fauxpilot/fauxpilot.git
synced 2025-07-16 10:03:25 -07:00
feat: Return a 400 if prompt exceeds max tokens
This commit is contained in:
parent
31d2349dbb
commit
e2486698e0
4 changed files with 56 additions and 8 deletions
|
@ -1,12 +1,16 @@
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, Response
|
from fastapi import FastAPI, Response, HTTPException
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
|
from config.log_config import uvicorn_logger
|
||||||
from models import OpenAIinput
|
from models import OpenAIinput
|
||||||
from utils.codegen import CodeGenProxy
|
from utils.codegen import CodeGenProxy
|
||||||
|
|
||||||
|
logging.config.dictConfig(uvicorn_logger)
|
||||||
|
|
||||||
codegen = CodeGenProxy(
|
codegen = CodeGenProxy(
|
||||||
host=os.environ.get("TRITON_HOST", "triton"),
|
host=os.environ.get("TRITON_HOST", "triton"),
|
||||||
port=os.environ.get("TRITON_PORT", 8001),
|
port=os.environ.get("TRITON_PORT", 8001),
|
||||||
|
@ -22,21 +26,28 @@ app = FastAPI(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/engines/codegen/completions", status_code=200)
|
@app.post("/v1/engines/codegen/completions")
|
||||||
@app.post("/v1/completions", status_code=200)
|
@app.post("/v1/completions")
|
||||||
async def completions(data: OpenAIinput):
|
async def completions(data: OpenAIinput):
|
||||||
data = data.dict()
|
data = data.dict()
|
||||||
print(data)
|
try:
|
||||||
|
content = codegen(data=data)
|
||||||
|
except codegen.TokensExceedsMaximum as E:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=str(E)
|
||||||
|
)
|
||||||
|
|
||||||
if data.get("stream") is not None:
|
if data.get("stream") is not None:
|
||||||
return EventSourceResponse(
|
return EventSourceResponse(
|
||||||
content=codegen(data=data),
|
content=content,
|
||||||
status_code=200,
|
status_code=200,
|
||||||
media_type="text/event-stream"
|
media_type="text/event-stream"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return Response(
|
return Response(
|
||||||
status_code=200,
|
status_code=200,
|
||||||
content=codegen(data=data),
|
content=content,
|
||||||
media_type="application/json"
|
media_type="application/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
0
copilot_proxy/config/__init__.py
Normal file
0
copilot_proxy/config/__init__.py
Normal file
27
copilot_proxy/config/log_config.py
Normal file
27
copilot_proxy/config/log_config.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
# The uvicorn_logger is used to add timestamps
|
||||||
|
|
||||||
|
uvicorn_logger = {
|
||||||
|
"version": 1,
|
||||||
|
"disable_existing_loggers": False,
|
||||||
|
"formatters": {
|
||||||
|
"access": {
|
||||||
|
"()": "uvicorn.logging.AccessFormatter",
|
||||||
|
"fmt": '%(levelprefix)s %(asctime)s :: %(client_addr)s - "%(request_line)s" %(status_code)s',
|
||||||
|
"use_colors": True
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"handlers": {
|
||||||
|
"access": {
|
||||||
|
"formatter": "access",
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
|
"stream": "ext://sys.stdout",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"loggers": {
|
||||||
|
"uvicorn.access": {
|
||||||
|
"handlers": ["access"],
|
||||||
|
# "level": "INFO",
|
||||||
|
"propagate": False
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
|
@ -21,6 +21,9 @@ class CodeGenProxy:
|
||||||
# Max number of tokens the model can handle
|
# Max number of tokens the model can handle
|
||||||
self.MAX_MODEL_LEN = 2048
|
self.MAX_MODEL_LEN = 2048
|
||||||
|
|
||||||
|
class TokensExceedsMaximum(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_tensor(name: str, tensor_input):
|
def prepare_tensor(name: str, tensor_input):
|
||||||
t = client_util.InferInput(
|
t = client_util.InferInput(
|
||||||
|
@ -78,8 +81,15 @@ class CodeGenProxy:
|
||||||
prompt_len = input_start_ids.shape[1]
|
prompt_len = input_start_ids.shape[1]
|
||||||
input_len = prompt_len * np.ones([input_start_ids.shape[0], 1]).astype(np.uint32)
|
input_len = prompt_len * np.ones([input_start_ids.shape[0], 1]).astype(np.uint32)
|
||||||
max_tokens = data.get('max_tokens', 16)
|
max_tokens = data.get('max_tokens', 16)
|
||||||
if max_tokens + input_len[0][0] > self.MAX_MODEL_LEN:
|
prompt_tokens: int = input_len[0][0]
|
||||||
raise ValueError("Max tokens + prompt length exceeds maximum model length")
|
requested_tokens = max_tokens + prompt_tokens
|
||||||
|
if requested_tokens > self.MAX_MODEL_LEN:
|
||||||
|
print(1)
|
||||||
|
raise self.TokensExceedsMaximum(
|
||||||
|
f"This model's maximum context length is {self.MAX_MODEL_LEN}, however you requested "
|
||||||
|
f"{requested_tokens} tokens ({prompt_tokens} in your prompt; {max_tokens} for the completion). "
|
||||||
|
f"Please reduce your prompt; or completion length."
|
||||||
|
)
|
||||||
output_len = np.ones_like(input_len).astype(np.uint32) * max_tokens
|
output_len = np.ones_like(input_len).astype(np.uint32) * max_tokens
|
||||||
num_logprobs = data.get('logprobs', -1)
|
num_logprobs = data.get('logprobs', -1)
|
||||||
if num_logprobs is None:
|
if num_logprobs is None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue