Skip to content

Commit

Permalink
Fix max_position_embeddings default value for llama2 to 4096 huggingf…
Browse files Browse the repository at this point in the history
…ace#28241 (huggingface#28754)

* Changed max_position_embeddings default value from 2048 to 4096

* force push

* Fixed formatting issues. Fixed missing argument in write_model.

* Reverted to the default value 2048 in the Llama config. Added comments for the llama_version argument.

* Fixed issue with default value value of max_position_embeddings in docstring

* Updated help message for llama versions

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
2 people authored and jon-tow committed Feb 12, 2024
1 parent e023dc5 commit 36247bd
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions src/transformers/models/llama/convert_llama_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def write_json(text, path):
json.dump(text, f)


def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True):
def write_model(
model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, llama_version=1
):
# for backward compatibility, before you needed the repo to be called `my_repo/model_size`
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
input_base_path = os.path.join(input_base_path, model_size)
Expand All @@ -102,7 +104,16 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa
if base > 10000.0:
max_position_embeddings = 16384
else:
max_position_embeddings = 2048
# Depending on the Llama version, the default max_position_embeddings has different values.
if llama_version == 1:
max_position_embeddings = 2048
elif llama_version == 2:
max_position_embeddings = 4096
else:
raise NotImplementedError(
f"Version {llama_version} of llama is not supported yet. "
"Current supported versions of llama are [1, 2]."
)

tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
if tokenizer_path is not None:
Expand Down Expand Up @@ -301,6 +312,14 @@ def main():
help="Location to write HF model and tokenizer",
)
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
# Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used.
parser.add_argument(
"--llama_version",
choices=[1, 2],
default=1,
type=int,
help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size",
)
args = parser.parse_args()
spm_path = os.path.join(args.input_dir, "tokenizer.model")
if args.model_size != "tokenizer_only":
Expand All @@ -310,6 +329,7 @@ def main():
model_size=args.model_size,
safe_serialization=args.safe_serialization,
tokenizer_path=spm_path,
llama_version=args.llama_version,
)
else:
write_tokenizer(args.output_dir, spm_path)
Expand Down

0 comments on commit 36247bd

Please sign in to comment.