Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix max_position_embeddings default value for llama2 to 4096 #28241 #28754

Merged
merged 6 commits into from
Feb 9, 2024
Merged
24 changes: 22 additions & 2 deletions src/transformers/models/llama/convert_llama_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def write_json(text, path):
json.dump(text, f)


def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True):
def write_model(
model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, llama_version=1
):
# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
input_base_path = os.path.join(input_base_path, model_size)
Expand All @@ -102,7 +104,16 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
if base > 10000.0:
max_position_embeddings = 16384
else:
max_position_embeddings = 2048
# Depending on the Llama version, the default max_position_embeddings has different values.
if llama_version == 1:
max_position_embeddings = 2048
elif llama_version == 2:
max_position_embeddings = 4096
else:
raise NotImplementedError(
f"Version {llama_version} of llama is not supported yet. "
"Current supported versions of llama are [1, 2]."
)

tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
if tokenizer_path is not None:
Expand Down Expand Up @@ -301,6 +312,14 @@ def main():
help="Location to write HF model and tokenizer",
)
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
# Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
parser.add_argument(
"--llama_version",
choices=[1, 2],
default=1,
type=int,
help="Version of the Llama model to use. Current supported versions are Llama1 and Llama2.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's explain a bit why this is needed!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I will change back the default from 4096 to 2048 (forgot I had changed this in the first commit) for backwards compatibility. I will also add a comment for the llama_version argument.

karl-hajjar marked this conversation as resolved.
Show resolved Hide resolved
)
args = parser.parse_args()
spm_path = os.path.join(args.input_dir, "tokenizer.model")
if args.model_size != "tokenizer_only":
Expand All @@ -310,6 +329,7 @@ def main():
model_size=args.model_size,
safe_serialization=args.safe_serialization,
tokenizer_path=spm_path,
llama_version=args.llama_version,
)
else:
write_tokenizer(args.output_dir, spm_path)
Expand Down
Loading