Skip to content

Commit

Permalink
[Llama] Add descriptive names for symbolic variables
Browse files Browse the repository at this point in the history
Because the symbolic variables may appear in many places throughout
the resulting module, they should have more descriptive names, which
can be understood outside of their original contexts.
  • Loading branch information
Lunderberg committed Jan 30, 2024
1 parent b812bb5 commit 8f28486
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
26 changes: 13 additions & 13 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def __init__(

# Set the cached sin/cos to the maximum of 2048 and max seq len.
# This will be eliminated further with online rotary embedding calculation.
cache_len = te.var("cache_len", "int64")
cache_len = te.var("cached_rotary_embedding_len", "int64")
self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached")
self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached")

Expand Down Expand Up @@ -859,7 +859,7 @@ def create_embed_func(
) -> None:
func_name = "embed"

seq_len = tvm.tir.SizeVar("m", "int64")
seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
with bb.function(func_name):
model = LlamaEmbedTokensWrapper(config, tvm.tir.SizeVar("vocab_size", "int64"))
param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind)
Expand All @@ -886,8 +886,8 @@ def create_prefill_func_for_single_seq(
func_name = "prefill_with_embed" if sep_embed else "prefill"

bsz = 1
seq_len = tvm.tir.SizeVar("n", "int64")
all_seq_len = tvm.tir.SizeVar("m", "int64")
seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -932,8 +932,8 @@ def create_prefill_func_for_batching(
) -> None:
func_name = "prefill_with_embed"

bsz = tir.SizeVar("nseq", "int64")
total_seq_len = tvm.tir.SizeVar("m", "int64")
bsz = tir.SizeVar("batch_size", "int64")
total_seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -971,7 +971,7 @@ def create_decoding_func_for_single_seq(
func_name = "decode"

bsz = 1
all_seq_len = tvm.tir.SizeVar("m", "int64")
all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64")

with bb.function(func_name):
model = LlamaForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64"))
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def create_decoding_func_for_batching(
) -> None:
func_name = "decode_with_embed"

bsz = tir.SizeVar("nseq", "int64")
bsz = tir.SizeVar("batch_size", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -1041,7 +1041,7 @@ def create_verification_func_for_batching(
) -> None:
func_name = "verify_with_embed"

total_seq_len = tvm.tir.SizeVar("m", "int64")
total_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -1160,7 +1160,7 @@ def create_softmax_func_for_single_seq(bb: relax.BlockBuilder, config: LlamaConf

def create_softmax_func_for_batching(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
with bb.function("softmax_with_temperature"):
bsz = tvm.tir.SizeVar("nseq", "int64")
bsz = tvm.tir.SizeVar("batch_size", "int64")
logits = nn.Placeholder(
(bsz, 1, tvm.tir.SizeVar("vocab_size", "int64")),
dtype="float32",
Expand Down Expand Up @@ -1188,7 +1188,7 @@ def kv_cache_transpose_append(
var_v_data: T.handle,
var_position_map: T.handle,
):
ntoken = T.SizeVar("ntoken", "int64")
ntoken = T.SizeVar("num_tokens_excluding_cache", "int64")
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()

Expand Down Expand Up @@ -1451,10 +1451,10 @@ def get_model(args, hf_config):
mod = bb.get()

tir_bound_map = dict()
tir_bound_map["n"] = (
tir_bound_map["num_tokens_without_cache"] = (
args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length
)
tir_bound_map["m"] = config.max_sequence_length
tir_bound_map["num_tokens_with_cache"] = config.max_sequence_length
tir_bound_map["vocab_size"] = args.max_vocab_size
if enable_batching:
tir_bound_map["nseq"] = args.max_batch_size
Expand Down
12 changes: 6 additions & 6 deletions mlc_llm/relax_model/llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def __init__(

# Set the cached sin/cos to the maximum of 2048 and max seq len.
# This will be eliminated further with online rotary embedding calculation.
cache_len = te.var("cache_len", "int64")
cache_len = te.var("cached_rotary_embedding_len", "int64")
self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached")
self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached")
############ End ############
Expand Down Expand Up @@ -455,8 +455,8 @@ def create_evaluate_func(
"""Evaluate logits for the last token in each sequence. Same as prefill but without KV cache."""
func_name = "evaluate"

num_token = tvm.tir.SizeVar("num_token", "int64")
num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")

with bb.function(func_name):
model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed)
Expand Down Expand Up @@ -504,8 +504,8 @@ def create_encoding_func(
"""
func_name = "prefill_with_embed" if sep_embed else "prefill"

num_token = tvm.tir.SizeVar("num_token", "int64")
num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")

num_inputs = 5

Expand Down Expand Up @@ -569,7 +569,7 @@ def create_decoding_func(
"""Batched decoding with vLLM paged KV cache."""
func_name = "decode"

num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")
max_num_blocks_per_seq = tvm.tir.SizeVar("max_num_blocks_per_seq", "int64")

with bb.function(func_name):
Expand Down
4 changes: 2 additions & 2 deletions python/mlc_chat/nn/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def tir_kv_cache_transpose_append(
var_position_map: T.handle,
):
T.func_attr({"tir.noalias": T.bool(True)})
ntoken = T.SizeVar("ntoken", "int64")
ntoken = T.SizeVar("num_tokens_excluding_cache", "int64")
num_pages = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, head_dim), dtype)
k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), dtype)
Expand Down Expand Up @@ -333,7 +333,7 @@ def tir_kv_cache_debug_get_kv(
layer_id: T.int64,
):
T.func_attr({"tir.noalias": T.bool(True)})
seqlen = T.SizeVar("seqlen", "int64")
seqlen = T.SizeVar("num_tokens_including_cache", "int64")
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype)
Expand Down

0 comments on commit 8f28486

Please sign in to comment.