mirror of
https://github.com/fauxpilot/fauxpilot.git
synced 2025-07-11 07:36:49 -07:00
feat: Return a 400 if prompt exceeds max tokens
This commit is contained in:
parent
31d2349dbb
commit
e2486698e0
4 changed files with 56 additions and 8 deletions
|
@ -21,6 +21,9 @@ class CodeGenProxy:
|
|||
# Max number of tokens the model can handle
|
||||
self.MAX_MODEL_LEN = 2048
|
||||
|
||||
class TokensExceedsMaximum(Exception):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def prepare_tensor(name: str, tensor_input):
|
||||
t = client_util.InferInput(
|
||||
|
@ -78,8 +81,15 @@ class CodeGenProxy:
|
|||
prompt_len = input_start_ids.shape[1]
|
||||
input_len = prompt_len * np.ones([input_start_ids.shape[0], 1]).astype(np.uint32)
|
||||
max_tokens = data.get('max_tokens', 16)
|
||||
if max_tokens + input_len[0][0] > self.MAX_MODEL_LEN:
|
||||
raise ValueError("Max tokens + prompt length exceeds maximum model length")
|
||||
prompt_tokens: int = input_len[0][0]
|
||||
requested_tokens = max_tokens + prompt_tokens
|
||||
if requested_tokens > self.MAX_MODEL_LEN:
|
||||
print(1)
|
||||
raise self.TokensExceedsMaximum(
|
||||
f"This model's maximum context length is {self.MAX_MODEL_LEN}, however you requested "
|
||||
f"{requested_tokens} tokens ({prompt_tokens} in your prompt; {max_tokens} for the completion). "
|
||||
f"Please reduce your prompt; or completion length."
|
||||
)
|
||||
output_len = np.ones_like(input_len).astype(np.uint32) * max_tokens
|
||||
num_logprobs = data.get('logprobs', -1)
|
||||
if num_logprobs is None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue