Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix max_position_embeddings default value for llama2 to 4096 #28241 #28754

Merged
merged 6 commits into from
Feb 9, 2024
Merged
24 changes: 22 additions & 2 deletions src/transformers/models/llama/convert_llama_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def write_json(text, path):
json.dump(text, f)


def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True):
def write_model(
model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, llama_version=1
):
# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
input_base_path = os.path.join(input_base_path, model_size)
Expand All @@ -102,7 +104,16 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
if base > 10000.0:
max_position_embeddings = 16384
else:
max_position_embeddings = 2048
# Depending on the Llama version, the default max_position_embeddings has different values.
if llama_version == 1:
max_position_embeddings = 2048
elif llama_version == 2:
max_position_embeddings = 4096
else:
raise NotImplementedError(
f"Version {llama_version} of llama is not supported yet. "
"Current supported versions of llama are [1, 2]."
)

tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
if tokenizer_path is not None:
Expand Down Expand Up @@ -301,6 +312,14 @@ def main():
help="Location to write HF model and tokenizer",
)
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
# Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
parser.add_argument(
"--llama_version",
choices=[1, 2],
default=1,
type=int,
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
)
args = parser.parse_args()
spm_path = os.path.join(args.input_dir, "tokenizer.model")
if args.model_size != "tokenizer_only":
Expand All @@ -310,6 +329,7 @@ def main():
model_size=args.model_size,
safe_serialization=args.safe_serialization,
tokenizer_path=spm_path,
llama_version=args.llama_version,
)
else:
write_tokenizer(args.output_dir, spm_path)
Expand Down
Loading