-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix max_position_embeddings default value for llama2 to 4096 #28241 #28754
Changes from 4 commits
0dfd39e
21a08a1
a3fc357
b5a9788
4792f49
51adf95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,7 +80,9 @@ def write_json(text, path): | |
json.dump(text, f) | ||
|
||
|
||
def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True): | ||
def write_model( | ||
model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, llama_version=1 | ||
): | ||
# for backward compatibility, before you needed the repo to be called `my_repo/model_size` | ||
if not os.path.isfile(os.path.join(input_base_path, "params.json")): | ||
input_base_path = os.path.join(input_base_path, model_size) | ||
|
@@ -102,7 +104,16 @@ def write_model(model_path, input_base_path, model_size, tokenizer_path=None, sa | |
if base > 10000.0: | ||
max_position_embeddings = 16384 | ||
else: | ||
max_position_embeddings = 2048 | ||
# Depending on the Llama version, the default max_position_embeddings has different values. | ||
if llama_version == 1: | ||
max_position_embeddings = 2048 | ||
elif llama_version == 2: | ||
max_position_embeddings = 4096 | ||
else: | ||
raise NotImplementedError( | ||
f"Version {llama_version} of llama is not supported yet. " | ||
"Current supported versions of llama are [1, 2]." | ||
) | ||
|
||
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast | ||
if tokenizer_path is not None: | ||
|
@@ -301,6 +312,14 @@ def main(): | |
help="Location to write HF model and tokenizer", | ||
) | ||
parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") | ||
# Different Llama versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. | ||
parser.add_argument( | ||
"--llama_version", | ||
choices=[1, 2], | ||
default=1, | ||
type=int, | ||
help="Version of the Llama model to use. Current supported versions are Llama1 and Llama2.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's explain a bit why this is needed! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I will change back the default from 4096 to 2048 (forgot I had changed this in the first commit) for backwards compatibility. I will also add a comment for the
karl-hajjar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
args = parser.parse_args() | ||
spm_path = os.path.join(args.input_dir, "tokenizer.model") | ||
if args.model_size != "tokenizer_only": | ||
|
@@ -310,6 +329,7 @@ def main(): | |
model_size=args.model_size, | ||
safe_serialization=args.safe_serialization, | ||
tokenizer_path=spm_path, | ||
llama_version=args.llama_version, | ||
) | ||
else: | ||
write_tokenizer(args.output_dir, spm_path) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this IMO should still be adressed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 This shouldn't be changed here.