mirror of
https://github.com/fauxpilot/fauxpilot.git
synced 2025-08-14 10:37:43 -07:00
Merge branch 'main' into python_backend
Signed-off-by: Parth Thakkar <thakkarparth007@gmail.com>
This commit is contained in:
commit
f0a12b5e8e
10 changed files with 145 additions and 13 deletions
|
@ -1,11 +1,17 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Response
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from config.log_config import uvicorn_logger
|
||||
from models import OpenAIinput
|
||||
from utils.codegen import CodeGenProxy
|
||||
from utils.errors import FauxPilotException
|
||||
|
||||
logging.config.dictConfig(uvicorn_logger)
|
||||
|
||||
codegen = CodeGenProxy(
|
||||
host=os.environ.get("TRITON_HOST", "triton"),
|
||||
|
@ -21,22 +27,37 @@ app = FastAPI(
|
|||
swagger_ui_parameters={"defaultModelsExpandDepth": -1}
|
||||
)
|
||||
|
||||
@app.exception_handler(FauxPilotException)
|
||||
async def fauxpilot_handler(request: Request, exc: FauxPilotException):
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=exc.json()
|
||||
)
|
||||
|
||||
@app.post("/v1/engines/codegen/completions", status_code=200)
|
||||
@app.post("/v1/completions", status_code=200)
|
||||
@app.post("/v1/engines/codegen/completions")
|
||||
@app.post("/v1/completions")
|
||||
async def completions(data: OpenAIinput):
|
||||
data = data.dict()
|
||||
print(data)
|
||||
try:
|
||||
content = codegen(data=data)
|
||||
except codegen.TokensExceedsMaximum as E:
|
||||
raise FauxPilotException(
|
||||
message=str(E),
|
||||
type="invalid_request_error",
|
||||
param=None,
|
||||
code=None,
|
||||
)
|
||||
|
||||
if data.get("stream") is not None:
|
||||
return EventSourceResponse(
|
||||
content=codegen(data=data),
|
||||
content=content,
|
||||
status_code=200,
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
return Response(
|
||||
status_code=200,
|
||||
content=codegen(data=data),
|
||||
content=content,
|
||||
media_type="application/json"
|
||||
)
|
||||
|
||||
|
|
0
copilot_proxy/config/__init__.py
Normal file
0
copilot_proxy/config/__init__.py
Normal file
27
copilot_proxy/config/log_config.py
Normal file
27
copilot_proxy/config/log_config.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
# The uvicorn_logger is used to add timestamps
|
||||
|
||||
uvicorn_logger = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"access": {
|
||||
"()": "uvicorn.logging.AccessFormatter",
|
||||
"fmt": '%(levelprefix)s %(asctime)s :: %(client_addr)s - "%(request_line)s" %(status_code)s',
|
||||
"use_colors": True
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"access": {
|
||||
"formatter": "access",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"uvicorn.access": {
|
||||
"handlers": ["access"],
|
||||
# "level": "INFO",
|
||||
"propagate": False
|
||||
},
|
||||
},
|
||||
}
|
|
@ -21,6 +21,9 @@ class CodeGenProxy:
|
|||
# Max number of tokens the model can handle
|
||||
self.MAX_MODEL_LEN = 2048
|
||||
|
||||
class TokensExceedsMaximum(Exception):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def prepare_tensor(name: str, tensor_input):
|
||||
t = client_util.InferInput(
|
||||
|
@ -82,8 +85,15 @@ class CodeGenProxy:
|
|||
prompt_len = input_start_ids.shape[1]
|
||||
input_len = prompt_len * np.ones([input_start_ids.shape[0], 1]).astype(np_type)
|
||||
max_tokens = data.get('max_tokens', 16)
|
||||
if max_tokens + input_len[0][0] > self.MAX_MODEL_LEN:
|
||||
raise ValueError("Max tokens + prompt length exceeds maximum model length")
|
||||
prompt_tokens: int = input_len[0][0]
|
||||
requested_tokens = max_tokens + prompt_tokens
|
||||
if requested_tokens > self.MAX_MODEL_LEN:
|
||||
print(1)
|
||||
raise self.TokensExceedsMaximum(
|
||||
f"This model's maximum context length is {self.MAX_MODEL_LEN}, however you requested "
|
||||
f"{requested_tokens} tokens ({prompt_tokens} in your prompt; {max_tokens} for the completion). "
|
||||
f"Please reduce your prompt; or completion length."
|
||||
)
|
||||
output_len = np.ones_like(input_len).astype(np_type) * max_tokens
|
||||
num_logprobs = data.get('logprobs', -1)
|
||||
if num_logprobs is None:
|
||||
|
|
19
copilot_proxy/utils/errors.py
Normal file
19
copilot_proxy/utils/errors.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from typing import *
|
||||
|
||||
class FauxPilotException(Exception):
|
||||
def __init__(self, message: str, type: Optional[str] = None, param: Optional[str] = None, code: Optional[int] = None):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.type = type
|
||||
self.param = param
|
||||
self.code = code
|
||||
|
||||
def json(self):
|
||||
return {
|
||||
'error': {
|
||||
'message': self.message,
|
||||
'type': self.type,
|
||||
'param': self.param,
|
||||
'code': self.code
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue