Happy lint

This commit is contained in:
fdegier 2024-02-07 14:45:48 +01:00
commit eb62cffdcb

View file

@ -7,7 +7,12 @@ from transformers import GPTJForCausalLM, GPTJConfig
from transformers import CodeGenTokenizer, CodeGenForCausalLM # noqa: F401
from transformers import CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST
CODEGEN_2_LIST = ["Salesforce/codegen2-1B", "Salesforce/codegen2-3_7B", "Salesforce/codegen2-7B", "Salesforce/codegen2-16B"]
CODEGEN_2_LIST = [
"Salesforce/codegen2-1B",
"Salesforce/codegen2-3_7B",
"Salesforce/codegen2-7B",
"Salesforce/codegen2-16B"
]
convertable_models = CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST + CODEGEN_2_LIST
parser = argparse.ArgumentParser('Convert SalesForce CodeGen model to GPT-J')
@ -19,7 +24,9 @@ parser.add_argument('output_dir', help='where to store the converted model')
args = parser.parse_args()
print('Loading CodeGen model')
cg_model = CodeGenForCausalLM.from_pretrained(args.code_model, torch_dtype="auto", trust_remote_code=bool(args.code_model in CODEGEN_2_LIST))
cg_model = CodeGenForCausalLM.from_pretrained(
args.code_model, torch_dtype="auto", trust_remote_code=bool(args.code_model in CODEGEN_2_LIST)
)
cg_config = cg_model.config
# Create empty GPTJ model