feat: Return a 400 if prompt exceeds max tokens

This commit is contained in:
Fred de Gier 2022-10-20 14:56:39 +02:00
commit e2486698e0
4 changed files with 56 additions and 8 deletions

View file

@ -1,12 +1,16 @@
import logging
import os
import uvicorn
from fastapi import FastAPI, Response
from fastapi import FastAPI, Response, HTTPException
from sse_starlette.sse import EventSourceResponse
from config.log_config import uvicorn_logger
from models import OpenAIinput
from utils.codegen import CodeGenProxy
logging.config.dictConfig(uvicorn_logger)
codegen = CodeGenProxy(
host=os.environ.get("TRITON_HOST", "triton"),
port=os.environ.get("TRITON_PORT", 8001),
@ -22,21 +26,28 @@ app = FastAPI(
)
@app.post("/v1/engines/codegen/completions", status_code=200)
@app.post("/v1/completions", status_code=200)
@app.post("/v1/engines/codegen/completions")
@app.post("/v1/completions")
async def completions(data: OpenAIinput):
data = data.dict()
print(data)
try:
content = codegen(data=data)
except codegen.TokensExceedsMaximum as E:
raise HTTPException(
status_code=400,
detail=str(E)
)
if data.get("stream") is not None:
return EventSourceResponse(
content=codegen(data=data),
content=content,
status_code=200,
media_type="text/event-stream"
)
else:
return Response(
status_code=200,
content=codegen(data=data),
content=content,
media_type="application/json"
)