diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py new file mode 100644 index 0000000000..a290eb892c --- /dev/null +++ b/examples/python/run_llama_batched_vllm.py @@ -0,0 +1,448 @@ +import argparse +import math +import os +import json +from collections import defaultdict +from typing import List +from dataclasses import dataclass + +import numpy as np + +import tvm +from tvm import relax +from tvm.runtime import disco as di + +import torch +from transformers import AutoTokenizer + +from mlc_llm.relax_model.llama import LlamaConfig +from mlc_llm import utils + + +class KVCache: + def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, disco_session): + if disco_session: + init_cache_func = disco_session.get_global_func("tvm.contrib.vllm.allocate_kv_cache") + else: + init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") + + self.cache = init_cache_func(head_size, num_layers, num_heads, block_size, num_blocks) + + self.block_tables = defaultdict(list) + self.slot_mappings = defaultdict(list) + self.block_size = block_size + + +class CacheManager: + block_size: int = 16 + + def __init__( + self, num_blocks, num_layers, num_heads, head_size, disco_session=None, sliding_window=None + ): + self.num_blocks = num_blocks + self.free_blocks = list(range(num_blocks)) + self.kv_cache = KVCache( + num_blocks, self.block_size, num_layers, num_heads, head_size, disco_session + ) + + if sliding_window: + assert sliding_window % self.kv_cache.block_size == 0 + self.block_sliding_window = sliding_window // self.kv_cache.block_size + else: + self.block_sliding_window = None + + def set_size(self, request_ids: List[int], target_sizes: List[int]): + for id, size in zip(request_ids, target_sizes): + num_needed_block = math.ceil(size / self.block_size) + + if self.block_sliding_window: + num_needed_block = min(num_needed_block, self.block_sliding_window) + + if id in self.kv_cache.block_tables and size == 0: + self.free_blocks.extend(self.kv_cache.block_tables[id]) + del self.kv_cache.block_tables[id] + del self.kv_cache.slot_mappings[id] + + elif id in self.kv_cache.block_tables: + # Decoding + if len(self.kv_cache.block_tables[id]) < num_needed_block: + # Need to allocate a new block for this request + assert len(self.kv_cache.block_tables[id]) + 1 == num_needed_block + self.kv_cache.block_tables[id].append(self.free_blocks.pop()) + + pos = size - 1 + block_number = self.kv_cache.block_tables[id][-1] + + if self.block_sliding_window: + block_number = self.kv_cache.block_tables[id][ + (pos // self.block_size) % self.block_sliding_window + ] + else: + block_number = self.kv_cache.block_tables[id][-1] + + block_offset = pos % self.block_size + slot = block_number * self.block_size + block_offset + self.kv_cache.slot_mappings[id].append(slot) + + elif id not in self.kv_cache.block_tables: + assert len(self.free_blocks) >= num_needed_block, "Not enough free blocks." + + for _ in range(num_needed_block): + self.kv_cache.block_tables[id].append(self.free_blocks.pop()) + + for i in range(size): + block_idx = i // self.block_size + + if self.block_sliding_window: + block_idx %= self.block_sliding_window + + block_number = self.kv_cache.block_tables[id][block_idx] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + self.kv_cache.slot_mappings[id].append(slot) + + def get(self): + return self.kv_cache + + +@dataclass +class SequenceGenerationRequest: + request_id: int + token_ids: List[int] + + +@dataclass +class SequenceGenerationResponse: + request_id: int + token_id: int + + +def sample(logits): + logits = torch.from_dlpack(logits) + return torch.argmax(logits, -1).cpu().numpy() + + +def load_params_disco(artifact_path, lib_path, num_shards): + sess = di.ProcessSession(num_workers=num_shards) + devices = range(num_shards) + sess.init_ccl("nccl", *devices) + module = sess.load_vm_module(lib_path) + + loader_create = sess.get_global_func("runtime.disco.ShardLoader") + metadata_path = os.path.join(artifact_path, "params", "ndarray-cache.json") + with open(metadata_path, "r", encoding="utf-8") as f: + ndarray_cache_metadata = f.read() + + loader = loader_create(metadata_path, ndarray_cache_metadata, "", module) + loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAll") + params = loader_load(loader) + + return module, params, sess + + +def copy_to_worker_0(sess: di.Session, host_array): + x_array = sess.empty(host_array.shape, host_array.dtype) + sess.copy_to_worker_0(host_array, x_array) + return x_array + + +def get_tvm_model(artifact_path, model, quantization, num_shards, dev): + lib_path = os.path.join(artifact_path, f"{model}-{quantization}-cuda.so") + + if num_shards == 1: + ex = tvm.runtime.load_module(lib_path) + vm = relax.VirtualMachine(ex, dev) + params = utils.load_params(artifact_path, dev) + return vm.module, params, None + + return load_params_disco(artifact_path, lib_path, num_shards) + + +def _prepare_inputs( + requests, + all_slot_mappings, + all_block_tables, + sliding_window, + dev, + is_prefill, +): + block_tables = [] + seq_lens = [] + input_ids = [] + slot_mapping = [] + positions = [] + max_num_blocks_per_seq = 0 + indices_within_window = [] + start_idx = 0 + + for request in requests: + request_id = request.request_id + token_ids = request.token_ids + + if is_prefill: + input_ids += token_ids + prompt_len = len(token_ids) + seq_lens.append(prompt_len) + positions += range(prompt_len) + slot_mapping += all_slot_mappings[request_id] + + if sliding_window: + indices_within_window += range( + start_idx + max(0, prompt_len - sliding_window), + start_idx + prompt_len, + ) + start_idx += prompt_len + + else: + input_ids.append(token_ids[-1]) + pos = len(token_ids) - 1 + positions.append(pos) + block_table = all_block_tables[request_id] + max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) + block_tables.append(block_table) + slot_mapping.append(all_slot_mappings[request_id][-1]) + + if sliding_window: + seq_lens.append(min(len(token_ids), sliding_window)) + else: + seq_lens.append(len(token_ids)) + + input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), dev) + positions = tvm.nd.array(np.array(positions, dtype="int32"), dev) + seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), dev) + slot_mapping = tvm.nd.array(np.array(slot_mapping, dtype="int32"), dev) + + if is_prefill and sliding_window: + indices_within_window = tvm.nd.array(np.array(indices_within_window, dtype="int32"), dev) + else: + indices_within_window = None + + if not is_prefill: + + def _pad_to_max(x: List[int], max_len: int) -> List[int]: + return x + [0] * (max_len - len(x)) + + padded_block_tables = [ + _pad_to_max(block_table, max_num_blocks_per_seq) for block_table in block_tables + ] + + block_tables_np = np.vstack(padded_block_tables).astype("int32") + block_tables = tvm.nd.array(np.array(block_tables_np, dtype="int32"), dev) + else: + block_tables = None + + return ( + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ) + + +class Model: + def __init__( + self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window + ): + self.mod, self.params, self.disco_session = get_tvm_model( + artifact_path, model_name, quant, num_shards, dev + ) + self.dev = dev + self.vocab_size = vocab_size + self.sliding_window = sliding_window + + if sliding_window: + self.block_sliding_window = sliding_window // CacheManager.block_size + else: + self.block_sliding_window = None + + def generate( + self, requests: List[SequenceGenerationRequest], cache: KVCache, is_prefill: bool + ) -> List[SequenceGenerationResponse]: + ( + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ) = _prepare_inputs( + requests, + cache.slot_mappings, + cache.block_tables, + self.sliding_window, + self.dev, + is_prefill, + ) + + if self.disco_session: + input_ids = copy_to_worker_0(self.disco_session, input_ids) + positions = copy_to_worker_0(self.disco_session, positions) + seq_lens = copy_to_worker_0(self.disco_session, seq_lens) + slot_mapping = copy_to_worker_0(self.disco_session, slot_mapping) + + kv_cache = cache.cache + + if is_prefill: + if self.sliding_window: + if self.disco_session: + indices_within_window = copy_to_worker_0( + self.disco_session, indices_within_window + ) + + out = self.mod["prefill"]( + input_ids, + positions, + seq_lens, + kv_cache, + slot_mapping, + indices_within_window, + self.params, + ) + else: + out = self.mod["prefill"]( + input_ids, positions, seq_lens, kv_cache, slot_mapping, self.params + ) + + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[0] # Ignore returned KV cache since it is updated in-place anyway. + else: + if self.disco_session: + block_tables = copy_to_worker_0(self.disco_session, block_tables) + + out = self.mod["decode"]( + input_ids, + positions, + seq_lens, + kv_cache, + slot_mapping, + block_tables, + self.params, + ) + + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[0] + + next_tokens = sample(logits) + + return [ + SequenceGenerationResponse(request.request_id, new_token) + for request, new_token in zip(requests, next_tokens) + ] + + +def parse_args(): + # Example + # python build.py --model vicuna-v1-7b --quantization q4f16_ft --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention + # python examples/python/run_llama_batched_vllm.py --local-id vicuna-v1-7b-q4f16_ft + # + # For Disco: + # python build.py --model vicuna-v1-7b --quantization q0f16 --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention --build-model-only --num-shards 2 + # python build.py --model vicuna-v1-7b --quantization q0f16 --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention --convert-weight-only + # CUDA_VISIBLE_DEVICES=0,1 python examples/python/run_llama_batched_vllm.py --local-id vicuna-v1-7b-q0f16 --num-shards 2 + + args = argparse.ArgumentParser() + args.add_argument("--local-id", type=str, required=True) + args.add_argument("--artifact-path", type=str, default="dist") + args.add_argument("--num-shards", type=int, default=1) + args.add_argument("--num-decode-steps", type=int, default=20) + parsed = args.parse_args() + parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) + utils.argparse_postproc_common(parsed) + parsed.artifact_path = os.path.join( + parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" + ) + return parsed + + +def run(args): + quantization = args.quantization.name + artifact_path = args.artifact_path + model_name = args.model + model_path = f"dist/models/{model_name}" + + dev = tvm.device("cuda", 0) + + with open(os.path.join(model_path, "config.json"), encoding="utf-8") as i_f: + config = LlamaConfig(**json.load(i_f)) + + model = Model( + artifact_path, + model_name, + quantization, + config.vocab_size, + args.num_shards, + dev, + config.sliding_window, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) + + num_kv_heads = config.get_num_key_value_heads() // args.num_shards + head_size = config.hidden_size // config.num_attention_heads + num_blocks = 500 + + cache_manager = CacheManager( + num_blocks, + config.num_hidden_layers, + num_kv_heads, + head_size, + model.disco_session, + sliding_window=config.sliding_window, + ) + cache = cache_manager.get() + + model.block_sliding_window = cache_manager.block_sliding_window + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + batched_token_ids = [tokenizer.encode(p) for p in prompts] + prompts_len = [len(ids) for ids in batched_token_ids] + request_ids = list(range(len(prompts))) + target_sizes = [] + requests = [] + + for token_ids, request_id in zip(batched_token_ids, request_ids): + request_ids.append(request_id) + target_sizes.append(len(token_ids)) + requests.append(SequenceGenerationRequest(request_id, token_ids)) + + cache_manager.set_size(request_ids, target_sizes) + + out = model.generate(requests, cache, True) + + for _ in range(args.num_decode_steps): + for i, response in enumerate(out): + new_token_id = response.token_id + requests[i].token_ids.append(new_token_id) + target_sizes[i] += 1 + + cache_manager.set_size(request_ids, target_sizes) + + out = model.generate(requests, cache, False) + + output_tokens = [ + tokenizer.convert_ids_to_tokens( + requests[i].token_ids[prompts_len[i] :], skip_special_tokens=True + ) + for i in range(len(requests)) + ] + + generated = [tokenizer.convert_tokens_to_string(tokens) for tokens in output_tokens] + + for p, g in zip(prompts, generated): + print("Prompt = '{}', generated text = '{}'".format(p, g)) + + +if __name__ == "__main__": + run(parse_args()) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index e720d19542..0b7d1c8c39 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -22,6 +22,7 @@ gpt_neox, gptj, llama, + llama_batched_vllm, minigpt, param_manager, rwkv, @@ -96,7 +97,7 @@ class BuildArgs: Disable offloading layer and RMS norm operations to CUTLASS. no_cublas: bool Disable the step that offloads matmul to cuBLAS. Without this flag, - matmul will be offloaded to cuBLAS if quantization mode is ``q0f16`` or + matmul will be offloaded to cuBLAS if quantization mode is ``q0f16`` or ``q0f32``, target is CUDA and TVM has been built with cuBLAS enabled. use_cuda_graph: bool Specifies whether to enable CUDA Graph for the decoder. MLP and QKV @@ -108,6 +109,8 @@ class BuildArgs: Offload multi-query attention workload to Flash Attention. pdb: bool If set, drop into a pdb debugger on error. + use_vllm_attention: bool + Use vLLM paged KV cache and attention kernel, only relevant when enable_batching=True. """ model: str = field( default="auto", @@ -279,6 +282,15 @@ class BuildArgs: "action": "store_true", }, ) + use_vllm_attention: bool = field( + default=False, + metadata={ + "help": ( + "Use vLLM paged KV cache and attention kernel, only relevant when enable_batching=True." + ), + "action": "store_true", + }, + ) def convert_build_args_to_argparser() -> argparse.ArgumentParser: @@ -315,6 +327,11 @@ def _parse_args(parsed) -> argparse.Namespace: utils.parse_target(parsed) utils.argparse_postproc_common(parsed) + if parsed.use_vllm_attention: + assert parsed.enable_batching, "--enable_batching is required for using vLLM attention." + assert parsed.target_kind == "cuda", "vLLM attention is only supported for CUDA." + assert tvm.get_global_func("tvm.contrib.vllm.single_query_cached_kv_attention", True), "TVM needs to be built with -DUSE_VLLM=ON." + parsed.artifact_path = os.path.join( parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" ) @@ -409,10 +426,19 @@ def mod_transform_before_build( model_names = [ "prefill", "decode", - "create_kv_cache", - "softmax_with_temperature", - "get_metadata", ] + + if not args.use_vllm_attention: + model_names += [ + "create_kv_cache", + "softmax_with_temperature", + "get_metadata", + ] + else: + # This is equivalent to prefill but without KV cache. It is used for + # determining the number of paged cache blocks that can be allocated. + model_names.append("evaluate") + if args.sep_embed: model_names = ["embed", "prefill_with_embed"] + model_names[1:] if args.enable_batching: @@ -427,7 +453,8 @@ def mod_transform_before_build( mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod) if ( - hasattr(config, "num_attention_heads") + not args.enable_batching + and hasattr(config, "num_attention_heads") and hasattr(config, "hidden_size") and hasattr(config, "position_embedding_base") and getattr(config, "dtype", "float16") == "float16" @@ -649,6 +676,10 @@ def build_model_from_args(args: argparse.Namespace): "chatglm": chatglm, } + if args.use_vllm_attention: + model_generators["llama"] = llama_batched_vllm + model_generators["mistral"] = llama_batched_vllm + assert args.model_category in model_generators, f"Model {args.model} not supported" mod, param_manager, params, model_config = model_generators[args.model_category].get_model( diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index e45a4a3e20..8294313324 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -38,6 +38,7 @@ def __init__( combine_matmul=True, build_model_only=False, num_shards=1, + sliding_window=None, **kwargs, ): self.dtype = dtype @@ -57,6 +58,8 @@ def __init__( self.tie_word_embeddings = tie_word_embeddings self.position_embedding_base = position_embedding_base self.combine_matmul = combine_matmul + self.sliding_window = sliding_window + if build_model_only and num_shards > 1: self.num_shards = num_shards else: @@ -120,30 +123,50 @@ def f_rms_norm(x, weight): def f_square(x): return tir.Cast("float32", x) * tir.Cast("float32", x) if not is_float32 else x * x - k = te.reduce_axis((0, x.shape[2]), name="k") - square_sum = te.compute( - (x.shape[0], x.shape[1]), - lambda bsz, i: te.sum(f_square(x[bsz, i, k]), axis=k), - name=x.op.name + "red_temp", - ) + def f_mul_cast(x, y): + value = x * y + if not is_float32: + value = tir.Cast(x.dtype, value) + return value - def f_div_cast(bsz, i, k): + def f_div_cast_2d(i, k): + x_val = x[i, k] + if not is_float32: + x_val = tir.Cast("float32", x_val) + return x_val / tir.sqrt(square_sum[i] / x.shape[1] + self.variance_epsilon) + + def f_div_cast_3d(bsz, i, k): x_val = x[bsz, i, k] if not is_float32: x_val = tir.Cast("float32", x_val) return x_val / tir.sqrt(square_sum[bsz, i] / x.shape[2] + self.variance_epsilon) - def f_mul_cast(x, y): - value = x * y - if not is_float32: - value = tir.Cast(x.dtype, value) - return value + k = te.reduce_axis((0, x.shape[-1]), name="k") - return te.compute( - x.shape, - lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast(bsz, i, k)), - name="rms_norm", - ) + if len(x.shape) == 2: + square_sum = te.compute( + (x.shape[0],), + lambda i: te.sum(f_square(x[i, k]), axis=k), + name=x.op.name + "red_temp", + ) + + return te.compute( + x.shape, + lambda i, k: f_mul_cast(weight(k), f_div_cast_2d(i, k)), + name="rms_norm", + ) + else: + square_sum = te.compute( + (x.shape[0], x.shape[1]), + lambda bsz, i: te.sum(f_square(x[bsz, i, k]), axis=k), + name=x.op.name + "red_temp", + ) + + return te.compute( + x.shape, + lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast_3d(bsz, i, k)), + name="rms_norm", + ) return nn.emit_te(f_rms_norm, hidden_states, self.weight, primfunc_name_hint="rms_norm") @@ -186,28 +209,36 @@ def forward(self, x): return result +def rotary_modulate_by_freq(tensor, idx, pos, position_embedding_base): + head_dim = tensor.shape[-1] + dtype = tensor.dtype + n_feat_half = head_dim // 2 + feat_idx = idx[-1] + inv_freq = te.const(1, "float32") / ( + te.power( + te.const(position_embedding_base, "float32"), + ((2 * feat_idx) % head_dim).astype("float32") / head_dim.astype("float32"), + ) + ) + freq = pos * inv_freq + left_indices = idx[:-1] + (feat_idx - n_feat_half,) + right_indices = idx[:-1] + (feat_idx + n_feat_half,) + return te.cos(freq).astype(dtype) * tensor(*idx) + te.sin(freq).astype(dtype) * tvm.tir.Select( + feat_idx >= n_feat_half, + tensor[(*left_indices,)], + -tensor[(*right_indices,)], + ) + + def apply_rotary_pos_emb(q, k, position_embedding_base, offset: int = 0): def f_rotary_embedding(tensor, offset): - dtype = tensor.dtype - head_dim = tensor.shape[-1] - n_feat_half = tensor.shape[-1] // 2 - def rotary_compute(*idx): - i, j = idx[-3], idx[-1] - pos = (offset + i).astype("float32") - inv_freq = te.const(1, "float32") / ( - te.power( - te.const(position_embedding_base, "float32"), - ((2 * j) % head_dim).astype("float32") / head_dim.astype("float32"), - ) - ) - freq = pos * inv_freq - return te.cos(freq).astype(dtype) * tensor(*idx) + te.sin(freq).astype( - dtype - ) * tvm.tir.Select( - j >= n_feat_half, - tensor[idx[0], i, idx[2], j - n_feat_half], - -tensor[idx[0], i, idx[2], j + n_feat_half], + pos = (offset + idx[-3]).astype("float32") + return rotary_modulate_by_freq( + tensor, + idx, + pos, + position_embedding_base, ) return tvm.te.compute(tensor.shape, rotary_compute, name="rotary") @@ -268,18 +299,9 @@ def __init__(self, config: LlamaConfig): self.o_proj.weight.shard_dim = 1 self.o_proj.weight.shard_strategy = "shard_o_proj_k" - def forward( - self, - hidden_states: relax.Expr, - all_seq_len_shape: Optional[relax.Expr], - past_key_values: Union[relax.Expr, Tuple[relax.Expr]], - layer_id: int, - attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Union[relax.Expr, Tuple[relax.Expr]]]: + def project_qkv(self, hidden_states, query_output_shape, kv_output_shape): from tvm.relax.op import reshape, split - bsz, q_len, _ = hidden_states.struct_info.shape - if self.combine_matmul: qkv_states = nn.emit( split( @@ -300,24 +322,35 @@ def forward( value_states = self.v_proj(hidden_states) query_states = nn.emit( - reshape( - query_states, - (bsz, q_len, self.num_query_heads, self.head_dim), - ), + reshape(query_states, query_output_shape), ) key_states = nn.emit( - reshape( - key_states, - (bsz, q_len, self.num_key_value_heads, self.head_dim), - ), + reshape(key_states, kv_output_shape), ) value_states = nn.emit( - reshape( - value_states, - (bsz, q_len, self.num_key_value_heads, self.head_dim), - ), + reshape(value_states, kv_output_shape), ) + return query_states, key_states, value_states + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], + past_key_values: Union[relax.Expr, Tuple[relax.Expr]], + layer_id: int, + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Union[relax.Expr, Tuple[relax.Expr]]]: + bsz, q_len, _ = hidden_states.struct_info.shape + + query_states, key_states, value_states = self.project_qkv( + hidden_states, + (bsz, q_len, self.num_query_heads, self.head_dim), + (bsz, q_len, self.num_key_value_heads, self.head_dim), + ) + + from tvm.relax.op import reshape + attn_output, past_key_values = self.attention_fwd( query_states, key_states, @@ -541,6 +574,29 @@ def __init__(self, config: LlamaConfig, enable_batching: bool): config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps ) + def post_self_attn(self, hidden_states, residual): + if self.self_attn.num_shards > 1: + residual = nn.emit( + residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if self.self_attn.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if self.mlp.num_shards > 1: + residual = nn.emit( + residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if self.mlp.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + + return hidden_states + def forward( self, hidden_states: relax.Expr, @@ -561,25 +617,7 @@ def forward( all_seq_len_shape=all_seq_len_shape, layer_id=layer_id, ) - if self.self_attn.num_shards > 1: - residual = nn.emit( - residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype) - ) - hidden_states = nn.emit(residual + hidden_states) - if self.self_attn.num_shards > 1: - hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - if self.mlp.num_shards > 1: - residual = nn.emit( - residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) - ) - hidden_states = nn.emit(residual + hidden_states) - if self.mlp.num_shards > 1: - hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + hidden_states = self.post_self_attn(hidden_states, residual) return hidden_states, present_key_value @@ -1164,6 +1202,91 @@ def kv_cache_transpose_append( bb.add_func(relax.extern("attention_func"), "attention") +def setup_params(mod, param_manager, dtype, config, args): + def f_convert_pname_fwd(pname: str) -> List[str]: + if not config.combine_matmul: + return [pname] + + qkv_str = "query_key_value_proj" + gate_up_str = "gate_up_proj" + if qkv_str in pname: + return [ + pname.replace(qkv_str, "q_proj"), + pname.replace(qkv_str, "k_proj"), + pname.replace(qkv_str, "v_proj"), + ] + elif gate_up_str in pname: + return [ + pname.replace(gate_up_str, "gate_proj"), + pname.replace(gate_up_str, "up_proj"), + ] + else: + return [pname] + + def f_convert_param_bkwd(torch_pname: str, torch_param): + if not config.combine_matmul: + return [(torch_pname, torch_param.astype(dtype))] + + combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] + if any([name in torch_pname for name in combined_layers]): + return None + return [(torch_pname, torch_param.astype(dtype))] + + def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): + # Expected to enter this function only for the combined linear matmul weights. + # Other weights are supposed to be loaded in `f_convert_param_bkwd` since + # each other relax param has a unique corresponding torch param. + if not config.combine_matmul: + # When matmul combination is not turned on, each relax param has a unique + # corresponding torch param, and this function is not expected to be entered. + raise NotImplementedError( + "Matmul combination is not turned on, and the function " + "is not expected to be entered" + ) + hidden_size = config.hidden_size + head_dim = config.hidden_size // config.num_attention_heads + + if "query_key_value_proj" in relax_pname: + q_heads = config.num_attention_heads + kv_heads = config.get_num_key_value_heads() + q, k, v = torch_params + assert q.shape == (q_heads * head_dim, hidden_size) + assert k.shape == (kv_heads * head_dim, hidden_size) + assert v.shape == (kv_heads * head_dim, hidden_size) + qkv = np.concatenate([q, k, v], axis=0).astype(dtype) + return qkv + if "gate_up_proj" in relax_pname: + gate, up = torch_params + gate_up = np.concatenate([gate, up], axis=0).astype(dtype) + return gate_up + raise ValueError("Unexpected param loading") + + param_manager.set_param_loading_func( + args.model_path, + args.use_safetensors, + f_convert_pname_fwd, + f_convert_param_bkwd, + f_compute_relax_param, + ) + + device = tvm.cpu() + param_list = [None] * param_manager.nparam_to_load + + head_dim = config.hidden_size / config.num_attention_heads + inv_freq = 1.0 / ( + config.position_embedding_base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim) + ) + + # The following cos/sin values can be removed but **are kept for compatibility issues**. + t = np.arange(2048, dtype=inv_freq.dtype) + freqs = np.einsum("i,j->ij", t, inv_freq) + emb = np.concatenate((freqs, freqs), axis=-1) + param_list[-2] = tvm.nd.array(np.cos(emb).astype(config.dtype), device) + param_list[-1] = tvm.nd.array(np.sin(emb).astype(config.dtype), device) + + return mod, param_manager, param_list, config + + def get_model(args, hf_config): model_name = args.model dtype = args.quantization.model_dtype @@ -1174,7 +1297,7 @@ def get_model(args, hf_config): raise ValueError("`sep_embed` is required when batching is enabled.") position_embedding_base = 10000 - max_position_embeddings = 2048 + if "rope_theta" in hf_config: position_embedding_base = hf_config["rope_theta"] @@ -1249,85 +1372,4 @@ def get_model(args, hf_config): if args.build_model_only: return mod, param_manager, None, config - def f_convert_pname_fwd(pname: str) -> List[str]: - if not config.combine_matmul: - return [pname] - - qkv_str = "query_key_value_proj" - gate_up_str = "gate_up_proj" - if qkv_str in pname: - return [ - pname.replace(qkv_str, "q_proj"), - pname.replace(qkv_str, "k_proj"), - pname.replace(qkv_str, "v_proj"), - ] - elif gate_up_str in pname: - return [ - pname.replace(gate_up_str, "gate_proj"), - pname.replace(gate_up_str, "up_proj"), - ] - else: - return [pname] - - def f_convert_param_bkwd(torch_pname: str, torch_param): - if not config.combine_matmul: - return [(torch_pname, torch_param.astype(dtype))] - - combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] - if any([name in torch_pname for name in combined_layers]): - return None - return [(torch_pname, torch_param.astype(dtype))] - - def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): - # Expected to enter this function only for the combined linear matmul weights. - # Other weights are supposed to be loaded in `f_convert_param_bkwd` since - # each other relax param has a unique corresponding torch param. - if not config.combine_matmul: - # When matmul combination is not turned on, each relax param has a unique - # corresponding torch param, and this function is not expected to be entered. - raise NotImplementedError( - "Matmul combination is not turned on, and the function " - "is not expected to be entered" - ) - hidden_size = config.hidden_size - head_dim = config.hidden_size // config.num_attention_heads - - if "query_key_value_proj" in relax_pname: - q_heads = config.num_attention_heads - kv_heads = config.get_num_key_value_heads() - q, k, v = torch_params - assert q.shape == (q_heads * head_dim, hidden_size) - assert k.shape == (kv_heads * head_dim, hidden_size) - assert v.shape == (kv_heads * head_dim, hidden_size) - qkv = np.concatenate([q, k, v], axis=0).astype(dtype) - return qkv - if "gate_up_proj" in relax_pname: - gate, up = torch_params - gate_up = np.concatenate([gate, up], axis=0).astype(dtype) - return gate_up - raise ValueError("Unexpected param loading") - - param_manager.set_param_loading_func( - args.model_path, - args.use_safetensors, - f_convert_pname_fwd, - f_convert_param_bkwd, - f_compute_relax_param, - ) - - device = tvm.cpu() - param_list = [None] * param_manager.nparam_to_load - - head_dim = config.hidden_size / config.num_attention_heads - inv_freq = 1.0 / ( - config.position_embedding_base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim) - ) - - # The following cos/sin values can be removed but **are kept for compatibility issues**. - t = np.arange(2048, dtype=inv_freq.dtype) - freqs = np.einsum("i,j->ij", t, inv_freq) - emb = np.concatenate((freqs, freqs), axis=-1) - param_list[-2] = tvm.nd.array(np.cos(emb).astype(config.dtype), device) - param_list[-1] = tvm.nd.array(np.sin(emb).astype(config.dtype), device) - - return mod, param_manager, param_list, config + return setup_params(mod, param_manager, dtype, config, args) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py new file mode 100644 index 0000000000..2309bdd92e --- /dev/null +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -0,0 +1,661 @@ +from typing import Optional, Tuple + +import numpy as np +import tvm +from tvm import relax, te +from tvm.relax.op import ccl, reshape, expand_dims, concat, zeros, repeat, take +from tvm.relax.op.nn import attention_var_len +from tvm.relax.testing import nn +from tvm.ir import VDevice +from tvm.script import relax as R +from tvm.script.ir_builder import tir as T + +from ..quantization import QuantizationScheme +from .modules import ModuleList +from .param_manager import ParamManager +from .llama import ( + LlamaConfig, + Linear, + Embedding, + LlamaRMSNorm, + LlamaAttentionBase, + LlamaDecoderLayer, + get_param_quant_kind, + setup_params, + rotary_modulate_by_freq, +) + + +def apply_rotary_pos_emb(q, k, positions, position_embedding_base): + def f_rotary_embedding(tensor, pos_tensor): + def rotary_compute(*idx): + pos = pos_tensor[idx[0]].astype("float32") + return rotary_modulate_by_freq( + tensor, + idx, + pos, + position_embedding_base, + ) + + return tvm.te.compute(tensor.shape, rotary_compute, name="rotary") + + q_embed = nn.emit_te(f_rotary_embedding, q, positions, primfunc_name_hint="rotary_embedding") + k_embed = nn.emit_te(f_rotary_embedding, k, positions, primfunc_name_hint="rotary_embedding") + return q_embed, k_embed + + +class LlamaAttentionBatched(LlamaAttentionBase): + def __init__(self, config: LlamaConfig, head_mapping: relax.Constant): + super().__init__(config) + self.head_mapping = head_mapping # (num_heads,), used by vLLM for multi-query attention + self.sliding_window = None + + if config.sliding_window: + self.sliding_window = T.IntImm("int32", config.sliding_window) + + def forward( + self, + hidden_states: relax.Expr, # (num_token, hidden_size) + positions: relax.Expr, # (num_token,), for batched RoPE + seq_lens: relax.Expr, # (num_seq,) + kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], + slot_mapping: Optional[relax.Expr], # (num_token,) + max_seqlen: Optional[relax.Expr], # (), must be on CPU + seqstart: Optional[relax.Expr], # (num_seq + 1,), for prefill + block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode + indices_within_window: Optional[ + relax.Expr + ], # (num_cached_total,), for prefill with sliding-window attention + ): + num_tokens, _ = hidden_states.struct_info.shape + + queries, keys, values = self.project_qkv( + hidden_states, + (num_tokens, self.num_query_heads, self.head_dim), + (num_tokens, self.num_key_value_heads, self.head_dim), + ) + + queries, keys = apply_rotary_pos_emb(queries, keys, positions, self.position_embedding_base) + + if kv_cache: + # Paged KV cache update + k_cache, v_cache = kv_cache + + if self.sliding_window is None or block_tables: + # For decode or prefill without sliding window, cache all keys / values. + keys_to_cache = keys + values_to_cache = values + else: + # Cache only the most recent keys and values within the window. + keys_to_cache = nn.emit(take(keys, indices_within_window, axis=0)) + values_to_cache = nn.emit(take(values, indices_within_window, axis=0)) + slot_mapping = nn.emit(take(slot_mapping, indices_within_window, axis=0)) + + # kv caches are updated inplace, but make it look like a pure operation + kv = nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.vllm.reshape_and_cache", + keys_to_cache, + values_to_cache, + k_cache, + v_cache, + slot_mapping, + sinfo_args=[k_cache.struct_info, v_cache.struct_info], + ) + ) + + k_cache, v_cache = kv[0], kv[1] + else: + k_cache = v_cache = None + + if seqstart: + # Prefill, batched attention over variable sequence lengths + attn_output = nn.emit( + attention_var_len( + nn.emit(expand_dims(queries, axis=0)), + nn.emit(expand_dims(keys, axis=0)), + nn.emit(expand_dims(values, axis=0)), + seqstart_q=seqstart, + max_seqlen_q=max_seqlen, + causal_mask="BottomRight", + window_size=self.sliding_window, + ) + ) + else: + # Decode, using vLLM kernel + attn_output = nn.emit( + relax.op.call_dps_packed( + "tvm.contrib.vllm.single_query_cached_kv_attention", + [ + queries, + k_cache, + v_cache, + self.head_mapping, + block_tables, + seq_lens, + 16, # block_size + max_seqlen, + ], + out_sinfo=queries.struct_info, + ) + ) + + attn_output = nn.emit( + reshape(attn_output, (num_tokens, self.num_query_heads * self.head_dim)) + ) + attn_output = self.o_proj(attn_output) + + return attn_output, (k_cache, v_cache) + + +class LlamaDecoderLayerBatched(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig, head_mapping: relax.Constant): + super().__init__(config, False) + self.self_attn = LlamaAttentionBatched(config, head_mapping) + + def forward( + self, + hidden_states: relax.Expr, + positions: relax.Expr, + seq_lens: relax.Expr, + kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], + slot_mapping: Optional[relax.Expr], + max_seqlen: Optional[relax.Expr], + seqstart: Optional[relax.Expr], + block_tables: Optional[relax.Expr], + indices_within_window: Optional[relax.Expr], + ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, new_kv = self.self_attn( + hidden_states=hidden_states, + positions=positions, + seq_lens=seq_lens, + kv_cache=kv_cache, + slot_mapping=slot_mapping, + max_seqlen=max_seqlen, + seqstart=seqstart, + block_tables=block_tables, + indices_within_window=indices_within_window, + ) + + hidden_states = self.post_self_attn(hidden_states, residual) + + return hidden_states, new_kv + + +class LlamaModel(nn.Module): + def __init__( + self, + config: LlamaConfig, + cpu_device: VDevice, + vocab_size_var: tvm.tir.Var, + sep_embed: bool = False, + ): + self.padding_idx = config.pad_token_id + self.embed_tokens = None + + num_query_heads = config.num_attention_heads // config.num_shards + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + num_queries_per_kv = num_query_heads // num_key_value_heads + head_mapping = relax.const( + tvm.nd.array( + np.repeat(np.arange(num_key_value_heads, dtype="int32"), num_queries_per_kv) + ) + ) + + if not sep_embed: + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + self.layers = ModuleList( + [ + LlamaDecoderLayerBatched(config, head_mapping) + for _ in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) + + self.cpu_device = cpu_device + + def forward( + self, + inputs: relax.Expr, + positions: relax.Expr, + seq_lens: relax.Expr, + kv_caches: Optional[relax.Expr], + slot_mapping: Optional[relax.Expr], + seqstart: Optional[relax.Expr], + block_tables: Optional[relax.Expr], + indices_within_window: Optional[relax.Expr], + ): + if self.embed_tokens: + inputs_embeds = self.embed_tokens(inputs) + else: + inputs_embeds = inputs + + hidden_states = inputs_embeds + + # max_seqlen needs to be on CPU, so that vLLM and Flash Attention can directly get the + # integer length by max_seqlen->data[0]. Otherwise, we need to repeatedly do cudaMemcpy + # of a single int32. + max_seqlen = R.to_vdevice(R.max(seq_lens), self.cpu_device) + + new_kvs = () + + for idx, decoder_layer in enumerate(self.layers): + if kv_caches: + cache = (kv_caches[2 * idx], kv_caches[2 * idx + 1]) + else: + cache = None + + hidden_states, new_kv = decoder_layer( + hidden_states, + positions, + seq_lens, + cache, + slot_mapping, + max_seqlen, + seqstart, + block_tables, + indices_within_window, + ) + new_kvs += new_kv + + return self.norm(hidden_states), new_kvs + + +class LlamaForCausalLM(nn.Module): + def __init__( + self, + config: LlamaConfig, + cpu_device: VDevice, + vocab_size_var: tvm.tir.Var, + sep_embed: bool = False, + ): + self.num_shards = config.num_shards + self.model = LlamaModel(config, cpu_device, vocab_size_var, sep_embed) + self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) + + ############ Rotary embedding constants ############ + assert config.hidden_size % config.num_attention_heads == 0 + head_dim = config.hidden_size // config.num_attention_heads + + # Set the cached sin/cos to the maximum of 2048 and max seq len. + # This will be eliminated further with online rotary embedding calculation. + cache_len = te.var("cache_len", "int64") + self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached") + self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached") + ############ End ############ + + def forward( + self, + input_ids: relax.Expr, # (num_token,) + positions: relax.Expr, # (num_token,), for batched RoPE + seq_lens: relax.Expr, # (num_seq,) + kv_caches: Optional[relax.Expr], # For prefill and decode, not needed for evaluate + slot_mapping: Optional[ + relax.Expr + ], # (num_token,), for prefill and decode, not needed for evaluate + block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode + indices_within_window: Optional[ + relax.Expr + ], # (num_cached_total,), for prefill with sliding-window attention + ): + """ + In vLLM, the paged KV cache is simply a pair of tensors, one for keys and the other + for values. The tensor has shape (num_blocks, num_kv_heads, head_size, block_size). + (In practice, the key cache has a slightly different shape for an efficiency reason, + but that's not important.) + + The mapping between sequences / tokens to blocks is specified by two inputs. + - block_tables: A list of block IDs allocated for the sequence. + - slot_mapping: A linear index into the 2D grid (num_blocks, block_size), for each token. + + Support for sliding-window attention is realized by making a block table a circular buffer. + So the length of a block table for each sequence is at most ceil(window_size / block_size). + + With sliding window, not all past K / V values need to be cached during prefill. + The last input, indices_within_window, tells which tokens among (num_token,) need to have + their K / V values cached. + """ + if self.num_shards > 1: + input_ids = nn.emit(ccl.broadcast_from_worker0(input_ids)) + positions = nn.emit(ccl.broadcast_from_worker0(positions)) + seq_lens = nn.emit(ccl.broadcast_from_worker0(seq_lens)) + + if slot_mapping: + slot_mapping = nn.emit(ccl.broadcast_from_worker0(slot_mapping)) + + if block_tables: + block_tables = nn.emit(ccl.broadcast_from_worker0(block_tables)) + + if indices_within_window: + indices_within_window = nn.emit(ccl.broadcast_from_worker0(indices_within_window)) + + is_prompt = block_tables is None + + if is_prompt: # prefill and evaluate + # https://github.com/apache/tvm/issues/15851 for why we need to use Thrust + cumsum = nn.emit( + relax.op.call_dps_packed( + "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info + ) + ) + seqstart = nn.emit(concat([zeros((1,), "int32"), cumsum])) + else: + seqstart = None + + hidden_states, new_kvs = self.model( + input_ids, + positions, + seq_lens, + kv_caches, + slot_mapping, + seqstart, + block_tables, + indices_within_window, + ) + + if is_prompt: + # Extract logits for the last token in each sequence + + def get_logits_last_tokens(x, seq_len_tensor, seqstart): + return te.compute( + shape=(seq_len_tensor.shape[0], x.shape[-1]), + fcompute=lambda i, j: x[seqstart[i] + seq_len_tensor[i] - 1, j], + name="get_logits_last_tokens", + ) + + logits = self.lm_head( + nn.emit_te( + get_logits_last_tokens, + hidden_states, + seq_lens, + seqstart, + primfunc_name_hint="get_logits_last_tokens", + ) + ) + else: + logits = self.lm_head(hidden_states) + + if logits.struct_info.dtype != "float32": + logits = nn.emit(relax.op.astype(logits, "float32")) + + return logits, new_kvs + + +def get_inputs( + num_token, num_seq, config, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True +): + hidden_size = config.hidden_size + + inputs = ( + nn.Placeholder((num_token, hidden_size), dtype=config.dtype, name="inputs_embeds") + if sep_embed + else nn.Placeholder((num_token,), dtype="int32", name="input_ids") + ) + + seq_lens = nn.Placeholder((num_seq,), dtype="int32", name="seq_lens") + positions = nn.Placeholder((num_token,), dtype="int32", name="positions") + + if need_cache: + num_blocks = tvm.tir.Var("num_blocks", "int64") + block_size = 16 + + vec_size = 8 # 128 bit, fp16 x 8 + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + head_size = hidden_size // config.num_attention_heads + + k_cache_shape = ( + num_blocks, + num_key_value_heads, + head_size // vec_size, + block_size, + vec_size, + ) + v_cache_shape = (num_blocks, num_key_value_heads, head_size, block_size) + + get_cache_sinfo = lambda i: relax.TensorStructInfo( + k_cache_shape if i % 2 == 0 else v_cache_shape, dtype="float16" + ) + + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [get_cache_sinfo(i) for i in range(config.num_hidden_layers * 2)] + ), + ) + slot_mapping = nn.Placeholder((num_token,), dtype="int32", name="slot_mapping") + else: + past_key_values = None + slot_mapping = None + block_tables = None + + if max_num_blocks_per_seq is None: + block_tables = None + else: + block_tables = nn.Placeholder( + (num_seq, max_num_blocks_per_seq), dtype="int32", name="block_tables" + ) + + return inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables + + +def create_evaluate_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + cpu_dev: VDevice, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + """Evaluate logits for the last token in each sequence. Same as prefill but without KV cache.""" + func_name = "evaluate" + + num_token = tvm.tir.Var("num_token", "int64") + num_seq = tvm.tir.Var("num_seq", "int64") + + with bb.function(func_name): + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), sep_embed) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs, positions, seq_lens, _, _, _ = get_inputs( + num_token, num_seq, config, sep_embed=sep_embed + ) + + with bb.dataflow(): + logits, _ = model( + inputs, + positions, + seq_lens, + kv_caches=None, + slot_mapping=None, + block_tables=None, + indices_within_window=None, + ) + params = [ + inputs, + positions, + seq_lens, + ] + model.parameters() + gv = bb.emit_output(logits) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_encoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + cpu_dev: VDevice, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + """Batched prefill with vLLM paged KV cache. + + The batched attention op is intended to be offloaded to CUTLASS or Flash Attention + via BYOC. + """ + func_name = "prefill_with_embed" if sep_embed else "prefill" + + num_token = tvm.tir.Var("num_token", "int64") + num_seq = tvm.tir.Var("num_seq", "int64") + + num_inputs = 5 + + with bb.function(func_name): + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), sep_embed) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( + num_token, num_seq, config, sep_embed=sep_embed + ) + + with bb.dataflow(): + params = [ + input_ids, + positions, + seq_lens, + past_key_values, + slot_mapping, + ] + + inputs = [ + input_ids, + positions, + seq_lens, + past_key_values, + slot_mapping, + None, # block_tables + ] + + if config.sliding_window: + num_inputs += 1 + # The value of num_cached_total is between + # num_token (if seq_len < sliding_window for all seq) and + # num_seq * config.sliding_window (if seq_len > sliding_window for all seq) + num_cached_total = tvm.tir.Var("num_cached_total", "int64") + indices_within_window = nn.Placeholder( + (num_cached_total,), dtype="int32", name="indices_within_window" + ) + inputs.append(indices_within_window) + params.append(indices_within_window) + else: + inputs.append(None) + + logits, new_kvs = model(*inputs) + gv = bb.emit_output((logits, relax.Tuple(new_kvs))) + + bb.emit_func_output(gv, params + model.parameters()) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", num_inputs)) + + +def create_decoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + cpu_dev: VDevice, + quant_scheme: QuantizationScheme, +) -> None: + """Batched decoding with vLLM paged KV cache.""" + func_name = "decode" + + num_seq = tvm.tir.Var("num_seq", "int64") + max_num_blocks_per_seq = tvm.tir.Var("max_num_blocks_per_seq", "int64") + + with bb.function(func_name): + inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables = get_inputs( + num_seq, num_seq, config, max_num_blocks_per_seq + ) + + with bb.dataflow(): + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + logits, new_kvs = model( + inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables, None + ) + params = [ + inputs, + positions, + seq_lens, + past_key_values, + slot_mapping, + block_tables, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(new_kvs))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 6)) + + +def get_model(args, hf_config): + dtype = args.quantization.model_dtype + sep_embed = False + + position_embedding_base = 10000 + + if "rope_theta" in hf_config: + position_embedding_base = hf_config["rope_theta"] + + # Llama-2 variants use `max_position_embeddings` to encode maximum sequence length in their hf model cards, + # while Llama-1 variants use `max_sequence_length`. + # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. + # If none of them is defined, throw an error. + if "max_sequence_length" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + elif "max_position_embeddings" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + max_sequence_length=hf_config["max_position_embeddings"], + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + else: + raise Exception( + "The model config should contain information about maximum sequence length." + ) + + # If there is a user-provided maximum sequence length, override hf config. + if args.max_seq_len != -1: + config.max_sequence_length = args.max_seq_len + + param_manager = ParamManager() + bb = relax.BlockBuilder() + + # The CPU device to copy the result of relax.op.max(seq_lens) to CPU. + cpu_dev = VDevice("llvm", 0, "global") + + create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) + create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) + create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization) + + mod = bb.get() + + mod.update_global_info("vdevice", [cpu_dev]) + + if args.build_model_only: + return mod, param_manager, None, config + + return setup_params(mod, param_manager, dtype, config, args)