diff --git a/src/transformers/models/llama/convert_llama_weights_to_hf.py b/src/transformers/models/llama/convert_llama_weights_to_hf.py index d2fc3a79aff1b4..f9bca1204a22ec 100644 --- a/src/transformers/models/llama/convert_llama_weights_to_hf.py +++ b/src/transformers/models/llama/convert_llama_weights_to_hf.py @@ -80,7 +80,9 @@ def write_json(text, path): json.dump(text, f) -def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True): +def write_model( + model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, llama_version=1 +): # for backward compatibility, before you needed the repo to be called `my_repo/model_size` if not os.path.isfile(os.path.join(input_base_path, "params.json")): input_base_path = os.path.join(input_base_path, model_size) @@ -102,7 +104,16 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa if base > 10000.0: max_position_embeddings = 16384 else: - max_position_embeddings = 2048 + # Depending on the Llama version, the default max_position_embeddings has different values. + if llama_version == 1: + max_position_embeddings = 2048 + elif llama_version == 2: + max_position_embeddings = 4096 + else: + raise NotImplementedError( + f"Version {llama_version} of llama is not supported yet. " + "Current supported versions of llama are [1, 2]." + ) tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast if tokenizer_path is not None: @@ -301,6 +312,14 @@ def main(): help="Location to write HF model and tokenizer", ) parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") + # Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. + parser.add_argument( + "--llama_version", + choices=[1, 2], + default=1, + type=int, + help="Version of the Llama model to convert. Currently supports Llama1 and Llama2. Controls the context size", + ) args = parser.parse_args() spm_path = os.path.join(args.input_dir, "tokenizer.model") if args.model_size != "tokenizer_only": @@ -310,6 +329,7 @@ def main(): model_size=args.model_size, safe_serialization=args.safe_serialization, tokenizer_path=spm_path, + llama_version=args.llama_version, ) else: write_tokenizer(args.output_dir, spm_path)