feat: Return a 400 if prompt exceeds max tokens

This commit is contained in:
Fred de Gier 2022-10-20 14:56:39 +02:00
parent 31d2349dbb
commit e2486698e0
4 changed files with 56 additions and 8 deletions

View file

@ -21,6 +21,9 @@ class CodeGenProxy:
# Max number of tokens the model can handle
self.MAX_MODEL_LEN = 2048
class TokensExceedsMaximum(Exception):
pass
@staticmethod
def prepare_tensor(name: str, tensor_input):
t = client_util.InferInput(
@ -78,8 +81,15 @@ class CodeGenProxy:
prompt_len = input_start_ids.shape[1]
input_len = prompt_len * np.ones([input_start_ids.shape[0], 1]).astype(np.uint32)
max_tokens = data.get('max_tokens', 16)
if max_tokens + input_len[0][0] > self.MAX_MODEL_LEN:
raise ValueError("Max tokens + prompt length exceeds maximum model length")
prompt_tokens: int = input_len[0][0]
requested_tokens = max_tokens + prompt_tokens
if requested_tokens > self.MAX_MODEL_LEN:
print(1)
raise self.TokensExceedsMaximum(
f"This model's maximum context length is {self.MAX_MODEL_LEN}, however you requested "
f"{requested_tokens} tokens ({prompt_tokens} in your prompt; {max_tokens} for the completion). "
f"Please reduce your prompt; or completion length."
)
output_len = np.ones_like(input_len).astype(np.uint32) * max_tokens
num_logprobs = data.get('logprobs', -1)
if num_logprobs is None: