Skip to content

Commit

Permalink
[BugFix] Set the right max_sequence_length for both Llama-1 and Lla…
Browse files Browse the repository at this point in the history
…ma-2 families (mlc-ai#1032)

* fix

* reflect feedback

---------

Co-authored-by: “Sunghyun <sunggg@umich.com>
  • Loading branch information
sunggg and “Sunghyun authored Oct 13, 2023
1 parent bfaa5b9 commit ca8c11b
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,26 +817,42 @@ def create_softmax_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
def get_model(args, hf_config):
model_name = args.model
dtype = args.quantization.model_dtype
max_seq_len = args.max_seq_len
sep_embed = args.sep_embed

position_embedding_base = 10000
max_position_embeddings = 2048
if "rope_theta" in hf_config:
position_embedding_base = hf_config["rope_theta"]
if "max_position_embeddings" in hf_config:
max_position_embeddings = hf_config["max_position_embeddings"]

config = LlamaConfig(
**hf_config,
dtype=dtype,
position_embedding_base=position_embedding_base,
combine_matmul=True,
num_shards=args.num_shards,
build_model_only=args.build_model_only,
)
if max_seq_len != -1:
config.max_sequence_length = max_seq_len
# Llama-2 variants use `max_position_embeddings` to encode maximum sequence length in their hf model cards,
# while Llama-1 variants use `max_sequence_length`.
# Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`.
# If none of them is defined, throw an error.
if "max_sequence_length" in hf_config:
config = LlamaConfig(
**hf_config,
dtype=dtype,
position_embedding_base=position_embedding_base,
combine_matmul=True,
num_shards=args.num_shards,
build_model_only=args.build_model_only,
)
elif "max_position_embeddings" in hf_config:
config = LlamaConfig(
**hf_config,
dtype=dtype,
max_sequence_length=hf_config["max_position_embeddings"],
position_embedding_base=position_embedding_base,
combine_matmul=True,
num_shards=args.num_shards,
build_model_only=args.build_model_only,
)
else:
raise Exception("The model config should contain information about maximum sequence length.")

# If there is a user-provided maximum sequence length, override hf config.
if args.max_seq_len != -1:
config.max_sequence_length = args.max_seq_len

param_manager = ParamManager()
bb = relax.BlockBuilder()
Expand Down

0 comments on commit ca8c11b

Please sign in to comment.