mirror of
https://github.com/fauxpilot/fauxpilot.git
synced 2025-08-14 02:28:06 -07:00
feat: Return a 400 if prompt exceeds max tokens
This commit is contained in:
parent
31d2349dbb
commit
e2486698e0
4 changed files with 56 additions and 8 deletions
|
@ -1,12 +1,16 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Response
|
||||
from fastapi import FastAPI, Response, HTTPException
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from config.log_config import uvicorn_logger
|
||||
from models import OpenAIinput
|
||||
from utils.codegen import CodeGenProxy
|
||||
|
||||
logging.config.dictConfig(uvicorn_logger)
|
||||
|
||||
codegen = CodeGenProxy(
|
||||
host=os.environ.get("TRITON_HOST", "triton"),
|
||||
port=os.environ.get("TRITON_PORT", 8001),
|
||||
|
@ -22,21 +26,28 @@ app = FastAPI(
|
|||
)
|
||||
|
||||
|
||||
@app.post("/v1/engines/codegen/completions", status_code=200)
|
||||
@app.post("/v1/completions", status_code=200)
|
||||
@app.post("/v1/engines/codegen/completions")
|
||||
@app.post("/v1/completions")
|
||||
async def completions(data: OpenAIinput):
|
||||
data = data.dict()
|
||||
print(data)
|
||||
try:
|
||||
content = codegen(data=data)
|
||||
except codegen.TokensExceedsMaximum as E:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=str(E)
|
||||
)
|
||||
|
||||
if data.get("stream") is not None:
|
||||
return EventSourceResponse(
|
||||
content=codegen(data=data),
|
||||
content=content,
|
||||
status_code=200,
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
return Response(
|
||||
status_code=200,
|
||||
content=codegen(data=data),
|
||||
content=content,
|
||||
media_type="application/json"
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue