mirror of
https://github.com/fauxpilot/fauxpilot.git
synced 2025-08-21 13:54:29 -07:00
Happy lint
This commit is contained in:
parent
acc6746114
commit
eb62cffdcb
1 changed files with 9 additions and 2 deletions
|
@ -7,7 +7,12 @@ from transformers import GPTJForCausalLM, GPTJConfig
|
||||||
from transformers import CodeGenTokenizer, CodeGenForCausalLM # noqa: F401
|
from transformers import CodeGenTokenizer, CodeGenForCausalLM # noqa: F401
|
||||||
from transformers import CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers import CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
|
||||||
CODEGEN_2_LIST = ["Salesforce/codegen2-1B", "Salesforce/codegen2-3_7B", "Salesforce/codegen2-7B", "Salesforce/codegen2-16B"]
|
CODEGEN_2_LIST = [
|
||||||
|
"Salesforce/codegen2-1B",
|
||||||
|
"Salesforce/codegen2-3_7B",
|
||||||
|
"Salesforce/codegen2-7B",
|
||||||
|
"Salesforce/codegen2-16B"
|
||||||
|
]
|
||||||
convertable_models = CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST + CODEGEN_2_LIST
|
convertable_models = CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST + CODEGEN_2_LIST
|
||||||
|
|
||||||
parser = argparse.ArgumentParser('Convert SalesForce CodeGen model to GPT-J')
|
parser = argparse.ArgumentParser('Convert SalesForce CodeGen model to GPT-J')
|
||||||
|
@ -19,7 +24,9 @@ parser.add_argument('output_dir', help='where to store the converted model')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print('Loading CodeGen model')
|
print('Loading CodeGen model')
|
||||||
cg_model = CodeGenForCausalLM.from_pretrained(args.code_model, torch_dtype="auto", trust_remote_code=bool(args.code_model in CODEGEN_2_LIST))
|
cg_model = CodeGenForCausalLM.from_pretrained(
|
||||||
|
args.code_model, torch_dtype="auto", trust_remote_code=bool(args.code_model in CODEGEN_2_LIST)
|
||||||
|
)
|
||||||
cg_config = cg_model.config
|
cg_config = cg_model.config
|
||||||
|
|
||||||
# Create empty GPTJ model
|
# Create empty GPTJ model
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue