mirror of
https://github.com/fauxpilot/fauxpilot.git
synced 2025-08-21 13:54:29 -07:00
Happy lint
This commit is contained in:
parent
acc6746114
commit
eb62cffdcb
1 changed files with 9 additions and 2 deletions
|
@ -7,7 +7,12 @@ from transformers import GPTJForCausalLM, GPTJConfig
|
|||
from transformers import CodeGenTokenizer, CodeGenForCausalLM # noqa: F401
|
||||
from transformers import CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
CODEGEN_2_LIST = ["Salesforce/codegen2-1B", "Salesforce/codegen2-3_7B", "Salesforce/codegen2-7B", "Salesforce/codegen2-16B"]
|
||||
CODEGEN_2_LIST = [
|
||||
"Salesforce/codegen2-1B",
|
||||
"Salesforce/codegen2-3_7B",
|
||||
"Salesforce/codegen2-7B",
|
||||
"Salesforce/codegen2-16B"
|
||||
]
|
||||
convertable_models = CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST + CODEGEN_2_LIST
|
||||
|
||||
parser = argparse.ArgumentParser('Convert SalesForce CodeGen model to GPT-J')
|
||||
|
@ -19,7 +24,9 @@ parser.add_argument('output_dir', help='where to store the converted model')
|
|||
args = parser.parse_args()
|
||||
|
||||
print('Loading CodeGen model')
|
||||
cg_model = CodeGenForCausalLM.from_pretrained(args.code_model, torch_dtype="auto", trust_remote_code=bool(args.code_model in CODEGEN_2_LIST))
|
||||
cg_model = CodeGenForCausalLM.from_pretrained(
|
||||
args.code_model, torch_dtype="auto", trust_remote_code=bool(args.code_model in CODEGEN_2_LIST)
|
||||
)
|
||||
cg_config = cg_model.config
|
||||
|
||||
# Create empty GPTJ model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue