diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index 0a2c8f0b9c..98cbbe9c3e 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -20,29 +20,29 @@ class KVCache: - def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, disco_session): - if disco_session: - init_cache_func = disco_session.get_global_func("tvm.contrib.vllm.allocate_kv_cache") - else: - init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") - + def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, init_cache_func): self.cache = init_cache_func(head_size, num_layers, num_heads, block_size, num_blocks) - self.block_tables = defaultdict(list) self.slot_mappings = defaultdict(list) self.block_size = block_size class CacheManager: - block_size: int = 16 - def __init__( - self, num_blocks, num_layers, num_heads, head_size, disco_session=None, sliding_window=None + self, + num_blocks, + block_size, + num_layers, + num_heads, + head_size, + init_cache_func, + sliding_window=None, ): + self.block_size = block_size self.num_blocks = num_blocks self.free_blocks = list(range(num_blocks)) self.kv_cache = KVCache( - num_blocks, self.block_size, num_layers, num_heads, head_size, disco_session + num_blocks, self.block_size, num_layers, num_heads, head_size, init_cache_func ) if sliding_window: @@ -172,6 +172,7 @@ def _prepare_inputs( sliding_window, dev, is_prefill, + query_token_len=1, ): block_tables = [] seq_lens = [] @@ -201,13 +202,16 @@ def _prepare_inputs( start_idx += prompt_len else: - input_ids.append(token_ids[-1]) - pos = len(token_ids) - 1 - positions.append(pos) + input_ids += token_ids[-query_token_len:] + + for i in range(query_token_len): + positions.append(len(token_ids) - (query_token_len - i)) + + slot_mapping += all_slot_mappings[request_id][-query_token_len:] + block_table = all_block_tables[request_id] max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) block_tables.append(block_table) - slot_mapping.append(all_slot_mappings[request_id][-1]) if sliding_window: seq_lens.append(min(len(token_ids), sliding_window)) @@ -316,7 +320,15 @@ def _prepare_eval_queries( class Model: def __init__( - self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window + self, + artifact_path, + model_name, + quant, + vocab_size, + num_shards, + dev, + sliding_window, + block_size, ): self.mod, self.params, self.disco_session = get_tvm_model( artifact_path, model_name, quant, num_shards, dev @@ -326,7 +338,7 @@ def __init__( self.sliding_window = sliding_window if sliding_window: - self.block_sliding_window = sliding_window // CacheManager.block_size + self.block_sliding_window = sliding_window // block_size else: self.block_sliding_window = None @@ -409,6 +421,15 @@ def generate( ] +def get_paged_kv_cache_type(model_artifact_path): + config_file_path = os.path.join(model_artifact_path, "build_config.json") + assert os.path.exists(config_file_path) + + with open(config_file_path, mode="rt", encoding="utf-8") as f: + build_cfg = json.load(f) + return build_cfg["paged_kv_cache_type"] + + def parse_args(): # Example # python build.py --model vicuna-v1-7b --quantization q4f16_ft --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention @@ -444,6 +465,18 @@ def run(args): with open(os.path.join(model_path, "config.json"), encoding="utf-8") as i_f: config = LlamaConfig(**json.load(i_f)) + kv_type = get_paged_kv_cache_type(args.artifact_path) + use_flash_decoding = kv_type == "flash-decoding" + + if use_flash_decoding: + allocate_func_name = "tvm.contrib.flash_attn.allocate_kv_cache" + block_size = 256 + num_blocks = 30 + else: + allocate_func_name = "tvm.contrib.vllm.allocate_kv_cache" + block_size = 16 + num_blocks = 500 + model = Model( artifact_path, model_name, @@ -452,20 +485,26 @@ def run(args): args.num_shards, dev, config.sliding_window, + block_size, ) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) num_kv_heads = config.get_num_key_value_heads() // args.num_shards head_size = config.hidden_size // config.num_attention_heads - num_blocks = 500 + + if model.disco_session: + init_cache_func = model.disco_session.get_global_func(allocate_func_name) + else: + init_cache_func = tvm.get_global_func(allocate_func_name) cache_manager = CacheManager( num_blocks, + block_size, config.num_hidden_layers, num_kv_heads, head_size, - model.disco_session, + init_cache_func, sliding_window=config.sliding_window, ) cache = cache_manager.get() @@ -516,8 +555,28 @@ def run(args): for p, g in zip(prompts, generated): print("Prompt = '{}', generated text = '{}'".format(p, g)) - query_token_lens = [4, 3, 5, 2] + if model.disco_session: + return + + def verify_logits(logits, query_token_lens): + assert logits.shape[0] == sum(query_token_lens) + + logits_offset = 0 + + for request_id, query_token_len in zip(request_ids, query_token_lens): + for i in range(query_token_len - 1): + # requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens. + # Doing argmax over multi-timestep logits computed in parallel should yield the same + # tokens at the corresponding positions. + past_tokens = requests[request_id].token_ids[:-query_token_len] + assert ( + np.argmax(logits[logits_offset + i]) + == requests[request_id].token_ids[len(past_tokens) + i + 1] + ) + logits_offset += query_token_len + + query_token_lens = [4, 3, 5, 2] eval_query_requests = [] for request_id, query_token_len in zip(request_ids, query_token_lens): @@ -552,22 +611,47 @@ def run(args): model.params, )[0].numpy() - assert logits.shape[0] == sum(query_token_lens) + verify_logits(logits, query_token_lens) - logits_offset = 0 + if not use_flash_decoding: + return - for request_id, query_token_len in zip(request_ids, query_token_lens): - for i in range(query_token_len - 1): - # requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens. - # Doing argmax over multi-timestep logits computed in parallel should yield the same - # tokens at the corresponding positions. - past_tokens = requests[request_id].token_ids[:-query_token_len] - assert ( - np.argmax(logits[logits_offset + i]) - == requests[request_id].token_ids[len(past_tokens) + i + 1] - ) + query_token_lens = [3, 3, 3, 3] + decode_multi_query_requests = requests + query_len = query_token_lens[0] + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + _, + block_tables, + ) = _prepare_inputs( + decode_multi_query_requests, + cache.slot_mappings, + cache.block_tables, + model.sliding_window, + model.dev, + False, # is_prefill + query_len, + ) + + input_ids = tvm.nd.array(np.reshape(input_ids.numpy(), [-1, query_len]), dev) + + logits = model.mod["decode_multi_query"]( + input_ids, + positions, + seq_lens, + cache.cache, + slot_mapping, + block_tables, + model.params, + )[0].numpy() + + logits = np.reshape(logits, (-1, logits.shape[-1])) - logits_offset += query_token_len + verify_logits(logits, query_token_lens) if __name__ == "__main__": diff --git a/mlc_llm/core.py b/mlc_llm/core.py index f7afbbb693..b834ff9c33 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -392,6 +392,7 @@ class BuildArgs: "action": "store_true", }, ) + # TODO(masahi): Remove the use of this option with paged_kv_cache_type use_vllm_attention: bool = field( default=False, metadata={ @@ -402,6 +403,10 @@ class BuildArgs: "action": "store_true", }, ) + paged_kv_cache_type: str = field( + default="vllm", + metadata={"help": "The type of paged KV cache, either vllm or flash-decoding"}, + ) @property def convert_weight_only(self): @@ -595,6 +600,9 @@ def mod_transform_before_build( model_names.append("evaluate") model_names.append("evaluate_multi_query") + if args.paged_kv_cache_type == "flash-decoding": + model_names.append("decode_multi_query") + if args.sep_embed: model_names = ["embed", "prefill_with_embed"] + model_names[1:] if args.enable_batching: @@ -706,6 +714,7 @@ def dump_build_config( config: Dict[str, Any] = { "num_shards": args.num_shards, "quantization": args.quantization.name, + "paged_kv_cache_type": args.paged_kv_cache_type, "library_name": args.lib_name, "build_options": str(args) } diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 33b0966f54..97e7656c9b 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -1,4 +1,5 @@ from typing import Optional, Tuple, Union +from enum import Enum, auto from dataclasses import dataclass @@ -47,6 +48,11 @@ def rotary_compute(*idx): return q_embed, k_embed +class KVCacheType(Enum): + VLLM = auto() + FlashDecoding = auto() + + @dataclass class PrefillAttentionInput: seq_start: Optional[relax.Expr] # (num_seq + 1,) @@ -80,25 +86,248 @@ class AttentionInput: aux_info: Union[PrefillAttentionInput, DecodeAttentionInput, EvaluateMultiQueryInput] +class AttentionBackend: + def __init__(self, num_query_heads, num_key_value_heads, head_dim): + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + + def decode_attention( + self, + queries, + k_cache, + v_cache, + block_tables, + context_lens, + max_context_len, + num_seq, + seqlen_q, + ): + pass + + def update_cache(self, keys, values, k_cache, v_cache, slot_mapping): + pass + + def reconstruct_from_cache(self, k_cache, v_cache, past_slot_mapping): + pass + + +class VllmAttention(AttentionBackend): + block_size: int = 16 + + def __init__(self, num_query_heads, num_key_value_heads, head_dim, max_context_length): + super().__init__(num_query_heads, num_key_value_heads, head_dim) + + partition_size = 512 # partition_size in vLLM attention + self.max_num_partitions = (max_context_length + partition_size - 1) // partition_size + + def decode_attention( + self, + queries, + k_cache, + v_cache, + block_tables, + context_lens, + max_context_len, + num_seq, + seqlen_q, + ): + num_query_tokens = queries.struct_info.shape[0] + exp_sums = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr((num_query_tokens, self.num_query_heads, self.max_num_partitions)), + dtype="float32", + runtime_device_index=0, + ) + ) + max_logits = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr((num_query_tokens, self.num_query_heads, self.max_num_partitions)), + dtype="float32", + runtime_device_index=0, + ) + ) + tmp_out = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr( + ( + num_query_tokens, + self.num_query_heads, + self.max_num_partitions, + self.head_dim, + ) + ), + dtype=queries.struct_info.dtype, + runtime_device_index=0, + ) + ) + return nn.emit( + relax.op.call_dps_packed( + "tvm.contrib.vllm.single_query_cached_kv_attention", + [ + queries, + k_cache, + v_cache, + block_tables, + context_lens, + 16, # block_size + max_context_len, + exp_sums, + max_logits, + tmp_out, + ], + out_sinfo=queries.struct_info, + ) + ) + + def update_cache(self, keys, values, k_cache, v_cache, slot_mapping): + return nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.vllm.reshape_and_cache", + keys, + values, + k_cache, + v_cache, + slot_mapping, + sinfo_args=[k_cache.struct_info, v_cache.struct_info], + ) + ) + + def reconstruct_from_cache(self, k_cache, v_cache, past_slot_mapping): + num_kv_head = v_cache.struct_info.shape[1] + head_size = v_cache.struct_info.shape[2] + + num_past_token = past_slot_mapping.struct_info.shape[0] + kv_shape = (num_past_token, num_kv_head, head_size) + kv_sinfo = relax.TensorStructInfo(kv_shape, k_cache.struct_info.dtype) + + return nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.vllm.reconstruct_from_cache", + k_cache, + v_cache, + past_slot_mapping, + sinfo_args=[kv_sinfo, kv_sinfo], + ) + ) + + +class FlashDecodingAttention(AttentionBackend): + block_size: int = 256 + + def __init__(self, num_query_heads, num_key_value_heads, head_dim): + super().__init__(num_query_heads, num_key_value_heads, head_dim) + self.max_num_partitions = 128 + + def decode_attention( + self, + queries, + k_cache, + v_cache, + block_tables, + context_lens, + max_context_len, + num_seq, + seqlen_q, + ): + queries = nn.emit( + reshape(queries, (num_seq, seqlen_q, self.num_query_heads, self.head_dim)) + ) + + softmax_lse_accum = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr((self.max_num_partitions, num_seq, self.num_query_heads, seqlen_q)), + dtype="float32", + runtime_device_index=0, + ) + ) + output_accum = nn.emit( + relax.op.builtin.alloc_tensor( + relax.ShapeExpr( + ( + self.max_num_partitions, + num_seq, + self.num_query_heads, + seqlen_q, + self.head_dim, + ) + ), + dtype="float32", + runtime_device_index=0, + ) + ) + + return R.call_dps_packed( + "tvm.contrib.flash_attn.flash_decoding_with_paged_kvcache", + [ + queries, + k_cache, + v_cache, + block_tables, + context_lens, + softmax_lse_accum, + output_accum, + ], + out_sinfo=queries.struct_info, + ) + + def update_cache(self, keys, values, k_cache, v_cache, slot_mapping): + return nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.flash_attn.update_cache", + keys, + values, + k_cache, + v_cache, + slot_mapping, + sinfo_args=[k_cache.struct_info, v_cache.struct_info], + ) + ) + + def reconstruct_from_cache(self, k_cache, v_cache, past_slot_mapping): + num_kv_head = v_cache.struct_info.shape[2] + head_size = v_cache.struct_info.shape[-1] + + num_past_token = past_slot_mapping.struct_info.shape[0] + kv_shape = (num_past_token, num_kv_head, head_size) + kv_sinfo = relax.TensorStructInfo(kv_shape, k_cache.struct_info.dtype) + + return nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.flash_attn.reconstruct_from_cache", + k_cache, + v_cache, + past_slot_mapping, + sinfo_args=[kv_sinfo, kv_sinfo], + ) + ) + + class LlamaAttentionBatched(LlamaAttentionBase): - def __init__(self, config: LlamaConfig): + def __init__(self, config: LlamaConfig, kv_type: KVCacheType): super().__init__(config) + if kv_type == KVCacheType.VLLM: + max_context_length = config.sliding_window or config.max_sequence_length + self.attn_backend = VllmAttention( + self.num_query_heads, self.num_key_value_heads, self.head_dim, max_context_length + ) + else: + self.attn_backend = FlashDecodingAttention( + self.num_query_heads, self.num_key_value_heads, self.head_dim + ) + self.sliding_window = None if config.sliding_window: self.sliding_window = T.IntImm("int32", config.sliding_window) - max_context_length = config.sliding_window or config.max_sequence_length - partition_size = 512 # partition_size in vLLM attention - self.max_num_partitions = (max_context_length + partition_size - 1) // partition_size - def forward( self, - hidden_states: relax.Expr, # (num_query_token, hidden_size) + hidden_states: relax.Expr, # (num_query_token, hidden_size) or (num_seq, seqlen_q, hidden_size) positions: relax.Expr, # (num_query_token,), for batched RoPE attn_input: AttentionInput, ): - num_query_tokens, _ = hidden_states.struct_info.shape + num_query_tokens = positions.struct_info.shape[0] queries, keys, values = self.project_qkv( hidden_states, @@ -129,38 +358,18 @@ def forward( slot_mapping = attn_input.slot_mapping # kv caches are updated inplace, but make it look like a pure operation - kv = nn.emit( - relax.op.call_pure_packed( - "tvm.contrib.vllm.reshape_and_cache", - keys_to_cache, - values_to_cache, - k_cache, - v_cache, - slot_mapping, - sinfo_args=[k_cache.struct_info, v_cache.struct_info], - ) + kv = self.attn_backend.update_cache( + keys_to_cache, values_to_cache, k_cache, v_cache, slot_mapping ) - k_cache, v_cache = kv[0], kv[1] else: k_cache = v_cache = None if isinstance(attn_input.aux_info, EvaluateMultiQueryInput): assert k_cache and v_cache - num_kv_head = v_cache.struct_info.shape[1] - head_size = v_cache.struct_info.shape[2] - num_past_token = attn_input.aux_info.past_slot_mapping.struct_info.shape[0] - kv_shape = (num_past_token, num_kv_head, head_size) - kv_sinfo = relax.TensorStructInfo(kv_shape, k_cache.struct_info.dtype) - - kv_tensors = nn.emit( - relax.op.call_pure_packed( - "tvm.contrib.vllm.reconstruct_from_cache", - k_cache, - v_cache, - attn_input.aux_info.past_slot_mapping, - sinfo_args=[kv_sinfo, kv_sinfo], - ) + + kv_tensors = self.attn_backend.reconstruct_from_cache( + k_cache, v_cache, attn_input.aux_info.past_slot_mapping ) keys_past, values_past = kv_tensors[0], kv_tensors[1] # Say we have past tokens [P1, P2, P3] and the current ones [C1, C2, C3]. @@ -211,72 +420,36 @@ def forward( ) ) else: - # Decode, using vLLM kernel + # Decode, using vLLM or Flash-Decoding kernel assert isinstance(attn_input.aux_info, DecodeAttentionInput) - exp_sums = nn.emit( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr( - (num_query_tokens, self.num_query_heads, self.max_num_partitions) - ), - dtype="float32", - runtime_device_index=0, - ) - ) - max_logits = nn.emit( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr( - (num_query_tokens, self.num_query_heads, self.max_num_partitions) - ), - dtype="float32", - runtime_device_index=0, - ) - ) - tmp_out = nn.emit( - relax.op.builtin.alloc_tensor( - relax.ShapeExpr( - ( - num_query_tokens, - self.num_query_heads, - self.max_num_partitions, - self.head_dim, - ) - ), - dtype=queries.struct_info.dtype, - runtime_device_index=0, - ) - ) - attn_output = nn.emit( - relax.op.call_dps_packed( - "tvm.contrib.vllm.single_query_cached_kv_attention", - [ - queries, - k_cache, - v_cache, - attn_input.aux_info.block_tables, - attn_input.aux_info.seq_lens, - 16, # block_size - attn_input.max_seqlen, - exp_sums, - max_logits, - tmp_out, - ], - out_sinfo=queries.struct_info, - ) + if len(hidden_states.struct_info.shape) == 3: + num_seq, seqlen_q, _ = hidden_states.struct_info.shape + else: + num_seq = hidden_states.struct_info.shape[0] + seqlen_q = 1 + + attn_output = self.attn_backend.decode_attention( + queries, + k_cache, + v_cache, + attn_input.aux_info.block_tables, + attn_input.aux_info.seq_lens, + attn_input.max_seqlen, + num_seq, + seqlen_q, ) - attn_output = nn.emit( - reshape(attn_output, (num_query_tokens, self.num_query_heads * self.head_dim)) - ) + attn_output = nn.emit(reshape(attn_output, hidden_states.struct_info.shape)) attn_output = self.o_proj(attn_output) return attn_output, (k_cache, v_cache) class LlamaDecoderLayerBatched(LlamaDecoderLayer): - def __init__(self, config: LlamaConfig): + def __init__(self, config: LlamaConfig, kv_type: KVCacheType): super().__init__(config, False) - self.self_attn = LlamaAttentionBatched(config) + self.self_attn = LlamaAttentionBatched(config, kv_type) def forward( self, @@ -315,6 +488,7 @@ def __init__( self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, + kv_type: KVCacheType, sep_embed: bool = False, ): self.padding_idx = config.pad_token_id @@ -324,7 +498,7 @@ def __init__( self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) self.layers = ModuleList( - [LlamaDecoderLayerBatched(config) for _ in range(config.num_hidden_layers)] + [LlamaDecoderLayerBatched(config, kv_type) for _ in range(config.num_hidden_layers)] ) self.norm = LlamaRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) @@ -370,11 +544,12 @@ def __init__( config: LlamaConfig, cpu_device: VDevice, vocab_size_var: tvm.tir.SizeVar, + kv_type: KVCacheType, sep_embed: bool = False, ): self.num_shards = config.num_shards self.cpu_device = cpu_device - self.model = LlamaModel(config, vocab_size_var, sep_embed) + self.model = LlamaModel(config, vocab_size_var, kv_type, sep_embed) self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) ############ Rotary embedding constants ############ @@ -390,7 +565,7 @@ def __init__( def forward( self, - input_ids: relax.Expr, # (num_query_token,) + input_ids: relax.Expr, # (num_query_token,) or (num_seq, seqlen_q) positions: relax.Expr, # (num_query_token,), for batched RoPE seq_lens: relax.Expr, # (num_seq,) kv_caches: Optional[relax.Expr], # For prefill and decode, not needed for evaluate @@ -504,35 +679,51 @@ def get_logits_last_tokens(x, seq_len_tensor, seq_start): def get_inputs( - num_query_token, num_seq, config, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True + num_query_token, + num_seq, + input_shape, + config, + kv_type=None, + max_num_blocks_per_seq=None, + sep_embed=False, ): hidden_size = config.hidden_size inputs = ( - nn.Placeholder((num_query_token, hidden_size), dtype=config.dtype, name="inputs_embeds") + nn.Placeholder(input_shape + (hidden_size,), dtype=config.dtype, name="inputs_embeds") if sep_embed - else nn.Placeholder((num_query_token,), dtype="int32", name="input_ids") + else nn.Placeholder(input_shape, dtype="int32", name="input_ids") ) seq_lens = nn.Placeholder((num_seq,), dtype="int32", name="seq_lens") positions = nn.Placeholder((num_query_token,), dtype="int32", name="positions") - if need_cache: + if kv_type: num_blocks = tvm.tir.Var("num_blocks", "int64") - block_size = 16 - - vec_size = 8 # 128 bit, fp16 x 8 - num_key_value_heads = config.get_num_key_value_heads() // config.num_shards - head_size = hidden_size // config.num_attention_heads - - k_cache_shape = ( - num_blocks, - num_key_value_heads, - head_size // vec_size, - block_size, - vec_size, - ) - v_cache_shape = (num_blocks, num_key_value_heads, head_size, block_size) + + if kv_type == KVCacheType.VLLM: + block_size = VllmAttention.block_size + + vec_size = 8 # 128 bit, fp16 x 8 + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + head_size = hidden_size // config.num_attention_heads + + k_cache_shape = ( + num_blocks, + num_key_value_heads, + head_size // vec_size, + block_size, + vec_size, + ) + v_cache_shape = (num_blocks, num_key_value_heads, head_size, block_size) + else: + block_size = FlashDecodingAttention.block_size + + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + head_size = hidden_size // config.num_attention_heads + + k_cache_shape = (num_blocks, block_size, num_key_value_heads, head_size) + v_cache_shape = k_cache_shape get_cache_sinfo = lambda i: relax.TensorStructInfo( k_cache_shape if i % 2 == 0 else v_cache_shape, dtype="float16" @@ -579,7 +770,7 @@ def create_evaluate_func( param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) inputs, positions, seq_lens, _, _, _ = get_inputs( - num_query_token, num_seq, config, sep_embed=sep_embed + num_query_token, num_seq, (num_query_token,), config, sep_embed=sep_embed ) with bb.dataflow(): @@ -612,6 +803,7 @@ def create_encoding_func( bb: relax.BlockBuilder, param_manager: ParamManager, config: LlamaConfig, + kv_type: KVCacheType, cpu_dev: VDevice, quant_scheme: QuantizationScheme, sep_embed: bool = False, @@ -629,11 +821,13 @@ def create_encoding_func( num_inputs = 5 with bb.function(func_name): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed) + model = LlamaForCausalLM( + config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), kv_type, sep_embed + ) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( - num_query_token, num_seq, config, sep_embed=sep_embed + num_query_token, num_seq, (num_query_token,), config, kv_type, sep_embed=sep_embed ) with bb.dataflow(): @@ -684,6 +878,7 @@ def create_decoding_func( bb: relax.BlockBuilder, param_manager: ParamManager, config: LlamaConfig, + kv_type: KVCacheType, cpu_dev: VDevice, quant_scheme: QuantizationScheme, ) -> None: @@ -691,49 +886,67 @@ def create_decoding_func( func_name = "decode" num_seq = tvm.tir.SizeVar("num_seq", "int64") - max_num_blocks_per_seq = tvm.tir.SizeVar("max_num_blocks_per_seq", "int64") - with bb.function(func_name): - inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables = get_inputs( - num_seq, num_seq, config, max_num_blocks_per_seq - ) + func_names = ["decode"] - with bb.dataflow(): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64")) - param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + if kv_type == KVCacheType.FlashDecoding: + func_names.append("decode_multi_query") - logits, new_kvs = model( - inputs, - positions, - seq_lens, - past_key_values, - slot_mapping, - block_tables, - None, - None, - None, - None, + for func_name in func_names: + max_num_blocks_per_seq = tvm.tir.SizeVar("max_num_blocks_per_seq", "int64") + + if func_name == "decode": + num_query_token = num_seq + input_shape = (num_query_token,) + else: + seqlen_q = tvm.tir.SizeVar("seqlen_q", "int64") + num_query_token = num_seq * seqlen_q + input_shape = (num_seq, seqlen_q) + + with bb.function(func_name): + inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables = get_inputs( + num_query_token, num_seq, input_shape, config, kv_type, max_num_blocks_per_seq ) - params = [ - inputs, - positions, - seq_lens, - past_key_values, - slot_mapping, - block_tables, - ] + model.parameters() - gv = bb.emit_output((logits, relax.Tuple(new_kvs))) - bb.emit_func_output(gv, params) - mod = bb.get() - gv = mod.get_global_var(func_name) - bb.update_func(gv, mod[gv].with_attr("num_input", 6)) + with bb.dataflow(): + model = LlamaForCausalLM( + config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), kv_type + ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + logits, new_kvs = model( + inputs, + positions, + seq_lens, + past_key_values, + slot_mapping, + block_tables, + None, + None, + None, + None, + ) + params = [ + inputs, + positions, + seq_lens, + past_key_values, + slot_mapping, + block_tables, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(new_kvs))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 6)) def create_evaluate_multi_query_func( bb: relax.BlockBuilder, param_manager: ParamManager, config: LlamaConfig, + kv_type: KVCacheType, cpu_dev: VDevice, quant_scheme: QuantizationScheme, ) -> None: @@ -747,11 +960,13 @@ def create_evaluate_multi_query_func( num_inputs = 8 with bb.function(func_name): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), False) + model = LlamaForCausalLM( + config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), kv_type, False + ) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( - num_query_token, num_seq, config, sep_embed=False + num_query_token, num_seq, (num_query_token,), config, kv_type, sep_embed=False ) query_lens = nn.Placeholder((num_seq,), dtype="int32", name="query_lens") @@ -858,10 +1073,15 @@ def get_model(args, hf_config): # The CPU device to copy the result of relax.op.max(seq_lens) to CPU. cpu_dev = VDevice("llvm", 0, "global") + if args.paged_kv_cache_type == "flash-decoding": + kv_type = KVCacheType.FlashDecoding + else: + kv_type = KVCacheType.VLLM + create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) - create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) - create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization) - create_evaluate_multi_query_func(bb, param_manager, config, cpu_dev, args.quantization) + create_encoding_func(bb, param_manager, config, kv_type, cpu_dev, args.quantization, sep_embed) + create_decoding_func(bb, param_manager, config, kv_type, cpu_dev, args.quantization) + create_evaluate_multi_query_func(bb, param_manager, config, kv_type, cpu_dev, args.quantization) mod = bb.get()