mirror of
https://github.com/fauxpilot/fauxpilot.git
synced 2025-08-20 21:34:13 -07:00
Support newer upstream Triton
The main change is that the model config format has changed. To deal with this we have a new script in the converter that will upgrade a model to the new version. Aside from that, we also no longer need to maintain our own fork of Triton since they have fixed the bug with GPT-J models. This should make it a lot easier to stay synced with upstream (although we still have to build our own container since there doesn't seem to be a prebuilt Triton+FT container hosted by NVIDIA). Newer Triton should let us use some nice features: - Support for more models, like GPT-NeoX - Streaming token support (this still needs to be implemented in the proxy though) - Dynamic batching Still TODO: - Proxy support for streaming tokens - Add stuff to setup.sh and launch.sh to detect if a model upgrade is needed and do it automatically.
This commit is contained in:
parent
4441e5e16b
commit
02f7887f17
6 changed files with 3581 additions and 11 deletions
|
@ -1,4 +1,4 @@
|
|||
FROM moyix/triton_with_ft:22.09
|
||||
FROM moyix/triton_with_ft:23.01
|
||||
|
||||
# Install dependencies: torch
|
||||
RUN python3 -m pip install --disable-pip-version-check -U torch --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
|
|
|
@ -9,7 +9,7 @@ from transformers import CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST
|
|||
|
||||
parser = argparse.ArgumentParser('Convert SalesForce CodeGen model to GPT-J')
|
||||
parser.add_argument('--code_model',
|
||||
choices=CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST, default='Salesforce/codegen-350M-multi',
|
||||
default='Salesforce/codegen-350M-multi',
|
||||
help='which SalesForce model to convert'
|
||||
)
|
||||
parser.add_argument('output_dir', help='where to store the converted model')
|
||||
|
@ -60,7 +60,7 @@ print('Converting...')
|
|||
with torch.no_grad():
|
||||
cg_model.eval()
|
||||
gptj_model.eval()
|
||||
|
||||
|
||||
for name, param in cg_model.named_parameters():
|
||||
# print(f'Converting {name}')
|
||||
# Handle the qkv weights separately because we need to split them
|
||||
|
@ -86,4 +86,4 @@ with torch.no_grad():
|
|||
|
||||
print('Conversion complete.')
|
||||
print(f"Saving model to {args.output_dir}...")
|
||||
gptj_model.save_pretrained(args.output_dir)
|
||||
gptj_model.save_pretrained(args.output_dir)
|
||||
|
|
3376
converter/model_config_pb2.py
Normal file
3376
converter/model_config_pb2.py
Normal file
File diff suppressed because one or more lines are too long
199
converter/upgrade_model_config.py
Normal file
199
converter/upgrade_model_config.py
Normal file
|
@ -0,0 +1,199 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
from distutils.version import StrictVersion
|
||||
from google.protobuf import text_format
|
||||
from model_config_pb2 import ModelConfig, TYPE_UINT64, TYPE_UINT32, TYPE_FP32
|
||||
from configparser import ConfigParser
|
||||
import sys, os
|
||||
|
||||
# Upgrade model config file from v1.1 to v1.3
|
||||
|
||||
parser = argparse.ArgumentParser('Upgrade model config file from v1.1 to v1.3')
|
||||
parser.add_argument('model_dir', help='Path to the input model')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Make this an absolute path
|
||||
model_dir = os.path.realpath(args.model_dir)
|
||||
|
||||
# Path to the protobuf-text config file
|
||||
config_path = os.path.join(model_dir, 'fastertransformer', 'config.pbtxt')
|
||||
|
||||
# Check for existing backup files and bail if so
|
||||
old_version = '1.1'
|
||||
new_version = '1.3'
|
||||
# Check for a .version file. If it exists, we've already upgraded this model.
|
||||
# This will also let us use the .version file for future upgrades.
|
||||
version_path = os.path.join(model_dir, 'fastertransformer', '.version')
|
||||
if os.path.exists(version_path):
|
||||
with open(version_path, 'r') as f:
|
||||
old_version = f.read().strip()
|
||||
if StrictVersion(old_version) >= StrictVersion(new_version):
|
||||
print(f'INFO: model already upgraded to version {old_version}; nothing to do',
|
||||
file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
backup_ext = f'.bk_{new_version}'
|
||||
if os.path.exists(config_path+backup_ext):
|
||||
print(f'INFO: backup {config_path+backup_ext} already exists; did you already run this script?',
|
||||
file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Read the old config
|
||||
with open(config_path, 'r') as f:
|
||||
config = ModelConfig()
|
||||
text_format.Parse(f.read(), config)
|
||||
|
||||
# Only support GPT-J for now; we don't have any other model types
|
||||
if config.parameters['model_type'].string_value != 'GPT-J':
|
||||
print(f'ERROR: only GPT-J models are supported for now', file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Build up the new config
|
||||
new_config = ModelConfig()
|
||||
new_config.name = config.name
|
||||
new_config.backend = config.backend
|
||||
new_config.default_model_filename = config.default_model_filename
|
||||
new_config.max_batch_size = config.max_batch_size
|
||||
|
||||
# New: model_transaction_policy controls whether to stream tokens. Default
|
||||
# to false to preserve current behavior.
|
||||
new_config.model_transaction_policy.decoupled = False
|
||||
|
||||
# Inputs
|
||||
common_inputs = set([
|
||||
'input_ids', 'start_id', 'end_id', 'input_lengths', 'request_output_len',
|
||||
'runtime_top_k', 'runtime_top_p', 'beam_search_diversity_rate', 'temperature',
|
||||
'len_penalty', 'repetition_penalty', 'random_seed', 'is_return_log_probs',
|
||||
'beam_width', 'bad_words_list', 'stop_words_list'
|
||||
])
|
||||
for input in config.input:
|
||||
if input.name not in common_inputs: continue
|
||||
new_input = new_config.input.add()
|
||||
new_input.CopyFrom(input)
|
||||
# Random seed dtype changed from int32 to uint64
|
||||
if input.name == 'random_seed':
|
||||
new_input.data_type = TYPE_UINT64
|
||||
|
||||
# New inputs
|
||||
# {
|
||||
# name: "prompt_learning_task_name_ids"
|
||||
# data_type: TYPE_UINT32
|
||||
# dims: [ 1 ]
|
||||
# reshape: { shape: [ ] }
|
||||
# optional: true
|
||||
# }
|
||||
new_input = new_config.input.add()
|
||||
new_input.name = 'prompt_learning_task_name_ids'
|
||||
new_input.data_type = TYPE_UINT32
|
||||
new_input.dims.extend([1])
|
||||
new_input.reshape.shape.extend([])
|
||||
new_input.optional = True
|
||||
# {
|
||||
# name: "top_p_decay"
|
||||
# data_type: TYPE_FP32
|
||||
# dims: [ 1 ]
|
||||
# reshape: { shape: [ ] }
|
||||
# optional: true
|
||||
# }
|
||||
new_input = new_config.input.add()
|
||||
new_input.name = 'top_p_decay'
|
||||
new_input.data_type = TYPE_FP32
|
||||
new_input.dims.extend([1])
|
||||
new_input.reshape.shape.extend([])
|
||||
new_input.optional = True
|
||||
# {
|
||||
# name: "top_p_min"
|
||||
# data_type: TYPE_FP32
|
||||
# dims: [ 1 ]
|
||||
# reshape: { shape: [ ] }
|
||||
# optional: true
|
||||
# }
|
||||
new_input = new_config.input.add()
|
||||
new_input.name = 'top_p_min'
|
||||
new_input.data_type = TYPE_FP32
|
||||
new_input.dims.extend([1])
|
||||
new_input.reshape.shape.extend([])
|
||||
new_input.optional = True
|
||||
# {
|
||||
# name: "top_p_reset_ids"
|
||||
# data_type: TYPE_UINT32
|
||||
# dims: [ 1 ]
|
||||
# reshape: { shape: [ ] }
|
||||
# optional: true
|
||||
# }
|
||||
new_input = new_config.input.add()
|
||||
new_input.name = 'top_p_reset_ids'
|
||||
new_input.data_type = TYPE_UINT32
|
||||
new_input.dims.extend([1])
|
||||
new_input.reshape.shape.extend([])
|
||||
new_input.optional = True
|
||||
|
||||
# Outputs: these are all unchanged
|
||||
new_config.output.extend(config.output)
|
||||
# Instance group also unchanged
|
||||
new_config.instance_group.extend(config.instance_group)
|
||||
|
||||
common_parameters = set([
|
||||
'tensor_para_size', 'pipeline_para_size', 'model_type',
|
||||
'model_checkpoint_path', 'enable_custom_all_reduce',
|
||||
])
|
||||
for parameter in config.parameters:
|
||||
if parameter not in common_parameters: continue
|
||||
new_config.parameters[parameter].string_value = config.parameters[parameter].string_value
|
||||
|
||||
# New parameters
|
||||
new_config.parameters['data_type'].string_value = (
|
||||
'fp32' if config.parameters['is_half'].string_value == '0' else 'fp16'
|
||||
)
|
||||
|
||||
# These parameters moved to config.ini in the weights directory
|
||||
config_ini_params = {
|
||||
'model_name': 'model_name',
|
||||
'head_num': 'head_num',
|
||||
'size_per_head': 'size_per_head',
|
||||
'inter_size': 'inter_size',
|
||||
'decoder_layers': 'num_layer',
|
||||
'rotary_embedding': 'rotary_embedding',
|
||||
'vocab_size': 'vocab_size',
|
||||
'start_id': 'start_id',
|
||||
'end_id': 'end_id',
|
||||
}
|
||||
config_ini = ConfigParser()
|
||||
config_ini.add_section('gptj')
|
||||
for param in config_ini_params:
|
||||
config_ini['gptj'][config_ini_params[param]] = config.parameters[param].string_value
|
||||
config_ini['gptj']['weight_data_type'] = 'fp32'
|
||||
|
||||
weights_dir = config.parameters['model_checkpoint_path'].string_value
|
||||
# The weights dir in the config file may be remapped, e.g.
|
||||
# /fastdata/mymodels/codegen-6B-mono-1gpu/fastertransformer/1/1-gpu
|
||||
# -> /model/fastertransformer/1/1-gpu
|
||||
# Undo this remapping so we can find the config.ini file
|
||||
# Find the 'fastertransformer' component of the path
|
||||
orig_index = config_path.split(os.path.sep).index('fastertransformer')
|
||||
# Find the 'fastertransformer' component of the weights dir
|
||||
weights_index = weights_dir.split(os.path.sep).index('fastertransformer')
|
||||
real_weights_dir = os.path.sep.join(
|
||||
config_path.split(os.path.sep)[:orig_index] +
|
||||
weights_dir.split(os.path.sep)[weights_index:]
|
||||
)
|
||||
config_ini_path = os.path.join(real_weights_dir, 'config.ini')
|
||||
|
||||
# Make backup copies of config.ini and config.pbtxt
|
||||
os.rename(config_path, config_path + backup_ext)
|
||||
if os.path.exists(config_ini_path):
|
||||
os.rename(config_ini_path, config_ini_path + backup_ext)
|
||||
|
||||
# Write out the new config files
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(text_format.MessageToString(new_config))
|
||||
with open(config_ini_path, 'w') as f:
|
||||
config_ini.write(f)
|
||||
|
||||
# Write out the new version
|
||||
with open(version_path, 'w') as f:
|
||||
print(new_version, file=f)
|
||||
|
||||
print(f'INFO: Successfully upgraded {model_dir} from {old_version} to {new_version}',
|
||||
file=sys.stderr)
|
|
@ -2,5 +2,5 @@ fastapi==0.82.0
|
|||
numpy==1.23.2
|
||||
sse-starlette==1.1.6
|
||||
tokenizers==0.12.1
|
||||
tritonclient[all]==2.25.0
|
||||
tritonclient[all]==2.29.0
|
||||
uvicorn==0.18.3
|
||||
|
|
|
@ -8,10 +8,6 @@ import tritonclient.grpc as client_util
|
|||
from tokenizers import Tokenizer
|
||||
from tritonclient.utils import np_to_triton_dtype, InferenceServerException
|
||||
|
||||
np.finfo(np.dtype("float32"))
|
||||
np.finfo(np.dtype("float64"))
|
||||
|
||||
|
||||
class CodeGenProxy:
|
||||
def __init__(self, host: str = 'triton', port: int = 8001, verbose: bool = False):
|
||||
self.tokenizer = Tokenizer.from_file('/python-docker/cgtok/tokenizer.json')
|
||||
|
@ -88,7 +84,6 @@ class CodeGenProxy:
|
|||
prompt_tokens: int = input_len[0][0]
|
||||
requested_tokens = max_tokens + prompt_tokens
|
||||
if requested_tokens > self.MAX_MODEL_LEN:
|
||||
print(1)
|
||||
raise self.TokensExceedsMaximum(
|
||||
f"This model's maximum context length is {self.MAX_MODEL_LEN}, however you requested "
|
||||
f"{requested_tokens} tokens ({prompt_tokens} in your prompt; {max_tokens} for the completion). "
|
||||
|
@ -112,7 +107,7 @@ class CodeGenProxy:
|
|||
runtime_top_k = top_k * np.ones([input_start_ids.shape[0], 1]).astype(np_type)
|
||||
runtime_top_p = top_p * np.ones([input_start_ids.shape[0], 1]).astype(np.float32)
|
||||
beam_search_diversity_rate = 0.0 * np.ones([input_start_ids.shape[0], 1]).astype(np.float32)
|
||||
random_seed = np.random.randint(0, 2 ** 31 - 1, (input_start_ids.shape[0], 1), dtype=np.int32)
|
||||
random_seed = np.random.randint(0, 2 ** 31 - 1, (input_start_ids.shape[0], 1), dtype=np.uint64)
|
||||
temperature = temperature * np.ones([input_start_ids.shape[0], 1]).astype(np.float32)
|
||||
len_penalty = 1.0 * np.ones([input_start_ids.shape[0], 1]).astype(np.float32)
|
||||
repetition_penalty = frequency_penalty * np.ones([input_start_ids.shape[0], 1]).astype(np.float32)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue