mirror of
https://github.com/fauxpilot/fauxpilot.git
synced 2025-07-16 10:03:25 -07:00
feat: Return a 400 if prompt exceeds max tokens
This commit is contained in:
parent
31d2349dbb
commit
e2486698e0
4 changed files with 56 additions and 8 deletions
|
@ -1,12 +1,16 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Response
|
||||
from fastapi import FastAPI, Response, HTTPException
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from config.log_config import uvicorn_logger
|
||||
from models import OpenAIinput
|
||||
from utils.codegen import CodeGenProxy
|
||||
|
||||
logging.config.dictConfig(uvicorn_logger)
|
||||
|
||||
codegen = CodeGenProxy(
|
||||
host=os.environ.get("TRITON_HOST", "triton"),
|
||||
port=os.environ.get("TRITON_PORT", 8001),
|
||||
|
@ -22,21 +26,28 @@ app = FastAPI(
|
|||
)
|
||||
|
||||
|
||||
@app.post("/v1/engines/codegen/completions", status_code=200)
|
||||
@app.post("/v1/completions", status_code=200)
|
||||
@app.post("/v1/engines/codegen/completions")
|
||||
@app.post("/v1/completions")
|
||||
async def completions(data: OpenAIinput):
|
||||
data = data.dict()
|
||||
print(data)
|
||||
try:
|
||||
content = codegen(data=data)
|
||||
except codegen.TokensExceedsMaximum as E:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=str(E)
|
||||
)
|
||||
|
||||
if data.get("stream") is not None:
|
||||
return EventSourceResponse(
|
||||
content=codegen(data=data),
|
||||
content=content,
|
||||
status_code=200,
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
return Response(
|
||||
status_code=200,
|
||||
content=codegen(data=data),
|
||||
content=content,
|
||||
media_type="application/json"
|
||||
)
|
||||
|
||||
|
|
0
copilot_proxy/config/__init__.py
Normal file
0
copilot_proxy/config/__init__.py
Normal file
27
copilot_proxy/config/log_config.py
Normal file
27
copilot_proxy/config/log_config.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
# The uvicorn_logger is used to add timestamps
|
||||
|
||||
uvicorn_logger = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"access": {
|
||||
"()": "uvicorn.logging.AccessFormatter",
|
||||
"fmt": '%(levelprefix)s %(asctime)s :: %(client_addr)s - "%(request_line)s" %(status_code)s',
|
||||
"use_colors": True
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"access": {
|
||||
"formatter": "access",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"uvicorn.access": {
|
||||
"handlers": ["access"],
|
||||
# "level": "INFO",
|
||||
"propagate": False
|
||||
},
|
||||
},
|
||||
}
|
|
@ -21,6 +21,9 @@ class CodeGenProxy:
|
|||
# Max number of tokens the model can handle
|
||||
self.MAX_MODEL_LEN = 2048
|
||||
|
||||
class TokensExceedsMaximum(Exception):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def prepare_tensor(name: str, tensor_input):
|
||||
t = client_util.InferInput(
|
||||
|
@ -78,8 +81,15 @@ class CodeGenProxy:
|
|||
prompt_len = input_start_ids.shape[1]
|
||||
input_len = prompt_len * np.ones([input_start_ids.shape[0], 1]).astype(np.uint32)
|
||||
max_tokens = data.get('max_tokens', 16)
|
||||
if max_tokens + input_len[0][0] > self.MAX_MODEL_LEN:
|
||||
raise ValueError("Max tokens + prompt length exceeds maximum model length")
|
||||
prompt_tokens: int = input_len[0][0]
|
||||
requested_tokens = max_tokens + prompt_tokens
|
||||
if requested_tokens > self.MAX_MODEL_LEN:
|
||||
print(1)
|
||||
raise self.TokensExceedsMaximum(
|
||||
f"This model's maximum context length is {self.MAX_MODEL_LEN}, however you requested "
|
||||
f"{requested_tokens} tokens ({prompt_tokens} in your prompt; {max_tokens} for the completion). "
|
||||
f"Please reduce your prompt; or completion length."
|
||||
)
|
||||
output_len = np.ones_like(input_len).astype(np.uint32) * max_tokens
|
||||
num_logprobs = data.get('logprobs', -1)
|
||||
if num_logprobs is None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue