Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Llama] Add descriptive names for symbolic variables #1684

Merged
merged 1 commit into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def __init__(

# Set the cached sin/cos to the maximum of 2048 and max seq len.
# This will be eliminated further with online rotary embedding calculation.
cache_len = te.var("cache_len", "int64")
cache_len = te.var("cached_rotary_embedding_len", "int64")
self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached")
self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached")

Expand Down Expand Up @@ -859,7 +859,7 @@ def create_embed_func(
) -> None:
func_name = "embed"

seq_len = tvm.tir.SizeVar("m", "int64")
seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
with bb.function(func_name):
model = LlamaEmbedTokensWrapper(config, tvm.tir.SizeVar("vocab_size", "int64"))
param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind)
Expand All @@ -886,8 +886,8 @@ def create_prefill_func_for_single_seq(
func_name = "prefill_with_embed" if sep_embed else "prefill"

bsz = 1
seq_len = tvm.tir.SizeVar("n", "int64")
all_seq_len = tvm.tir.SizeVar("m", "int64")
seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -932,8 +932,8 @@ def create_prefill_func_for_batching(
) -> None:
func_name = "prefill_with_embed"

bsz = tir.SizeVar("nseq", "int64")
total_seq_len = tvm.tir.SizeVar("m", "int64")
bsz = tir.SizeVar("batch_size", "int64")
total_seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -971,7 +971,7 @@ def create_decoding_func_for_single_seq(
func_name = "decode"

bsz = 1
all_seq_len = tvm.tir.SizeVar("m", "int64")
all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64")

with bb.function(func_name):
model = LlamaForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64"))
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def create_decoding_func_for_batching(
) -> None:
func_name = "decode_with_embed"

bsz = tir.SizeVar("nseq", "int64")
bsz = tir.SizeVar("batch_size", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -1041,7 +1041,7 @@ def create_verification_func_for_batching(
) -> None:
func_name = "verify_with_embed"

total_seq_len = tvm.tir.SizeVar("m", "int64")
total_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -1160,7 +1160,7 @@ def create_softmax_func_for_single_seq(bb: relax.BlockBuilder, config: LlamaConf

def create_softmax_func_for_batching(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
with bb.function("softmax_with_temperature"):
bsz = tvm.tir.SizeVar("nseq", "int64")
bsz = tvm.tir.SizeVar("batch_size", "int64")
logits = nn.Placeholder(
(bsz, 1, tvm.tir.SizeVar("vocab_size", "int64")),
dtype="float32",
Expand Down Expand Up @@ -1188,7 +1188,7 @@ def kv_cache_transpose_append(
var_v_data: T.handle,
var_position_map: T.handle,
):
ntoken = T.SizeVar("ntoken", "int64")
ntoken = T.SizeVar("num_tokens_excluding_cache", "int64")
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()

Expand Down Expand Up @@ -1451,10 +1451,10 @@ def get_model(args, hf_config):
mod = bb.get()

tir_bound_map = dict()
tir_bound_map["n"] = (
tir_bound_map["num_tokens_without_cache"] = (
args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length
)
tir_bound_map["m"] = config.max_sequence_length
tir_bound_map["num_tokens_with_cache"] = config.max_sequence_length
tir_bound_map["vocab_size"] = args.max_vocab_size
if enable_batching:
tir_bound_map["nseq"] = args.max_batch_size
Expand Down
12 changes: 6 additions & 6 deletions mlc_llm/relax_model/llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def __init__(

# Set the cached sin/cos to the maximum of 2048 and max seq len.
# This will be eliminated further with online rotary embedding calculation.
cache_len = te.var("cache_len", "int64")
cache_len = te.var("cached_rotary_embedding_len", "int64")
self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached")
self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached")
############ End ############
Expand Down Expand Up @@ -455,8 +455,8 @@ def create_evaluate_func(
"""Evaluate logits for the last token in each sequence. Same as prefill but without KV cache."""
func_name = "evaluate"

num_token = tvm.tir.SizeVar("num_token", "int64")
num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")

with bb.function(func_name):
model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed)
Expand Down Expand Up @@ -504,8 +504,8 @@ def create_encoding_func(
"""
func_name = "prefill_with_embed" if sep_embed else "prefill"

num_token = tvm.tir.SizeVar("num_token", "int64")
num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")

num_inputs = 5

Expand Down Expand Up @@ -569,7 +569,7 @@ def create_decoding_func(
"""Batched decoding with vLLM paged KV cache."""
func_name = "decode"

num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")
max_num_blocks_per_seq = tvm.tir.SizeVar("max_num_blocks_per_seq", "int64")

with bb.function(func_name):
Expand Down
4 changes: 2 additions & 2 deletions python/mlc_chat/nn/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def tir_kv_cache_transpose_append(
var_position_map: T.handle,
):
T.func_attr({"tir.noalias": T.bool(True)})
ntoken = T.SizeVar("ntoken", "int64")
ntoken = T.SizeVar("num_tokens_excluding_cache", "int64")
num_pages = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, head_dim), dtype)
k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), dtype)
Expand Down Expand Up @@ -333,7 +333,7 @@ def tir_kv_cache_debug_get_kv(
layer_id: T.int64,
):
T.func_attr({"tir.noalias": T.bool(True)})
seqlen = T.SizeVar("seqlen", "int64")
seqlen = T.SizeVar("num_tokens_including_cache", "int64")
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype)
Expand Down