From e0ef4c6f61ba2682039f81a13c31fc1730c0f92e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 00:22:51 +0000 Subject: [PATCH 01/39] add new model for evaluating logits over multiple queries using KV cache --- mlc_llm/relax_model/llama_batched_vllm.py | 299 +++++++++++++++++----- 1 file changed, 238 insertions(+), 61 deletions(-) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 67e0e12f90..0d1ad13bcd 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -1,9 +1,11 @@ from typing import Optional, Tuple +from dataclasses import dataclass + import numpy as np import tvm from tvm import relax, te -from tvm.relax.op import ccl, reshape, expand_dims, concat, zeros, repeat, take +from tvm.relax.op import ccl, reshape, expand_dims, concat, zeros, take, concat from tvm.relax.op.nn import attention_var_len from tvm.relax.testing import nn from tvm.ir import VDevice @@ -45,6 +47,17 @@ def rotary_compute(*idx): return q_embed, k_embed +@dataclass +class EvaluateMultiQueryInput: + query_start: relax.Expr # (num_query_token + 1,) + max_query_len: relax.Expr # (), must be on CPU + # The followings are only needed for our naive implementation of multi-query eval + # with paged KV cache. They can be replaced with block_tables when a proper attention + # kernel becomes available. + past_slot_mapping: relax.Expr # (num_past_token,) + permute_indices_after_concat: relax.Expr # (num_past_token + num_query_token,) + + class LlamaAttentionBatched(LlamaAttentionBase): def __init__(self, config: LlamaConfig): super().__init__(config) @@ -58,24 +71,25 @@ def __init__(self, config: LlamaConfig): def forward( self, - hidden_states: relax.Expr, # (num_token, hidden_size) - positions: relax.Expr, # (num_token,), for batched RoPE + hidden_states: relax.Expr, # (num_query_token, hidden_size) + positions: relax.Expr, # (num_query_token,), for batched RoPE seq_lens: relax.Expr, # (num_seq,) kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], - slot_mapping: Optional[relax.Expr], # (num_token,) + slot_mapping: Optional[relax.Expr], # (num_query_token,) max_seqlen: Optional[relax.Expr], # (), must be on CPU - seqstart: Optional[relax.Expr], # (num_seq + 1,), for prefill + seq_start: Optional[relax.Expr], # (num_seq + 1,), for prefill block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode indices_within_window: Optional[ relax.Expr - ], # (num_cached_total,), for prefill with sliding-window attention + ], # (num_cached_total,), for prefill with sliding-window attention, + eval_multi_input: Optional[EvaluateMultiQueryInput], ): - num_tokens, _ = hidden_states.struct_info.shape + num_query_tokens, _ = hidden_states.struct_info.shape queries, keys, values = self.project_qkv( hidden_states, - (num_tokens, self.num_query_heads, self.head_dim), - (num_tokens, self.num_key_value_heads, self.head_dim), + (num_query_tokens, self.num_query_heads, self.head_dim), + (num_query_tokens, self.num_key_value_heads, self.head_dim), ) queries, keys = apply_rotary_pos_emb(queries, keys, positions, self.position_embedding_base) @@ -84,15 +98,15 @@ def forward( # Paged KV cache update k_cache, v_cache = kv_cache - if self.sliding_window is None or block_tables: - # For decode or prefill without sliding window, cache all keys / values. - keys_to_cache = keys - values_to_cache = values - else: + if indices_within_window: # Cache only the most recent keys and values within the window. keys_to_cache = nn.emit(take(keys, indices_within_window, axis=0)) values_to_cache = nn.emit(take(values, indices_within_window, axis=0)) slot_mapping = nn.emit(take(slot_mapping, indices_within_window, axis=0)) + else: + # For decode or prefill without sliding window, cache all keys / values. + keys_to_cache = keys + values_to_cache = values # kv caches are updated inplace, but make it look like a pure operation kv = nn.emit( @@ -111,15 +125,65 @@ def forward( else: k_cache = v_cache = None - if seqstart: - # Prefill, batched attention over variable sequence lengths + if eval_multi_input: + assert k_cache and v_cache + num_kv_head = v_cache.struct_info.shape[1] + head_size = v_cache.struct_info.shape[2] + num_past_token = eval_multi_input.past_slot_mapping.struct_info.shape[0] + kv_shape = (num_past_token, num_kv_head, head_size) + kv_sinfo = relax.TensorStructInfo(kv_shape, k_cache.struct_info.dtype) + + kv_tensors = nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.vllm.reconstruct_from_cache", + k_cache, + v_cache, + eval_multi_input.past_slot_mapping, + sinfo_args=[kv_sinfo, kv_sinfo], + ) + ) + keys_past, values_past = kv_tensors[0], kv_tensors[1] + # Say we have past tokens [P1, P2, P3] and the current ones [C1, C2, C3]. + # Each of P1, C1 etc is a sequence of tokens. + # After concat, we have [P1, P2, P3, C1, C2, C3], but batched sequences need to + # be in the format [P1, C1, P2, C2, P3, C3]. This permutation is done by the take + # op and the provided permutation indices. + keys = nn.emit( + take( + concat([keys_past, keys]), eval_multi_input.permute_indices_after_concat, axis=0 + ) + ) + values = nn.emit( + take( + concat([values_past, values]), + eval_multi_input.permute_indices_after_concat, + axis=0, + ) + ) + seq_start_q = eval_multi_input.query_start + max_seqlen_q = eval_multi_input.max_query_len + seq_start_k = seq_start + max_seqlen_k = max_seqlen + elif seq_start: + # prefill + seq_start_q = seq_start_k = seq_start + max_seqlen_q = max_seqlen_k = max_seqlen + else: + # decode + seq_start_q = seq_start_k = None + max_seqlen_q = max_seqlen_k = None + + if seq_start_q: + # Prefill or multi-query evaluation, batched attention over variable sequence lengths attn_output = nn.emit( attention_var_len( nn.emit(expand_dims(queries, axis=0)), nn.emit(expand_dims(keys, axis=0)), nn.emit(expand_dims(values, axis=0)), - seqstart_q=seqstart, - max_seqlen_q=max_seqlen, + seq_start_q, + max_seqlen_q, + seq_start_k, + max_seqlen_k, causal_mask="BottomRight", window_size=self.sliding_window, ) @@ -128,14 +192,14 @@ def forward( # Decode, using vLLM kernel exp_sums = nn.emit( relax.op.builtin.alloc_tensor( - relax.ShapeExpr((num_tokens, self.num_query_heads, self.max_num_partitions)), + relax.ShapeExpr((num_query_tokens, self.num_query_heads, self.max_num_partitions)), dtype="float32", runtime_device_index=0, ) ) max_logits = nn.emit( relax.op.builtin.alloc_tensor( - relax.ShapeExpr((num_tokens, self.num_query_heads, self.max_num_partitions)), + relax.ShapeExpr((num_query_tokens, self.num_query_heads, self.max_num_partitions)), dtype="float32", runtime_device_index=0, ) @@ -143,7 +207,7 @@ def forward( tmp_out = nn.emit( relax.op.builtin.alloc_tensor( relax.ShapeExpr( - (num_tokens, self.num_query_heads, self.max_num_partitions, self.head_dim) + (num_query_tokens, self.num_query_heads, self.max_num_partitions, self.head_dim) ), dtype=queries.struct_info.dtype, runtime_device_index=0, @@ -169,7 +233,7 @@ def forward( ) attn_output = nn.emit( - reshape(attn_output, (num_tokens, self.num_query_heads * self.head_dim)) + reshape(attn_output, (num_query_tokens, self.num_query_heads * self.head_dim)) ) attn_output = self.o_proj(attn_output) @@ -189,9 +253,10 @@ def forward( kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], slot_mapping: Optional[relax.Expr], max_seqlen: Optional[relax.Expr], - seqstart: Optional[relax.Expr], + seq_start: Optional[relax.Expr], block_tables: Optional[relax.Expr], indices_within_window: Optional[relax.Expr], + eval_multi_input: Optional[EvaluateMultiQueryInput], ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: residual = hidden_states @@ -205,9 +270,10 @@ def forward( kv_cache=kv_cache, slot_mapping=slot_mapping, max_seqlen=max_seqlen, - seqstart=seqstart, + seq_start=seq_start, block_tables=block_tables, indices_within_window=indices_within_window, + eval_multi_input=eval_multi_input, ) hidden_states = self.post_self_attn(hidden_states, residual) @@ -215,12 +281,22 @@ def forward( return hidden_states, new_kv +def create_seq_start(seq_lens): + # https://github.com/apache/tvm/issues/15851 for why we need to use Thrust + cumsum = nn.emit( + relax.op.call_dps_packed( + "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info + ) + ) + return nn.emit(concat([zeros((1,), "int32"), cumsum])) + + class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, cpu_device: VDevice, - vocab_size_var: tvm.tir.SizeVar, + vocab_size_var: tvm.tir.Var, sep_embed: bool = False, ): self.padding_idx = config.pad_token_id @@ -247,9 +323,12 @@ def forward( seq_lens: relax.Expr, kv_caches: Optional[relax.Expr], slot_mapping: Optional[relax.Expr], - seqstart: Optional[relax.Expr], + seq_start: Optional[relax.Expr], block_tables: Optional[relax.Expr], indices_within_window: Optional[relax.Expr], + query_lens: Optional[relax.Expr], + past_slot_mapping: Optional[relax.Expr], + permute_indices_after_concat: Optional[relax.Expr], ): if self.embed_tokens: inputs_embeds = self.embed_tokens(inputs) @@ -265,6 +344,15 @@ def forward( new_kvs = () + if query_lens: + max_query_len = R.to_vdevice(R.max(query_lens), self.cpu_device) + query_start = create_seq_start(query_lens) + eval_multi_input = EvaluateMultiQueryInput( + query_start, max_query_len, past_slot_mapping, permute_indices_after_concat + ) + else: + eval_multi_input = None + for idx, decoder_layer in enumerate(self.layers): if kv_caches: cache = (kv_caches[2 * idx], kv_caches[2 * idx + 1]) @@ -278,9 +366,10 @@ def forward( cache, slot_mapping, max_seqlen, - seqstart, + seq_start, block_tables, indices_within_window, + eval_multi_input, ) new_kvs += new_kv @@ -312,17 +401,18 @@ def __init__( def forward( self, - input_ids: relax.Expr, # (num_token,) - positions: relax.Expr, # (num_token,), for batched RoPE + input_ids: relax.Expr, # (num_query_token,) + positions: relax.Expr, # (num_query_token,), for batched RoPE seq_lens: relax.Expr, # (num_seq,) kv_caches: Optional[relax.Expr], # For prefill and decode, not needed for evaluate - slot_mapping: Optional[ - relax.Expr - ], # (num_token,), for prefill and decode, not needed for evaluate + slot_mapping: Optional[relax.Expr], # (num_query_token,), Not needed for evaluate block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode indices_within_window: Optional[ relax.Expr ], # (num_cached_total,), for prefill with sliding-window attention + query_lens: Optional[relax.Expr], + past_slot_mapping: Optional[relax.Expr], + permute_indices_after_concat: Optional[relax.Expr], ): """ In vLLM, the paged KV cache is simply a pair of tensors, one for keys and the other @@ -338,7 +428,7 @@ def forward( So the length of a block table for each sequence is at most ceil(window_size / block_size). With sliding window, not all past K / V values need to be cached during prefill. - The last input, indices_within_window, tells which tokens among (num_token,) need to have + The last input, indices_within_window, tells which tokens among (num_query_token,) need to have their K / V values cached. """ if self.num_shards > 1: @@ -355,18 +445,21 @@ def forward( if indices_within_window: indices_within_window = nn.emit(ccl.broadcast_from_worker0(indices_within_window)) - is_prompt = block_tables is None - - if is_prompt: # prefill and evaluate - # https://github.com/apache/tvm/issues/15851 for why we need to use Thrust - cumsum = nn.emit( - relax.op.call_dps_packed( - "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info + if query_lens: + query_lens = nn.emit(ccl.broadcast_from_worker0(query_lens)) + past_slot_mapping = nn.emit(ccl.broadcast_from_worker0(past_slot_mapping)) + permute_indices_after_concat = nn.emit( + ccl.broadcast_from_worker0(permute_indices_after_concat) ) - ) - seqstart = nn.emit(concat([zeros((1,), "int32"), cumsum])) + + # TODO: Update this condition for evaluate multi + is_prompt = block_tables is None and query_lens is None + is_eval_multi = query_lens is not None + + if is_prompt or is_eval_multi: # prefill and evaluate + seq_start = create_seq_start(seq_lens) else: - seqstart = None + seq_start = None hidden_states, new_kvs = self.model( input_ids, @@ -374,18 +467,21 @@ def forward( seq_lens, kv_caches, slot_mapping, - seqstart, + seq_start, block_tables, indices_within_window, + query_lens, + past_slot_mapping, + permute_indices_after_concat, ) if is_prompt: # Extract logits for the last token in each sequence - def get_logits_last_tokens(x, seq_len_tensor, seqstart): + def get_logits_last_tokens(x, seq_len_tensor, seq_start): return te.compute( shape=(seq_len_tensor.shape[0], x.shape[-1]), - fcompute=lambda i, j: x[seqstart[i] + seq_len_tensor[i] - 1, j], + fcompute=lambda i, j: x[seq_start[i] + seq_len_tensor[i] - 1, j], name="get_logits_last_tokens", ) @@ -394,7 +490,7 @@ def get_logits_last_tokens(x, seq_len_tensor, seqstart): get_logits_last_tokens, hidden_states, seq_lens, - seqstart, + seq_start, primfunc_name_hint="get_logits_last_tokens", ) ) @@ -408,21 +504,21 @@ def get_logits_last_tokens(x, seq_len_tensor, seqstart): def get_inputs( - num_token, num_seq, config, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True + num_query_token, num_seq, config, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True ): hidden_size = config.hidden_size inputs = ( - nn.Placeholder((num_token, hidden_size), dtype=config.dtype, name="inputs_embeds") + nn.Placeholder((num_query_token, hidden_size), dtype=config.dtype, name="inputs_embeds") if sep_embed - else nn.Placeholder((num_token,), dtype="int32", name="input_ids") + else nn.Placeholder((num_query_token,), dtype="int32", name="input_ids") ) seq_lens = nn.Placeholder((num_seq,), dtype="int32", name="seq_lens") - positions = nn.Placeholder((num_token,), dtype="int32", name="positions") + positions = nn.Placeholder((num_query_token,), dtype="int32", name="positions") if need_cache: - num_blocks = tvm.tir.SizeVar("num_blocks", "int64") + num_blocks = tvm.tir.Var("num_blocks", "int64") block_size = 16 vec_size = 8 # 128 bit, fp16 x 8 @@ -448,7 +544,7 @@ def get_inputs( [get_cache_sinfo(i) for i in range(config.num_hidden_layers * 2)] ), ) - slot_mapping = nn.Placeholder((num_token,), dtype="int32", name="slot_mapping") + slot_mapping = nn.Placeholder((num_query_token,), dtype="int32", name="slot_mapping") else: past_key_values = None slot_mapping = None @@ -475,15 +571,15 @@ def create_evaluate_func( """Evaluate logits for the last token in each sequence. Same as prefill but without KV cache.""" func_name = "evaluate" - num_token = tvm.tir.SizeVar("num_token", "int64") + num_query_token = tvm.tir.SizeVar("num_query_token", "int64") num_seq = tvm.tir.SizeVar("num_seq", "int64") with bb.function(func_name): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed) + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), sep_embed) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) inputs, positions, seq_lens, _, _, _ = get_inputs( - num_token, num_seq, config, sep_embed=sep_embed + num_query_token, num_seq, config, sep_embed=sep_embed ) with bb.dataflow(): @@ -495,6 +591,9 @@ def create_evaluate_func( slot_mapping=None, block_tables=None, indices_within_window=None, + query_lens=None, + past_slot_mapping=None, + permute_indices_after_concat=None, ) params = [ inputs, @@ -524,7 +623,7 @@ def create_encoding_func( """ func_name = "prefill_with_embed" if sep_embed else "prefill" - num_token = tvm.tir.SizeVar("num_token", "int64") + num_query_token = tvm.tir.SizeVar("num_query_token", "int64") num_seq = tvm.tir.SizeVar("num_seq", "int64") num_inputs = 5 @@ -534,7 +633,7 @@ def create_encoding_func( param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( - num_token, num_seq, config, sep_embed=sep_embed + num_query_token, num_seq, config, sep_embed=sep_embed ) with bb.dataflow(): @@ -558,9 +657,9 @@ def create_encoding_func( if config.sliding_window: num_inputs += 1 # The value of num_cached_total is between - # num_token (if seq_len < sliding_window for all seq) and + # num_query_token (if seq_len < sliding_window for all seq) and # num_seq * config.sliding_window (if seq_len > sliding_window for all seq) - num_cached_total = tvm.tir.SizeVar("num_cached_total", "int64") + num_cached_total = tvm.tir.Var("num_cached_total", "int64") indices_within_window = nn.Placeholder( (num_cached_total,), dtype="int32", name="indices_within_window" ) @@ -569,6 +668,8 @@ def create_encoding_func( else: inputs.append(None) + inputs += [None, None, None] + logits, new_kvs = model(*inputs) gv = bb.emit_output((logits, relax.Tuple(new_kvs))) @@ -602,7 +703,16 @@ def create_decoding_func( param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) logits, new_kvs = model( - inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables, None + inputs, + positions, + seq_lens, + past_key_values, + slot_mapping, + block_tables, + None, + None, + None, + None, ) params = [ inputs, @@ -620,6 +730,72 @@ def create_decoding_func( bb.update_func(gv, mod[gv].with_attr("num_input", 6)) +def create_evaluate_multi_query_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + cpu_dev: VDevice, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "evaluate_multi_query" + + num_query_token = tvm.tir.SizeVar("num_query_token", "int64") + num_past_token = tvm.tir.SizeVar("num_past_token", "int64") + num_seq = tvm.tir.SizeVar("num_seq", "int64") + seq_lens_sum = tvm.tir.SizeVar("seq_lens_sum", "int64") + + num_inputs = 8 + + with bb.function(func_name): + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), False) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( + num_query_token, num_seq, config, sep_embed=False + ) + + query_lens = nn.Placeholder((num_seq,), dtype="int32", name="query_lens") + + # Replace them with block_tables when a proper attention kernel becomes available. + past_slot_mapping = nn.Placeholder( + (num_past_token,), dtype="int32", name="past_slot_mapping" + ) + permute_indices_after_concat = nn.Placeholder( + (seq_lens_sum,), dtype="int32", name="permute_indices_after_concat" + ) + + with bb.dataflow(): + params = [ + input_ids, + positions, + seq_lens, + past_key_values, + slot_mapping, + ] + + inputs = [ + input_ids, + positions, + seq_lens, + past_key_values, + slot_mapping, + None, # block_tables + None, # indices_within_window + ] + + inputs += [query_lens, past_slot_mapping, permute_indices_after_concat] + params += [query_lens, past_slot_mapping, permute_indices_after_concat] + + logits, new_kvs = model(*inputs) + gv = bb.emit_output((logits, relax.Tuple(new_kvs))) + + bb.emit_func_output(gv, params + model.parameters()) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", num_inputs)) + + def get_model(args, hf_config): dtype = args.quantization.model_dtype sep_embed = False @@ -685,6 +861,7 @@ def get_model(args, hf_config): create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization) + create_evaluate_multi_query_func(bb, param_manager, config, cpu_dev, args.quantization) mod = bb.get() From 4ccbb27963d263158bb81c23f49cdfaf64770a29 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 00:24:08 +0000 Subject: [PATCH 02/39] add test --- examples/python/run_llama_batched_vllm.py | 130 ++++++++++++++++++++++ 1 file changed, 130 insertions(+) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index a290eb892c..0307278d44 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -117,6 +117,13 @@ class SequenceGenerationResponse: token_id: int +@dataclass +class EvalQueryRequest: + request_id: int + past_token_ids: List[int] + query_token_ids: List[int] + + def sample(logits): logits = torch.from_dlpack(logits) return torch.argmax(logits, -1).cpu().numpy() @@ -241,6 +248,76 @@ def _pad_to_max(x: List[int], max_len: int) -> List[int]: ) +def _prepare_eval_queries( + requests: List[EvalQueryRequest], + all_slot_mappings, + sliding_window, + dev, +): + seq_lens = [] + query_lens = [] + input_ids = [] + slot_mapping = [] + past_slot_mapping = [] + positions = [] + permute_map = [] + + query_offset = sum([len(request.past_token_ids) for request in requests]) + past_offset = 0 + + for request in requests: + num_past_tokens = len(request.past_token_ids) + num_queries = len(request.query_token_ids) + query_lens.append(num_queries) + request_id = request.request_id + input_ids += request.query_token_ids + + positions += [num_past_tokens + i for i in range(num_queries)] + + if sliding_window: + seq_lens.append(min(num_past_tokens + num_queries, sliding_window)) + # TODO: verify this + past_slot_mapping += all_slot_mappings[request_id][ + : min(num_past_tokens, sliding_window) + ] + slot_mapping += all_slot_mappings[request_id][ + min(num_past_tokens, sliding_window) : min(num_past_tokens, sliding_window) + + num_queries + ] + else: + seq_lens.append(num_past_tokens + num_queries) + past_slot_mapping += all_slot_mappings[request_id][:num_past_tokens] + slot_mapping += all_slot_mappings[request_id][ + num_past_tokens : num_past_tokens + num_queries + ] + + permute_map += list(range(past_offset, past_offset + num_past_tokens)) + list( + range(query_offset, query_offset + num_queries) + ) + + query_offset += num_queries + past_offset += num_past_tokens + + input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), dev) + positions = tvm.nd.array(np.array(positions, dtype="int32"), dev) + seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), dev) + slot_mapping = tvm.nd.array(np.array(slot_mapping, dtype="int32"), dev) + + query_lens = tvm.nd.array(np.array(query_lens, dtype="int32"), dev) + past_slot_mapping = tvm.nd.array(np.array(past_slot_mapping, dtype="int32"), dev) + permute_map = tvm.nd.array(np.array(permute_map, dtype="int32"), dev) + + return ( + input_ids, + positions, + seq_lens, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + ) + + class Model: def __init__( self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window @@ -443,6 +520,59 @@ def run(args): for p, g in zip(prompts, generated): print("Prompt = '{}', generated text = '{}'".format(p, g)) + query_token_lens = [4, 3, 5, 2] + + eval_query_requests = [] + + for request_id, query_token_len in zip(request_ids, query_token_lens): + queries_to_eval = requests[request_id].token_ids[-query_token_len:] + past_tokens = requests[request_id].token_ids[:-query_token_len] + eval_query_requests.append(EvalQueryRequest(request_id, past_tokens, queries_to_eval)) + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + ) = _prepare_eval_queries( + eval_query_requests, + cache.slot_mappings, + None, + model.dev, + ) + + logits = model.mod["evaluate_multi_query"]( + input_ids, + positions, + seq_lens, + cache.cache, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + model.params, + )[0].numpy() + + assert logits.shape[0] == sum(query_token_lens) + + logits_offset = 0 + + for request_id, query_token_len in zip(request_ids, query_token_lens): + for i in range(query_token_len - 1): + # requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens. + # Doing argmax over multi-timestep logits computed in parallel should yield the same + # tokens at the corresponding positions. + past_tokens = requests[request_id].token_ids[:-query_token_len] + assert ( + np.argmax(logits[logits_offset + i]) + == requests[request_id].token_ids[len(past_tokens) + i + 1] + ) + + logits_offset += query_token_len + if __name__ == "__main__": run(parse_args()) From f1314a57dbb409158b3c7d88e2fc57af00e85941 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 00:31:17 +0000 Subject: [PATCH 03/39] clean --- examples/python/run_llama_batched_vllm.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index 0307278d44..dc30a0cfa3 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -276,14 +276,9 @@ def _prepare_eval_queries( if sliding_window: seq_lens.append(min(num_past_tokens + num_queries, sliding_window)) - # TODO: verify this - past_slot_mapping += all_slot_mappings[request_id][ - : min(num_past_tokens, sliding_window) - ] - slot_mapping += all_slot_mappings[request_id][ - min(num_past_tokens, sliding_window) : min(num_past_tokens, sliding_window) - + num_queries - ] + num_past = min(num_past_tokens, sliding_window) + past_slot_mapping += all_slot_mappings[request_id][num_past:] + slot_mapping += all_slot_mappings[request_id][num_past: num_past + num_queries] else: seq_lens.append(num_past_tokens + num_queries) past_slot_mapping += all_slot_mappings[request_id][:num_past_tokens] From 2bee022917a1f586f029dcb4e7daf1a279550518 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 09:38:12 +0000 Subject: [PATCH 04/39] Only the number of past tokens is needed --- examples/python/run_llama_batched_vllm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index dc30a0cfa3..5cb7f52ae6 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -120,7 +120,7 @@ class SequenceGenerationResponse: @dataclass class EvalQueryRequest: request_id: int - past_token_ids: List[int] + num_past_tokens: int query_token_ids: List[int] @@ -262,11 +262,11 @@ def _prepare_eval_queries( positions = [] permute_map = [] - query_offset = sum([len(request.past_token_ids) for request in requests]) + query_offset = sum([request.num_past_tokens for request in requests]) past_offset = 0 for request in requests: - num_past_tokens = len(request.past_token_ids) + num_past_tokens = request.num_past_tokens num_queries = len(request.query_token_ids) query_lens.append(num_queries) request_id = request.request_id @@ -521,8 +521,8 @@ def run(args): for request_id, query_token_len in zip(request_ids, query_token_lens): queries_to_eval = requests[request_id].token_ids[-query_token_len:] - past_tokens = requests[request_id].token_ids[:-query_token_len] - eval_query_requests.append(EvalQueryRequest(request_id, past_tokens, queries_to_eval)) + num_past = len(requests[request_id].token_ids) - query_token_len + eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval)) ( input_ids, From 756b09f6bc54682c9221af62df1aaad2d2478c20 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 21:59:52 +0000 Subject: [PATCH 05/39] fix build --- mlc_llm/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 7a69562696..f7afbbb693 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -593,6 +593,7 @@ def mod_transform_before_build( # This is equivalent to prefill but without KV cache. It is used for # determining the number of paged cache blocks that can be allocated. model_names.append("evaluate") + model_names.append("evaluate_multi_query") if args.sep_embed: model_names = ["embed", "prefill_with_embed"] + model_names[1:] From 09ef5b3467fc6fcba8d23bb5754454e6eac250d2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 00:49:59 +0000 Subject: [PATCH 06/39] fix --- examples/python/run_llama_batched_vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index 5cb7f52ae6..dcb16a878d 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -277,7 +277,7 @@ def _prepare_eval_queries( if sliding_window: seq_lens.append(min(num_past_tokens + num_queries, sliding_window)) num_past = min(num_past_tokens, sliding_window) - past_slot_mapping += all_slot_mappings[request_id][num_past:] + past_slot_mapping += all_slot_mappings[request_id][:num_past] slot_mapping += all_slot_mappings[request_id][num_past: num_past + num_queries] else: seq_lens.append(num_past_tokens + num_queries) From 7b67ba40c85e844dbce5c8d782d4d3d02b28d9a1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 08:40:17 +0000 Subject: [PATCH 07/39] correctly handle num_past_tokens > sliding_window case --- examples/python/run_llama_batched_vllm.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index dcb16a878d..0a2c8f0b9c 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -274,17 +274,18 @@ def _prepare_eval_queries( positions += [num_past_tokens + i for i in range(num_queries)] - if sliding_window: - seq_lens.append(min(num_past_tokens + num_queries, sliding_window)) - num_past = min(num_past_tokens, sliding_window) - past_slot_mapping += all_slot_mappings[request_id][:num_past] - slot_mapping += all_slot_mappings[request_id][num_past: num_past + num_queries] + if sliding_window and num_past_tokens + num_queries >= sliding_window: + seq_lens.append(sliding_window) + past_slot_mapping += all_slot_mappings[request_id][ + num_past_tokens - (sliding_window - num_queries) : num_past_tokens + ] else: seq_lens.append(num_past_tokens + num_queries) past_slot_mapping += all_slot_mappings[request_id][:num_past_tokens] - slot_mapping += all_slot_mappings[request_id][ - num_past_tokens : num_past_tokens + num_queries - ] + + slot_mapping += all_slot_mappings[request_id][ + num_past_tokens : num_past_tokens + num_queries + ] permute_map += list(range(past_offset, past_offset + num_past_tokens)) + list( range(query_offset, query_offset + num_queries) From e0517fd15220e11e31dd42abaa1fb300215a927c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 06:43:18 +0000 Subject: [PATCH 08/39] wip --- serve/mlc_serve/engine/model_module.py | 20 +- serve/mlc_serve/model/paged_cache_model.py | 314 ++++++++++++++------- 2 files changed, 233 insertions(+), 101 deletions(-) diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 9b018c6cc4..b3bdb31064 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -28,11 +28,29 @@ class PrefillRequest: class DecodeRequest: sequence_id: SequenceId prompt_token_counts: int - # All tokens for this request, including prompt + # Decoded tokens for this sequence token_ids: List[int] sampling_params: SamplingParams +@dataclass +class DraftTokens: + token_ids: List[int] + + +@dataclass +class EvictedTokens: + token_ids: List[int] + + +@dataclass +class MultiQueryDecodeRequest: + sequence_id: SequenceId + past_token_ids: List[int] + queries: Union[DraftTokens, EvictedTokens] + sampling_params: SamplingParams + + @dataclass class TextGenerationResult: """ diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 64c291aa1c..5c1ffd475c 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -24,6 +24,7 @@ from ..engine.model_module import ( DecodeRequest, PrefillRequest, + MultiQueryDecodeRequest, TextGenerationResult, ) from ..engine.model_module import ModelModule @@ -277,6 +278,71 @@ def _pad_to_max(x: List[int], max_len: int) -> List[int]: ) +def prepare_multi_query_decode_inputs( + requests: List[MultiQueryDecodeRequest], + all_slot_mappings, + sliding_window, + dev, +): + seq_lens = [] + query_lens = [] + input_ids = [] + slot_mapping = [] + past_slot_mapping = [] + positions = [] + permute_map = [] + + query_offset = sum([len(request.past_token_ids) for request in requests]) + past_offset = 0 + + for request in requests: + num_past_tokens = len(request.past_token_ids) + num_queries = len(request.query_token_ids) + query_lens.append(num_queries) + request_id = request.request_id + input_ids += request.query_token_ids + + positions += [num_past_tokens + i for i in range(num_queries)] + + if sliding_window: + seq_lens.append(min(num_past_tokens + num_queries, sliding_window)) + num_past = min(num_past_tokens, sliding_window) + past_slot_mapping += all_slot_mappings[request_id][num_past:] + slot_mapping += all_slot_mappings[request_id][num_past: num_past + num_queries] + else: + seq_lens.append(num_past_tokens + num_queries) + past_slot_mapping += all_slot_mappings[request_id][:num_past_tokens] + slot_mapping += all_slot_mappings[request_id][ + num_past_tokens : num_past_tokens + num_queries + ] + + permute_map += list(range(past_offset, past_offset + num_past_tokens)) + list( + range(query_offset, query_offset + num_queries) + ) + + query_offset += num_queries + past_offset += num_past_tokens + + input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), dev) + positions = tvm.nd.array(np.array(positions, dtype="int32"), dev) + seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), dev) + slot_mapping = tvm.nd.array(np.array(slot_mapping, dtype="int32"), dev) + + query_lens = tvm.nd.array(np.array(query_lens, dtype="int32"), dev) + past_slot_mapping = tvm.nd.array(np.array(past_slot_mapping, dtype="int32"), dev) + permute_map = tvm.nd.array(np.array(permute_map, dtype="int32"), dev) + + return ( + input_ids, + positions, + seq_lens, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + ) + + class Model: def __init__( self, @@ -352,32 +418,169 @@ def profile_memory_usage(self, seq_lens): return self.get_used_memory() + def sample_from_logits(self, logits, sequence_ids, requests): + sampling_params = [req.sampling_params for req in requests] + + try: + next_tokens = sample(logits, sampling_params, self.vocab_size) + assert next_tokens is not None + outputs = [] + for i, (sequence_id, new_token) in enumerate( + zip(sequence_ids, next_tokens) + ): + if not new_token in sampling_params[i].appeared_tokens_freq: + requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + assert isinstance(requests[i], PrefillRequest) + + for seq_id in range(requests[i].num_sequences): + outputs.append( + TextGenerationResult( + sequence_id=SequenceId(sequence_id.request_id, seq_id), + generated_tokens=[new_token], + error=None, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[new_token], + error=None, + ) + ) + + return outputs + except RuntimeError: + # Fallback to per-token sampling in case some logits values are corrupted. + outputs = [] + err_msg = ( + "Error from sampling: probability tensor contains either `inf`, `nan`" + " or element < 0" + ) + + for i, (sequence_id, logits_per_token, sampling_param) in enumerate( + zip(sequence_ids, torch.from_dlpack(logits), sampling_params) + ): + maybe_new_token = sample( + torch.unsqueeze(logits_per_token, 0), + [sampling_param], + self.vocab_size, + check_safety=True, + ) + + if maybe_new_token is not None: + new_token = maybe_new_token[0] + if not new_token in requests[i].sampling_params.appeared_tokens_freq: + requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + assert isinstance(requests[i], PrefillRequest) + for seq_id in range(requests[i].num_sequences): + outputs.append( + TextGenerationResult( + sequence_id=SequenceId( + sequence_id.request_id, seq_id + ), + generated_tokens=[new_token], # type: ignore + error=None, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[new_token], # type: ignore + error=None, + ) + ) + else: + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + assert isinstance(requests[i], PrefillRequest) + for seq_id in range(requests[i].num_sequences): + outputs.append( + TextGenerationResult( + sequence_id=SequenceId( + sequence_id.request_id, seq_id + ), + generated_tokens=[], + error=err_msg, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[], + error=err_msg, + ) + ) + + return outputs + + def generate_multi_query(self, requests:List[MultiQueryDecodeRequest], cache: KVCache) -> List[TextGenerationResult]: + sequence_ids = [] + for request in requests: + sequence_ids.append(request.sequence_id) + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + ) = prepare_multi_query_decode_inputs( + requests, + cache.slot_mappings, + None, + self.dev, + ) + + logits = self.mod["evaluate_multi_query"]( + input_ids, + positions, + seq_lens, + cache.cache_blocks, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + self.params, + )[0].numpy() + + return self.sample_from_logits(logits, sequence_ids, requests) + def generate( self, - requests: Union[List[PrefillRequest], List[DecodeRequest]], + requests: Union[List[PrefillRequest], List[DecodeRequest], List[MultiQueryDecodeRequest]], cache: KVCache, ) -> List[TextGenerationResult]: if len(requests) == 0: return [] is_prefill = isinstance(requests[0], PrefillRequest) + is_multi_query_decode = isinstance(requests[0], MultiQueryDecodeRequest) + if is_multi_query_decode: + return self.generate_multi_query(requests, cache) + + # Prefill or decode all_token_ids = [] - sampling_params = [] sequence_ids = [] prompt_lens = [] - num_sequences = [] for request in requests: if isinstance(request, PrefillRequest): sequence_ids.append(get_prompt_sequence_id(request.request_id)) - num_sequences.append(request.num_sequence) - else: + elif isinstance(request, DecodeRequest): sequence_ids.append(request.sequence_id) prompt_lens.append(request.prompt_token_counts) + assert not isinstance(request, MultiQueryDecodeRequest) all_token_ids.append(request.token_ids) - sampling_params.append(request.sampling_params) ( input_ids, @@ -478,99 +681,7 @@ def generate( self.copy_cache_blocks_func(cache.cache_blocks, block_mapping) cache.pending_copy_from_to = [] - try: - next_tokens = sample(logits, sampling_params, self.vocab_size) - assert next_tokens is not None - outputs = [] - for i, (sequence_id, new_token) in enumerate( - zip(sequence_ids, next_tokens) - ): - if not new_token in requests[i].sampling_params.appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[new_token], - error=None, - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], - error=None, - ) - ) - - return outputs - except RuntimeError: - # Fallback to per-token sampling in case some logits values are corrupted. - outputs = [] - err_msg = ( - "Error from sampling: probability tensor contains either `inf`, `nan`" - " or element < 0" - ) - - for i, (sequence_id, logits_per_token, sampling_param) in enumerate( - zip(sequence_ids, torch.from_dlpack(logits), sampling_params) - ): - maybe_new_token = sample( - torch.unsqueeze(logits_per_token, 0), - [sampling_param], - self.vocab_size, - check_safety=True, - ) - - if maybe_new_token is not None: - new_token = maybe_new_token[0] - if not new_token in requests[i].sampling_params.appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId( - sequence_id.request_id, seq_id - ), - generated_tokens=[new_token], # type: ignore - error=None, - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], # type: ignore - error=None, - ) - ) - else: - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId( - sequence_id.request_id, seq_id - ), - generated_tokens=[], - error=err_msg, - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[], - error=err_msg, - ) - ) - - return outputs + return self.sample_from_logits(logits, sequence_ids, requests) def get_gpu_memory(gpu: int = 0) -> int: @@ -600,16 +711,19 @@ def __init__(self, model: Model): self.model = model def generate( - self, requests: list[Union[PrefillRequest, DecodeRequest]], kv_cache + self, requests: list[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], kv_cache ) -> list[TextGenerationResult]: prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)] decode_requests = [r for r in requests if isinstance(r, DecodeRequest)] + multi_query_decode_requests = [r for r in requests if isinstance(r, MultiQueryDecodeRequest)] out = [] if prefill_requests: out.extend(self.model.generate(prefill_requests, kv_cache)) if decode_requests: out.extend(self.model.generate(decode_requests, kv_cache)) + if multi_query_decode_requests: + out.extend(self.model.generate(multi_query_decode_requests, kv_cache)) return out From cf89a5bae21f1e48a623312658ea2d67fc649b68 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 06:44:00 +0000 Subject: [PATCH 09/39] blac --- serve/mlc_serve/model/paged_cache_model.py | 58 +++++++++++++++++----- 1 file changed, 45 insertions(+), 13 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 5c1ffd475c..fe067e1516 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -106,17 +106,33 @@ def _is_safe_to_sample(prob_like): # TODO(vvchernov): need to strictly define order of using penalties and logit bias or # prohibit simultaneous using of them. At the latter case it can be LogitProcessor - if (not param.presence_penalty == 0.0 or not param.frequency_penalty == 0) and bool(freq): - index = torch.from_numpy(np.array(list(freq.keys()))).to(device=logits.device) - src = torch.from_numpy(np.array(list(freq.values()))).type_as(logits).to(device=logits.device) - logits[i][index] -= src * param.frequency_penalty + param.presence_penalty + if ( + not param.presence_penalty == 0.0 or not param.frequency_penalty == 0 + ) and bool(freq): + index = torch.from_numpy(np.array(list(freq.keys()))).to( + device=logits.device + ) + src = ( + torch.from_numpy(np.array(list(freq.values()))) + .type_as(logits) + .to(device=logits.device) + ) + logits[i][index] -= ( + src * param.frequency_penalty + param.presence_penalty + ) if not param.repetition_penalty == 1.0 and bool(freq): - index = torch.from_numpy(np.array(list(freq.keys()))).to(device=logits.device) + index = torch.from_numpy(np.array(list(freq.keys()))).to( + device=logits.device + ) logits[i][index] /= param.repetition_penalty if param.logit_bias: - logits[i][param.logit_bias_index] += torch.Tensor(param.logit_bias_value).type_as(logits).to(device=logits.device) + logits[i][param.logit_bias_index] += ( + torch.Tensor(param.logit_bias_value) + .type_as(logits) + .to(device=logits.device) + ) logits_random = logits[mask_random] @@ -185,7 +201,10 @@ def get_tvm_model(config, dev): vm = relax.VirtualMachine(ex, dev) from tvm.contrib import tvmjs # pylint: disable=import-outside-toplevel - _params, _meta = tvmjs.load_ndarray_cache(f"{config.model_artifact_path}/params", dev) + + _params, _meta = tvmjs.load_ndarray_cache( + f"{config.model_artifact_path}/params", dev + ) params = [] for i in range(_meta["ParamSize"]): params.append(_params[f"param_{i}"]) @@ -308,7 +327,9 @@ def prepare_multi_query_decode_inputs( seq_lens.append(min(num_past_tokens + num_queries, sliding_window)) num_past = min(num_past_tokens, sliding_window) past_slot_mapping += all_slot_mappings[request_id][num_past:] - slot_mapping += all_slot_mappings[request_id][num_past: num_past + num_queries] + slot_mapping += all_slot_mappings[request_id][ + num_past : num_past + num_queries + ] else: seq_lens.append(num_past_tokens + num_queries) past_slot_mapping += all_slot_mappings[request_id][:num_past_tokens] @@ -472,7 +493,10 @@ def sample_from_logits(self, logits, sequence_ids, requests): if maybe_new_token is not None: new_token = maybe_new_token[0] - if not new_token in requests[i].sampling_params.appeared_tokens_freq: + if ( + not new_token + in requests[i].sampling_params.appeared_tokens_freq + ): requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: @@ -519,7 +543,9 @@ def sample_from_logits(self, logits, sequence_ids, requests): return outputs - def generate_multi_query(self, requests:List[MultiQueryDecodeRequest], cache: KVCache) -> List[TextGenerationResult]: + def generate_multi_query( + self, requests: List[MultiQueryDecodeRequest], cache: KVCache + ) -> List[TextGenerationResult]: sequence_ids = [] for request in requests: sequence_ids.append(request.sequence_id) @@ -555,7 +581,9 @@ def generate_multi_query(self, requests:List[MultiQueryDecodeRequest], cache: KV def generate( self, - requests: Union[List[PrefillRequest], List[DecodeRequest], List[MultiQueryDecodeRequest]], + requests: Union[ + List[PrefillRequest], List[DecodeRequest], List[MultiQueryDecodeRequest] + ], cache: KVCache, ) -> List[TextGenerationResult]: if len(requests) == 0: @@ -711,11 +739,15 @@ def __init__(self, model: Model): self.model = model def generate( - self, requests: list[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], kv_cache + self, + requests: list[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], + kv_cache, ) -> list[TextGenerationResult]: prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)] decode_requests = [r for r in requests if isinstance(r, DecodeRequest)] - multi_query_decode_requests = [r for r in requests if isinstance(r, MultiQueryDecodeRequest)] + multi_query_decode_requests = [ + r for r in requests if isinstance(r, MultiQueryDecodeRequest) + ] out = [] if prefill_requests: From 9ca48064904aa4fb3ed48fe46f81690ec4890dbd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 07:08:23 +0000 Subject: [PATCH 10/39] wip --- serve/mlc_serve/engine/model_module.py | 8 ++++++++ serve/mlc_serve/model/paged_cache_model.py | 23 ++++++++++++++++++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index b3bdb31064..896d9d3ef1 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -37,11 +37,19 @@ class DecodeRequest: class DraftTokens: token_ids: List[int] + @property + def num_tokens(self): + return len(self.token_ids) + @dataclass class EvictedTokens: token_ids: List[int] + @property + def num_tokens(self): + return len(self.token_ids) + @dataclass class MultiQueryDecodeRequest: diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index fe067e1516..680e260a82 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -24,6 +24,7 @@ from ..engine.model_module import ( DecodeRequest, PrefillRequest, + DraftTokens, MultiQueryDecodeRequest, TextGenerationResult, ) @@ -440,6 +441,8 @@ def profile_memory_usage(self, seq_lens): return self.get_used_memory() def sample_from_logits(self, logits, sequence_ids, requests): + assert logits.shape[0] == len(requests) + sampling_params = [req.sampling_params for req in requests] try: @@ -547,9 +550,18 @@ def generate_multi_query( self, requests: List[MultiQueryDecodeRequest], cache: KVCache ) -> List[TextGenerationResult]: sequence_ids = [] + last_query_offsets = [] for request in requests: + assert not isinstance(request.queries, DraftTokens) sequence_ids.append(request.sequence_id) + if len(last_query_offsets) == 0: + last_query_offsets.append(request.queries.num_tokens - 1) + else: + last_query_offsets.append( + last_query_offsets[-1] + request.queries.num_tokens + ) + ( input_ids, positions, @@ -565,6 +577,8 @@ def generate_multi_query( self.dev, ) + torch.cuda.nvtx.range_push(f"forward multi-query decode {input_ids.shape}") + logits = self.mod["evaluate_multi_query"]( input_ids, positions, @@ -577,7 +591,12 @@ def generate_multi_query( self.params, )[0].numpy() - return self.sample_from_logits(logits, sequence_ids, requests) + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + last_query_logits = logits[last_query_offsets] + + return self.sample_from_logits(last_query_logits, sequence_ids, requests) def generate( self, @@ -593,7 +612,7 @@ def generate( is_multi_query_decode = isinstance(requests[0], MultiQueryDecodeRequest) if is_multi_query_decode: - return self.generate_multi_query(requests, cache) + return self.generate_multi_query(requests, cache) # type: ignore # Prefill or decode all_token_ids = [] From 4541b4d78489c58cd87e8a203067e8be4852668d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 07:23:24 +0000 Subject: [PATCH 11/39] wip --- serve/mlc_serve/model/paged_cache_model.py | 29 +++++++++++++++------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 680e260a82..3d7019ad5c 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -317,24 +317,25 @@ def prepare_multi_query_decode_inputs( for request in requests: num_past_tokens = len(request.past_token_ids) - num_queries = len(request.query_token_ids) + num_queries = request.queries.num_tokens query_lens.append(num_queries) - request_id = request.request_id - input_ids += request.query_token_ids + input_ids += request.queries.token_ids positions += [num_past_tokens + i for i in range(num_queries)] if sliding_window: seq_lens.append(min(num_past_tokens + num_queries, sliding_window)) num_past = min(num_past_tokens, sliding_window) - past_slot_mapping += all_slot_mappings[request_id][num_past:] - slot_mapping += all_slot_mappings[request_id][ + past_slot_mapping += all_slot_mappings[request.sequence_id][num_past:] + slot_mapping += all_slot_mappings[request.sequence_id][ num_past : num_past + num_queries ] else: seq_lens.append(num_past_tokens + num_queries) - past_slot_mapping += all_slot_mappings[request_id][:num_past_tokens] - slot_mapping += all_slot_mappings[request_id][ + past_slot_mapping += all_slot_mappings[request.sequence_id][ + :num_past_tokens + ] + slot_mapping += all_slot_mappings[request.sequence_id][ num_past_tokens : num_past_tokens + num_queries ] @@ -351,6 +352,8 @@ def prepare_multi_query_decode_inputs( slot_mapping = tvm.nd.array(np.array(slot_mapping, dtype="int32"), dev) query_lens = tvm.nd.array(np.array(query_lens, dtype="int32"), dev) + # TODO(masahi): These inputs need to be replaced by block_table when a proper attention kernel + # becomes available. past_slot_mapping = tvm.nd.array(np.array(past_slot_mapping, dtype="int32"), dev) permute_map = tvm.nd.array(np.array(permute_map, dtype="int32"), dev) @@ -440,7 +443,14 @@ def profile_memory_usage(self, seq_lens): return self.get_used_memory() - def sample_from_logits(self, logits, sequence_ids, requests): + def sample_from_logits( + self, + logits: Union[tvm.nd.NDArray, torch.Tensor], + sequence_ids: List[SequenceId], + requests: Union[ + List[PrefillRequest], List[DecodeRequest], List[MultiQueryDecodeRequest] + ], + ): assert logits.shape[0] == len(requests) sampling_params = [req.sampling_params for req in requests] @@ -457,7 +467,6 @@ def sample_from_logits(self, logits, sequence_ids, requests): requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequences): outputs.append( TextGenerationResult( @@ -577,6 +586,8 @@ def generate_multi_query( self.dev, ) + # TODO(masahi): Disco, sliding window + torch.cuda.nvtx.range_push(f"forward multi-query decode {input_ids.shape}") logits = self.mod["evaluate_multi_query"]( From 5d376d2932c4535e1e186c173083832936e6255c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 08:07:01 +0000 Subject: [PATCH 12/39] remove cancel call back in eviction --- serve/mlc_serve/engine/engine_common.py | 10 ++-------- serve/mlc_serve/engine/staging_engine_worker.py | 6 +----- serve/mlc_serve/engine/sync_engine.py | 6 +----- 3 files changed, 4 insertions(+), 18 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 1bc252b48c..91f5bc45a6 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -320,7 +320,7 @@ def check_prompt_too_long(self, prompt_len: int, num_sequences: int = 1) -> bool < self.max_decode_steps * num_sequences ) - def evict_request(self, cancell_callback: Callable[[RequestId], None]) -> int: + def evict_request(self) -> int: # Must be called with the queue lock held num_eviction = 0 @@ -346,13 +346,7 @@ def evict_request(self, cancell_callback: Callable[[RequestId], None]) -> int: # TODO(masahi): Properly support evicting a multi-sequence request if self.current_batch[request_to_remove.request_id].num_sequences != 1: - cancell_callback(request_to_remove.request_id) - self.remove_request_from_batch(request_to_remove.request_id) - LOG.warn( - "Preempting a multi-sequence request is currently not supported," - f" cancelling request '{request_to_remove.request_id}'", - ) - continue + pass self.remove_request_from_batch(request_to_remove.request_id) request_to_remove.is_prefilled = False diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index e74a6181c8..c66f43ecca 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -305,11 +305,7 @@ def step(self) -> GenerationLoopWorkerOutput: def _adjust_batch(self): with self.queue_lock: - num_eviction = self.evict_request( - cancell_callback=lambda request_id: self.cancelled_requests.append( - self.current_batch[request_id] - ) - ) + num_eviction = self.evict_request() self.prom_metrics.counter(NUM_CACHE_EVICTONS).inc(num_eviction) if self.cache_manager.get_max_new_tokens() <= self.max_decode_steps: diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index 5dcf80eda7..6b6fe004f8 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -252,11 +252,7 @@ def _adjust_batch(self): self.cache_manager.free_request(state) self.requests_to_be_cancelled.remove(request_id) - self.evict_request( - cancell_callback=lambda request_id: self.requests_to_be_cancelled.add( - request_id - ) - ) + self.evict_request() self._discard_cancelled_requests_from_queue() From 59c36ccce4dd2bb85b179a64894c71b3fd4e4507 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 09:26:32 +0000 Subject: [PATCH 13/39] Create MultiQueryDecodeRequest --- serve/mlc_serve/engine/engine_common.py | 58 ++++++++++++++++++++----- 1 file changed, 47 insertions(+), 11 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 91f5bc45a6..dcb2ba0986 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -20,6 +20,8 @@ from .model_module import ( DecodeRequest, PrefillRequest, + MultiQueryDecodeRequest, + EvictedTokens, ConversationTemplate, KVCacheManager, ModelModule, @@ -185,18 +187,56 @@ def get_requests_to_process( token_counts = 0 + is_evicted_parallel_sampling_request = ( + lambda state: not state.is_prefilled + and state.num_sequences > 1 + and any( + len(gen_seq.generated_token_ids) > 0 + for gen_seq in state.generation_sequences + ) + ) + if is_prompt_batch: for state in current_states: - if not state.is_prefilled: + if is_evicted_parallel_sampling_request(state): + requests.append( + PrefillRequest( + request_id=state.request_id, + token_ids=state.prompt_token_ids, + num_sequence=state.num_sequences, + sampling_params=state.sampling_params, + ) + ) + + token_counts += len(state.prompt_token_ids) + + for gen_seq in state.generation_sequences: + requests.append( + MultiQueryDecodeRequest( + sequence_id=gen_seq.seq_id, + past_token_ids=state.prompt_token_ids, + queries=EvictedTokens(gen_seq.generated_token_ids), + sampling_params=state.sampling_params, + ) + ) + + # TODO(masahi): How to account for token counts in MultiQueryDecodeRequest in + # Prometheus metric? + elif not state.is_prefilled: + token_ids = state.prompt_token_ids + # generated_token_ids is added for the case where the request is + # recovering from cache eviction. + + if ( + state.num_sequences == 1 + and state.generation_sequences[0].generated_token_ids + ): + token_ids += state.generation_sequences[0].generated_token_ids + requests.append( - # generated_token_ids is added for the case where the request is - # recovering from cache eviction. - # TODO(masahi): This needs an update when we support evicting - # a parallel-sampling request. PrefillRequest( request_id=state.request_id, - token_ids=state.prompt_token_ids - + state.generation_sequences[0].generated_token_ids, + token_ids=token_ids, num_sequence=state.num_sequences, sampling_params=state.sampling_params, ) @@ -344,10 +384,6 @@ def evict_request(self) -> int: request_to_remove = min(candidate_victims, key=lambda s: s.num_total_tokens) - # TODO(masahi): Properly support evicting a multi-sequence request - if self.current_batch[request_to_remove.request_id].num_sequences != 1: - pass - self.remove_request_from_batch(request_to_remove.request_id) request_to_remove.is_prefilled = False self.queue.appendleft(request_to_remove) From f58acf77ab42143ade84681858d5192594593f89 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 09:45:16 +0000 Subject: [PATCH 14/39] only the number of past tokens is needed --- serve/mlc_serve/engine/engine_common.py | 2 +- serve/mlc_serve/engine/model_module.py | 2 +- serve/mlc_serve/model/paged_cache_model.py | 23 +++++++++++----------- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index dcb2ba0986..8ef9377806 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -214,7 +214,7 @@ def get_requests_to_process( requests.append( MultiQueryDecodeRequest( sequence_id=gen_seq.seq_id, - past_token_ids=state.prompt_token_ids, + num_past_tokens=state.prompt_len, queries=EvictedTokens(gen_seq.generated_token_ids), sampling_params=state.sampling_params, ) diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 896d9d3ef1..a224245be0 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -54,7 +54,7 @@ def num_tokens(self): @dataclass class MultiQueryDecodeRequest: sequence_id: SequenceId - past_token_ids: List[int] + num_past_tokens: int queries: Union[DraftTokens, EvictedTokens] sampling_params: SamplingParams diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 3d7019ad5c..2a3fcd2121 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -312,39 +312,38 @@ def prepare_multi_query_decode_inputs( positions = [] permute_map = [] - query_offset = sum([len(request.past_token_ids) for request in requests]) + query_offset = sum([request.num_past_tokens for request in requests]) past_offset = 0 for request in requests: - num_past_tokens = len(request.past_token_ids) num_queries = request.queries.num_tokens query_lens.append(num_queries) input_ids += request.queries.token_ids - positions += [num_past_tokens + i for i in range(num_queries)] + positions += [request.num_past_tokens + i for i in range(num_queries)] if sliding_window: - seq_lens.append(min(num_past_tokens + num_queries, sliding_window)) - num_past = min(num_past_tokens, sliding_window) + seq_lens.append(min(request.num_past_tokens + num_queries, sliding_window)) + num_past = min(request.num_past_tokens, sliding_window) past_slot_mapping += all_slot_mappings[request.sequence_id][num_past:] slot_mapping += all_slot_mappings[request.sequence_id][ num_past : num_past + num_queries ] else: - seq_lens.append(num_past_tokens + num_queries) + seq_lens.append(request.num_past_tokens + num_queries) past_slot_mapping += all_slot_mappings[request.sequence_id][ - :num_past_tokens + : request.num_past_tokens ] slot_mapping += all_slot_mappings[request.sequence_id][ - num_past_tokens : num_past_tokens + num_queries + request.num_past_tokens : request.num_past_tokens + num_queries ] - permute_map += list(range(past_offset, past_offset + num_past_tokens)) + list( - range(query_offset, query_offset + num_queries) - ) + permute_map += list( + range(past_offset, past_offset + request.num_past_tokens) + ) + list(range(query_offset, query_offset + num_queries)) query_offset += num_queries - past_offset += num_past_tokens + past_offset += request.num_past_tokens input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), dev) positions = tvm.nd.array(np.array(positions, dtype="int32"), dev) From d9dd2ca683aa91a0a4e63775a14ceaad3598f54f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 19:18:33 +0000 Subject: [PATCH 15/39] wip --- serve/mlc_serve/engine/engine_common.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 8ef9377806..7d81592e19 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -428,13 +428,8 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: num_new_batched_tokens ) = num_tokens = self.max_num_batched_tokens else: - # Evicting and recovering multi-sequence requests is not supported for now. - assert all( - gen_seq.next_start_position == state.prompt_len - for gen_seq in state.generation_sequences - ) - num_tokens = state.prompt_len - num_new_batched_tokens += num_tokens + num_tokens = state.prompt_len + sum([len(gen_seq.generated_token_ids) for gen_seq in state.generation_sequences]) + num_new_batched_tokens += state.prompt_len if num_new_batched_tokens > self.max_num_batched_tokens: LOG.debug( From cb11761752c4b6726ee932edace8e163c6cd797a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 20:43:23 +0000 Subject: [PATCH 16/39] wip --- serve/mlc_serve/engine/engine_common.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 7d81592e19..051a08c018 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -428,8 +428,19 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: num_new_batched_tokens ) = num_tokens = self.max_num_batched_tokens else: - num_tokens = state.prompt_len + sum([len(gen_seq.generated_token_ids) for gen_seq in state.generation_sequences]) - num_new_batched_tokens += state.prompt_len + prev_generated_token_counts = sum( + [ + len(gen_seq.generated_token_ids) + for gen_seq in state.generation_sequences + ] + ) + # Restoring an evicted parallel-sampling request with sliding-window attention is + # difficult to reason about, so we compute crude upper bounds below for now. + num_tokens = state.prompt_len + prev_generated_token_counts + # Restoring an evicted parallel-sampling request is done by separate + # Prefill and MultiQuery requests. The maximum below is an upper bound on the + # batch size increase due to this request. + num_new_batched_tokens += max(state.prompt_len, prev_generated_token_counts) if num_new_batched_tokens > self.max_num_batched_tokens: LOG.debug( @@ -453,7 +464,6 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: return None self.queue.popleft() - # TODO parallel sampling: Need update here when evicting multi-sequence requests is supported. self.cache_manager.allocate(state.request_id, num_tokens, state.num_sequences) self.current_batch[state.request_id] = state From 24f7bfa0ef26f69dd2972ff39cfd7e23e4cb8ea1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 20:43:32 +0000 Subject: [PATCH 17/39] wip --- serve/mlc_serve/model/paged_cache_model.py | 28 +++++++++++++++++----- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 2a3fcd2121..159fef59e0 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -323,6 +323,7 @@ def prepare_multi_query_decode_inputs( positions += [request.num_past_tokens + i for i in range(num_queries)] if sliding_window: + # TODO(masahi): Verify this code path seq_lens.append(min(request.num_past_tokens + num_queries, sliding_window)) num_past = min(request.num_past_tokens, sliding_window) past_slot_mapping += all_slot_mappings[request.sequence_id][num_past:] @@ -331,12 +332,27 @@ def prepare_multi_query_decode_inputs( ] else: seq_lens.append(request.num_past_tokens + num_queries) - past_slot_mapping += all_slot_mappings[request.sequence_id][ - : request.num_past_tokens - ] - slot_mapping += all_slot_mappings[request.sequence_id][ - request.num_past_tokens : request.num_past_tokens + num_queries - ] + prompt_seq_id = get_prompt_sequence_id(request.sequence_id.request_id) + prompt_slot_mappings = all_slot_mappings[prompt_seq_id] + + if request.num_past_tokens < len(prompt_slot_mappings): + raise RuntimeError( + "For MultiQueryDecodeRequest, the number of past tokens" + "smaller than the prompt length is not supported for now." + ) + elif request.num_past_tokens == len(prompt_slot_mappings): + # The case for restoring an evicted parallel-sampling request + past_slot_mapping += prompt_slot_mappings[: request.num_past_tokens] + slot_mapping += all_slot_mappings[request.sequence_id][:num_queries] + else: + query_begin_offset = request.num_past_tokens - len(prompt_slot_mappings) + past_slot_mapping += ( + prompt_slot_mappings + + all_slot_mappings[request.sequence_id][:query_begin_offset] + ) + slot_mapping += all_slot_mappings[request.sequence_id][ + query_begin_offset : query_begin_offset + num_queries + ] permute_map += list( range(past_offset, past_offset + request.num_past_tokens) From 34da2211c00a943aeb00a617b17ad8b454e7c7e3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 22:03:22 +0000 Subject: [PATCH 18/39] fix --- serve/mlc_serve/model/paged_cache_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 159fef59e0..d750b20a1b 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -482,7 +482,7 @@ def sample_from_logits( requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequences): + for seq_id in range(requests[i].num_sequence): outputs.append( TextGenerationResult( sequence_id=SequenceId(sequence_id.request_id, seq_id), @@ -528,7 +528,7 @@ def sample_from_logits( requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequences): + for seq_id in range(requests[i].num_sequence): outputs.append( TextGenerationResult( sequence_id=SequenceId( @@ -549,7 +549,7 @@ def sample_from_logits( else: if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequences): + for seq_id in range(requests[i].num_sequence): outputs.append( TextGenerationResult( sequence_id=SequenceId( From d94e9d87ce00508778732555499b78ba37be13f0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 10 Jan 2024 23:37:05 +0000 Subject: [PATCH 19/39] wip --- serve/mlc_serve/engine/engine_common.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 051a08c018..ef7e1f99af 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -219,6 +219,10 @@ def get_requests_to_process( sampling_params=state.sampling_params, ) ) + cache_manager.extend( + gen_seq.seq_id, + len(gen_seq.generated_token_ids) + 1, + ) # TODO(masahi): How to account for token counts in MultiQueryDecodeRequest in # Prometheus metric? @@ -364,7 +368,7 @@ def evict_request(self) -> int: # Must be called with the queue lock held num_eviction = 0 - while self.cache_manager.get_max_new_tokens() < 1: + while self.cache_manager.get_max_new_tokens() < 9089: num_eviction += 1 single_sample_requests = [] @@ -384,6 +388,7 @@ def evict_request(self) -> int: request_to_remove = min(candidate_victims, key=lambda s: s.num_total_tokens) + print("Evicting", request_to_remove.request_id) self.remove_request_from_batch(request_to_remove.request_id) request_to_remove.is_prefilled = False self.queue.appendleft(request_to_remove) @@ -427,6 +432,8 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: gen_seq.next_start_position = ( num_new_batched_tokens ) = num_tokens = self.max_num_batched_tokens + + num_kv_slots_needed = min(num_tokens, self.model_context_window_size) else: prev_generated_token_counts = sum( [ @@ -434,9 +441,15 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: for gen_seq in state.generation_sequences ] ) + + if prev_generated_token_counts > 0: + print("Restoring", state.request_id) + print("prev_generated_token_counts", prev_generated_token_counts) + # Restoring an evicted parallel-sampling request with sliding-window attention is - # difficult to reason about, so we compute crude upper bounds below for now. - num_tokens = state.prompt_len + prev_generated_token_counts + # difficult to reason about, so we use crude upper bounds below for now. + num_tokens = state.prompt_len + num_kv_slots_needed = state.prompt_len + prev_generated_token_counts # Restoring an evicted parallel-sampling request is done by separate # Prefill and MultiQuery requests. The maximum below is an upper bound on the # batch size increase due to this request. @@ -452,7 +465,6 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: # We make sure that the KV cache will have enough free space for this request to proceed # decoding for at least self.max_decode_steps steps. # See the comment in check_prompt_too_long for the optimization involving the window size. - num_kv_slots_needed = min(num_tokens, self.model_context_window_size) if (self.cache_manager.get_free_space() - num_kv_slots_needed) / ( len(self.current_batch) + 1 ) < self.max_decode_steps * state.num_sequences: From 4a3bb77f8cc0cc1f7f07196ad96c990f1a4b22a8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 00:34:26 +0000 Subject: [PATCH 20/39] wip --- serve/mlc_serve/model/paged_cache_manager.py | 82 ++++++++++++-------- serve/mlc_serve/model/paged_cache_model.py | 6 ++ 2 files changed, 55 insertions(+), 33 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_manager.py b/serve/mlc_serve/model/paged_cache_manager.py index 7cd24f5182..06062e8e01 100644 --- a/serve/mlc_serve/model/paged_cache_manager.py +++ b/serve/mlc_serve/model/paged_cache_manager.py @@ -175,50 +175,65 @@ def set_size(self, sequence_ids: List[SequenceId], target_sizes: List[int]): elif id in self.kv_cache.decode_block_tables: decode_block_table = self.kv_cache.decode_block_tables[id] - if len(decode_block_table) < num_needed_block: + while len(decode_block_table) < num_needed_block: # Need to allocate a new block for this request - assert len(decode_block_table) + 1 == num_needed_block assert len(self.free_blocks) > 0 decode_block_table.append(self.free_blocks.pop()) - pos = size - 1 + prompt_seq_id = get_prompt_sequence_id(id.request_id) + allocated_slot_counts = len( + self.kv_cache.slot_mappings[prompt_seq_id] + ) + len(self.kv_cache.slot_mappings[id]) - def get_block_circular_index(token_pos): - assert self.block_sliding_window - return (token_pos // self.block_size) % self.block_sliding_window + for current_size in range(allocated_slot_counts + 1, size + 1): + pos = current_size - 1 - if ( - decode_block_table.prompt_shared - and self.sliding_window - and size >= self.sliding_window - ): - # Parallel sampling + SWA case - if decode_block_table.prompt_cursor == get_block_circular_index( - pos + def get_block_circular_index(token_pos): + assert self.block_sliding_window + return ( + token_pos // self.block_size + ) % self.block_sliding_window + + if ( + decode_block_table.prompt_shared + and self.sliding_window + and current_size >= self.sliding_window ): - # This sequence is trying to overwrite a prompt block shared with other sequences. - assert ( - len(self.free_blocks) > 0 - ), "No more free block in the cache." - - block_number = self.free_blocks.pop() - # Add a new decode block and advance the prompt cursor - decode_block_table.replace_head_prompt_block_with(block_number) - else: - # Write to the decode block allocated above - block_number = decode_block_table[-1] + # Parallel sampling + SWA case + if ( + decode_block_table.prompt_cursor + == get_block_circular_index(pos) + ): + # This sequence is trying to overwrite a prompt block shared with other sequences. + assert ( + len(self.free_blocks) > 0 + ), "No more free block in the cache." + + block_number = self.free_blocks.pop() + # Add a new decode block and advance the prompt cursor + decode_block_table.replace_head_prompt_block_with( + block_number + ) + else: + # Write to the decode block allocated above + block_number = decode_block_table[-1] - else: - if self.block_sliding_window: - index = get_block_circular_index(pos) else: - index = -1 + if self.block_sliding_window: + index = get_block_circular_index(pos) + else: + index = -1 - block_number = decode_block_table[index] + block_number = decode_block_table[index] - block_offset = pos % self.block_size - slot = block_number * self.block_size + block_offset - self.kv_cache.slot_mappings[id].append(slot) + block_offset = pos % self.block_size + slot = block_number * self.block_size + block_offset + self.kv_cache.slot_mappings[id].append(slot) + + print( + "len(self.kv_cache.slot_mappings[id]", + len(self.kv_cache.slot_mappings[id]), + ) elif id not in self.kv_cache.prompt_block_tables: assert ( @@ -319,6 +334,7 @@ def extend(self, sequence_id: SequenceId, new_tokens: int): allocated = self.token_counts[sequence_id] self.set_size([sequence_id], [allocated + new_tokens]) self.token_counts[sequence_id] += new_tokens + print("sequence_id, allocated, new_tokens", sequence_id, allocated, new_tokens) def free(self, sequence_id: SequenceId): """ diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index d750b20a1b..944f8f4b76 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -344,6 +344,10 @@ def prepare_multi_query_decode_inputs( # The case for restoring an evicted parallel-sampling request past_slot_mapping += prompt_slot_mappings[: request.num_past_tokens] slot_mapping += all_slot_mappings[request.sequence_id][:num_queries] + print( + "len(all_slot_mappings[request.sequence_id]", + len(all_slot_mappings[request.sequence_id]), + ) else: query_begin_offset = request.num_past_tokens - len(prompt_slot_mappings) past_slot_mapping += ( @@ -797,9 +801,11 @@ def generate( out = [] if prefill_requests: out.extend(self.model.generate(prefill_requests, kv_cache)) + print("finished prefill") if decode_requests: out.extend(self.model.generate(decode_requests, kv_cache)) if multi_query_decode_requests: + print("doing multi query decode") out.extend(self.model.generate(multi_query_decode_requests, kv_cache)) return out From 0c6875e2be07d53bdd2ae096878c04fb9237e55f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 00:48:36 +0000 Subject: [PATCH 21/39] wip --- serve/mlc_serve/engine/engine_common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index ef7e1f99af..e78b8f8602 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -349,6 +349,7 @@ def __init__(self, model_module: ModelModule): self.has_new_requests = Condition(lock=self.queue_lock) self.current_batch = dict[RequestId, RequestState]() + self.has_evicted = False def check_prompt_too_long(self, prompt_len: int, num_sequences: int = 1) -> bool: # We make sure that the KV cache will have enough free space for this request to proceed @@ -368,7 +369,9 @@ def evict_request(self) -> int: # Must be called with the queue lock held num_eviction = 0 - while self.cache_manager.get_max_new_tokens() < 9089: + mx = 1 if self.has_evicted else 9089 + while self.cache_manager.get_max_new_tokens() < mx: + self.has_evicted = True num_eviction += 1 single_sample_requests = [] From a46abe199e6500a9a780e55a661b48399d02a95c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 00:48:46 +0000 Subject: [PATCH 22/39] wip --- serve/mlc_serve/model/paged_cache_model.py | 8 +++----- serve/tests/test_engine.py | 11 ++++++----- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 944f8f4b76..de4d3194f6 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -326,7 +326,7 @@ def prepare_multi_query_decode_inputs( # TODO(masahi): Verify this code path seq_lens.append(min(request.num_past_tokens + num_queries, sliding_window)) num_past = min(request.num_past_tokens, sliding_window) - past_slot_mapping += all_slot_mappings[request.sequence_id][num_past:] + past_slot_mapping += all_slot_mappings[request.sequence_id][:num_past] slot_mapping += all_slot_mappings[request.sequence_id][ num_past : num_past + num_queries ] @@ -344,10 +344,6 @@ def prepare_multi_query_decode_inputs( # The case for restoring an evicted parallel-sampling request past_slot_mapping += prompt_slot_mappings[: request.num_past_tokens] slot_mapping += all_slot_mappings[request.sequence_id][:num_queries] - print( - "len(all_slot_mappings[request.sequence_id]", - len(all_slot_mappings[request.sequence_id]), - ) else: query_begin_offset = request.num_past_tokens - len(prompt_slot_mappings) past_slot_mapping += ( @@ -625,6 +621,7 @@ def generate_multi_query( torch.cuda.nvtx.range_pop() last_query_logits = logits[last_query_offsets] + print("last_query_logits.shape", last_query_logits.shape) return self.sample_from_logits(last_query_logits, sequence_ids, requests) @@ -803,6 +800,7 @@ def generate( out.extend(self.model.generate(prefill_requests, kv_cache)) print("finished prefill") if decode_requests: + print("doing decode") out.extend(self.model.generate(decode_requests, kv_cache)) if multi_query_decode_requests: print("doing multi query decode") diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 4355fec6cd..2c65f6d51b 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -70,8 +70,8 @@ def _test(args: argparse.Namespace): prompts = [ "Hello, my name is", "The capital of France is", - "The president of the United States is a powerful man. But he can also be", - "The future of AI is full of promise. But we need to carefully", + # "The president of the United States is a powerful man. But he can also be", + # "The future of AI is full of promise. But we need to carefully", ] for i, prompt in enumerate(prompts): @@ -100,10 +100,11 @@ def _test(args: argparse.Namespace): if any(seq.is_finished for seq in res.sequences): any_finished.add(res.request_id) - if res.request_id not in any_finished: - # If all sequences are still running, we should always get num_sequences samples back. - assert len(res.sequences) == num_sequences, res + # if res.request_id not in any_finished: + # # If all sequences are still running, we should always get num_sequences samples back. + # assert len(res.sequences) == num_sequences, res + print("len(res.sequences)", len(res.sequences)) for i, seq in enumerate(res.sequences): if seq.delta: generated[int(res.request_id)][i] += seq.delta From c80bea2544832ea496eda78a310e02f648f75794 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 02:32:32 +0000 Subject: [PATCH 23/39] working? --- serve/mlc_serve/model/paged_cache_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index de4d3194f6..19a4e7d29c 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -797,7 +797,9 @@ def generate( out = [] if prefill_requests: - out.extend(self.model.generate(prefill_requests, kv_cache)) + prefill_res = self.model.generate(prefill_requests, kv_cache) + if not multi_query_decode_requests: + out.extend(prefill_res) print("finished prefill") if decode_requests: print("doing decode") From 18239a45e6eefa81a1632cf4afeb0b952db11982 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 06:59:31 +0000 Subject: [PATCH 24/39] remove dbg print --- serve/mlc_serve/engine/engine_common.py | 1 - serve/mlc_serve/model/paged_cache_manager.py | 6 ------ serve/mlc_serve/model/paged_cache_model.py | 4 ---- serve/tests/test_engine.py | 7 +++---- 4 files changed, 3 insertions(+), 15 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index e78b8f8602..df337758db 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -447,7 +447,6 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: if prev_generated_token_counts > 0: print("Restoring", state.request_id) - print("prev_generated_token_counts", prev_generated_token_counts) # Restoring an evicted parallel-sampling request with sliding-window attention is # difficult to reason about, so we use crude upper bounds below for now. diff --git a/serve/mlc_serve/model/paged_cache_manager.py b/serve/mlc_serve/model/paged_cache_manager.py index 06062e8e01..78ea4bbb2c 100644 --- a/serve/mlc_serve/model/paged_cache_manager.py +++ b/serve/mlc_serve/model/paged_cache_manager.py @@ -230,11 +230,6 @@ def get_block_circular_index(token_pos): slot = block_number * self.block_size + block_offset self.kv_cache.slot_mappings[id].append(slot) - print( - "len(self.kv_cache.slot_mappings[id]", - len(self.kv_cache.slot_mappings[id]), - ) - elif id not in self.kv_cache.prompt_block_tables: assert ( len(self.free_blocks) >= num_needed_block @@ -334,7 +329,6 @@ def extend(self, sequence_id: SequenceId, new_tokens: int): allocated = self.token_counts[sequence_id] self.set_size([sequence_id], [allocated + new_tokens]) self.token_counts[sequence_id] += new_tokens - print("sequence_id, allocated, new_tokens", sequence_id, allocated, new_tokens) def free(self, sequence_id: SequenceId): """ diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 19a4e7d29c..ac9bdd7348 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -621,7 +621,6 @@ def generate_multi_query( torch.cuda.nvtx.range_pop() last_query_logits = logits[last_query_offsets] - print("last_query_logits.shape", last_query_logits.shape) return self.sample_from_logits(last_query_logits, sequence_ids, requests) @@ -800,12 +799,9 @@ def generate( prefill_res = self.model.generate(prefill_requests, kv_cache) if not multi_query_decode_requests: out.extend(prefill_res) - print("finished prefill") if decode_requests: - print("doing decode") out.extend(self.model.generate(decode_requests, kv_cache)) if multi_query_decode_requests: - print("doing multi query decode") out.extend(self.model.generate(multi_query_decode_requests, kv_cache)) return out diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 2c65f6d51b..93814b8cf4 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -100,11 +100,10 @@ def _test(args: argparse.Namespace): if any(seq.is_finished for seq in res.sequences): any_finished.add(res.request_id) - # if res.request_id not in any_finished: - # # If all sequences are still running, we should always get num_sequences samples back. - # assert len(res.sequences) == num_sequences, res + if res.request_id not in any_finished: + # If all sequences are still running, we should always get num_sequences samples back. + assert len(res.sequences) == num_sequences, res - print("len(res.sequences)", len(res.sequences)) for i, seq in enumerate(res.sequences): if seq.delta: generated[int(res.request_id)][i] += seq.delta From fd2b2bd3da44677baca77b80cce3005db3356e92 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 07:33:36 +0000 Subject: [PATCH 25/39] multi gpu works --- serve/mlc_serve/engine/engine_common.py | 6 ++--- serve/mlc_serve/model/paged_cache_model.py | 26 +++++++++++++++++----- serve/tests/test_engine.py | 4 ++-- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index df337758db..47cab99166 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -349,7 +349,6 @@ def __init__(self, model_module: ModelModule): self.has_new_requests = Condition(lock=self.queue_lock) self.current_batch = dict[RequestId, RequestState]() - self.has_evicted = False def check_prompt_too_long(self, prompt_len: int, num_sequences: int = 1) -> bool: # We make sure that the KV cache will have enough free space for this request to proceed @@ -369,9 +368,8 @@ def evict_request(self) -> int: # Must be called with the queue lock held num_eviction = 0 - mx = 1 if self.has_evicted else 9089 - while self.cache_manager.get_max_new_tokens() < mx: - self.has_evicted = True + # print("self.cache_manager.get_max_new_tokens()", self.cache_manager.get_max_new_tokens()) + while self.cache_manager.get_max_new_tokens() < 1: num_eviction += 1 single_sample_requests = [] diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index ac9bdd7348..92f91e2a66 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -601,11 +601,20 @@ def generate_multi_query( self.dev, ) - # TODO(masahi): Disco, sliding window - torch.cuda.nvtx.range_push(f"forward multi-query decode {input_ids.shape}") - logits = self.mod["evaluate_multi_query"]( + if self.disco_session: + input_ids = copy_to_worker_0(self.disco_session, input_ids) + positions = copy_to_worker_0(self.disco_session, positions) + seq_lens = copy_to_worker_0(self.disco_session, seq_lens) + slot_mapping = copy_to_worker_0(self.disco_session, slot_mapping) + query_lens = copy_to_worker_0(self.disco_session, query_lens) + past_slot_mapping = copy_to_worker_0(self.disco_session, past_slot_mapping) + permute_map = copy_to_worker_0(self.disco_session, permute_map) + + # TODO(masahi): sliding window + + out = self.mod["evaluate_multi_query"]( input_ids, positions, seq_lens, @@ -615,12 +624,19 @@ def generate_multi_query( past_slot_mapping, permute_map, self.params, - )[0].numpy() + ) + + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[ + 0 + ] torch.cuda.synchronize() torch.cuda.nvtx.range_pop() - last_query_logits = logits[last_query_offsets] + last_query_logits = torch.from_dlpack(logits)[last_query_offsets] return self.sample_from_logits(last_query_logits, sequence_ids, requests) diff --git a/serve/tests/test_engine.py b/serve/tests/test_engine.py index 93814b8cf4..4355fec6cd 100644 --- a/serve/tests/test_engine.py +++ b/serve/tests/test_engine.py @@ -70,8 +70,8 @@ def _test(args: argparse.Namespace): prompts = [ "Hello, my name is", "The capital of France is", - # "The president of the United States is a powerful man. But he can also be", - # "The future of AI is full of promise. But we need to carefully", + "The president of the United States is a powerful man. But he can also be", + "The future of AI is full of promise. But we need to carefully", ] for i, prompt in enumerate(prompts): From 6ac292b017ce36aaa13e9d39211c7bf065270e0d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 09:16:03 +0000 Subject: [PATCH 26/39] fixed sliding window logic --- serve/mlc_serve/model/paged_cache_model.py | 30 ++++++++++++---------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 92f91e2a66..fe3f1385df 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -319,21 +319,25 @@ def prepare_multi_query_decode_inputs( num_queries = request.queries.num_tokens query_lens.append(num_queries) input_ids += request.queries.token_ids - positions += [request.num_past_tokens + i for i in range(num_queries)] - if sliding_window: - # TODO(masahi): Verify this code path - seq_lens.append(min(request.num_past_tokens + num_queries, sliding_window)) - num_past = min(request.num_past_tokens, sliding_window) - past_slot_mapping += all_slot_mappings[request.sequence_id][:num_past] - slot_mapping += all_slot_mappings[request.sequence_id][ - num_past : num_past + num_queries + prompt_seq_id = get_prompt_sequence_id(request.sequence_id.request_id) + prompt_slot_mappings = all_slot_mappings[prompt_seq_id] + + if sliding_window and request.num_past_tokens + num_queries >= sliding_window: + seq_lens.append(sliding_window) + prompt_and_decode_slot_mappings = ( + prompt_slot_mappings + all_slot_mappings[request.sequence_id] + ) + past_slot_mapping += prompt_and_decode_slot_mappings[ + request.num_past_tokens + - (sliding_window - num_queries) : request.num_past_tokens + ] + slot_mapping += prompt_and_decode_slot_mappings[ + request.num_past_tokens : request.num_past_tokens + num_queries ] else: seq_lens.append(request.num_past_tokens + num_queries) - prompt_seq_id = get_prompt_sequence_id(request.sequence_id.request_id) - prompt_slot_mappings = all_slot_mappings[prompt_seq_id] if request.num_past_tokens < len(prompt_slot_mappings): raise RuntimeError( @@ -342,7 +346,7 @@ def prepare_multi_query_decode_inputs( ) elif request.num_past_tokens == len(prompt_slot_mappings): # The case for restoring an evicted parallel-sampling request - past_slot_mapping += prompt_slot_mappings[: request.num_past_tokens] + past_slot_mapping += prompt_slot_mappings slot_mapping += all_slot_mappings[request.sequence_id][:num_queries] else: query_begin_offset = request.num_past_tokens - len(prompt_slot_mappings) @@ -629,9 +633,7 @@ def generate_multi_query( if self.disco_session: logits, _ = out.debug_get_from_remote(0) else: - logits = out[ - 0 - ] + logits = out[0] torch.cuda.synchronize() torch.cuda.nvtx.range_pop() From 2f9d1f7ff313aefd1c9e2373e451096ea88fd9d4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 09:44:01 +0000 Subject: [PATCH 27/39] remove dbug print --- serve/mlc_serve/engine/engine_common.py | 5 ----- serve/mlc_serve/model/paged_cache_model.py | 4 ++-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 47cab99166..f94eda0d0e 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -368,7 +368,6 @@ def evict_request(self) -> int: # Must be called with the queue lock held num_eviction = 0 - # print("self.cache_manager.get_max_new_tokens()", self.cache_manager.get_max_new_tokens()) while self.cache_manager.get_max_new_tokens() < 1: num_eviction += 1 @@ -389,7 +388,6 @@ def evict_request(self) -> int: request_to_remove = min(candidate_victims, key=lambda s: s.num_total_tokens) - print("Evicting", request_to_remove.request_id) self.remove_request_from_batch(request_to_remove.request_id) request_to_remove.is_prefilled = False self.queue.appendleft(request_to_remove) @@ -443,9 +441,6 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: ] ) - if prev_generated_token_counts > 0: - print("Restoring", state.request_id) - # Restoring an evicted parallel-sampling request with sliding-window attention is # difficult to reason about, so we use crude upper bounds below for now. num_tokens = state.prompt_len diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index fe3f1385df..90842bb278 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -298,7 +298,7 @@ def _pad_to_max(x: List[int], max_len: int) -> List[int]: ) -def prepare_multi_query_decode_inputs( +def _prepare_multi_query_decode_inputs( requests: List[MultiQueryDecodeRequest], all_slot_mappings, sliding_window, @@ -598,7 +598,7 @@ def generate_multi_query( query_lens, past_slot_mapping, permute_map, - ) = prepare_multi_query_decode_inputs( + ) = _prepare_multi_query_decode_inputs( requests, cache.slot_mappings, None, From 3a9f6d6eb2e2b796a6a02ff2ab8c19d023c36fee Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 10:04:52 +0000 Subject: [PATCH 28/39] clean and fix --- serve/mlc_serve/model/paged_cache_model.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 90842bb278..da1841b577 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -469,7 +469,7 @@ def sample_from_logits( requests: Union[ List[PrefillRequest], List[DecodeRequest], List[MultiQueryDecodeRequest] ], - ): + ) -> List[TextGenerationResult]: assert logits.shape[0] == len(requests) sampling_params = [req.sampling_params for req in requests] @@ -616,8 +616,6 @@ def generate_multi_query( past_slot_mapping = copy_to_worker_0(self.disco_session, past_slot_mapping) permute_map = copy_to_worker_0(self.disco_session, permute_map) - # TODO(masahi): sliding window - out = self.mod["evaluate_multi_query"]( input_ids, positions, @@ -812,13 +810,31 @@ def generate( r for r in requests if isinstance(r, MultiQueryDecodeRequest) ] + multi_query_decode_requests = [] + multi_query_decode_request_ids = set() + + for r in requests: + if isinstance(r, MultiQueryDecodeRequest): + multi_query_decode_requests.append(r) + multi_query_decode_request_ids.add(r.sequence_id.request_id) + out = [] + if prefill_requests: prefill_res = self.model.generate(prefill_requests, kv_cache) + if not multi_query_decode_requests: out.extend(prefill_res) + else: + # Prefill requests from restoration of evicted parallel-sampling requests + # must not return outputs. + for res in prefill_res: + if res.sequence_id.request_id not in multi_query_decode_request_ids: + out.append(res) + if decode_requests: out.extend(self.model.generate(decode_requests, kv_cache)) + if multi_query_decode_requests: out.extend(self.model.generate(multi_query_decode_requests, kv_cache)) From 9fb9261b45b506cf0c8bd0ea61580a5d96824b7d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 11 Jan 2024 10:10:08 +0000 Subject: [PATCH 29/39] mypy --- serve/mlc_serve/engine/engine_common.py | 6 ++++-- serve/mlc_serve/engine/model_module.py | 2 +- serve/mlc_serve/model/paged_cache_model.py | 8 ++++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index f94eda0d0e..97481dc9ba 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -179,8 +179,10 @@ def update_sequence( def get_requests_to_process( current_states: list[RequestState], cache_manager: KVCacheManager -) -> Tuple[list[Union[PrefillRequest, DecodeRequest]], bool, int]: - requests: list[Union[PrefillRequest, DecodeRequest]] = [] +) -> Tuple[ + list[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], bool, int +]: + requests: list[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]] = [] # TODO: consider having hybrid batch if the underlying attention kernel supports # mixing prefill and decode. is_prompt_batch = any(not state.is_prefilled for state in current_states) diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index a224245be0..912eb5ab44 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -143,7 +143,7 @@ class TextGenerator(Protocol): def generate( self, - requests: List[Union[PrefillRequest, DecodeRequest]], + requests: List[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], kv_cache: KVCache, ) -> List[TextGenerationResult]: """ diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index da1841b577..289f78ee6c 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -486,7 +486,7 @@ def sample_from_logits( requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequence): + for seq_id in range(requests[i].num_sequence): # type: ignore outputs.append( TextGenerationResult( sequence_id=SequenceId(sequence_id.request_id, seq_id), @@ -532,7 +532,7 @@ def sample_from_logits( requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequence): + for seq_id in range(requests[i].num_sequence): # type: ignore outputs.append( TextGenerationResult( sequence_id=SequenceId( @@ -553,7 +553,7 @@ def sample_from_logits( else: if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequence): + for seq_id in range(requests[i].num_sequence): # type: ignore outputs.append( TextGenerationResult( sequence_id=SequenceId( @@ -578,7 +578,7 @@ def generate_multi_query( self, requests: List[MultiQueryDecodeRequest], cache: KVCache ) -> List[TextGenerationResult]: sequence_ids = [] - last_query_offsets = [] + last_query_offsets: List[int] = [] for request in requests: assert not isinstance(request.queries, DraftTokens) sequence_ids.append(request.sequence_id) From 906b23b6cc1b2fc77f2a6aab079635db33030be9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 13 Jan 2024 07:44:59 +0000 Subject: [PATCH 30/39] generate signature update --- serve/mlc_serve/engine/model_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 912eb5ab44..a7a72bd074 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -143,8 +143,8 @@ class TextGenerator(Protocol): def generate( self, - requests: List[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], - kv_cache: KVCache, + requests: Sequence[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], + kv_cache, ) -> List[TextGenerationResult]: """ A unified entrypoint for text generation. From b197e711557fa5de43683b9c5cd789f660dd21c4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 13 Jan 2024 07:53:42 +0000 Subject: [PATCH 31/39] more --- serve/mlc_serve/model/tvm_model.py | 296 +++++++++++++++++++---------- 1 file changed, 193 insertions(+), 103 deletions(-) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index cb5683a5c9..bb2ca1c059 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -14,6 +14,7 @@ from .model_common import ( sample, prepare_inputs, + prepare_multi_query_decode_inputs, get_num_cache_blocks, ) @@ -26,6 +27,8 @@ from ..engine.model_module import ( DecodeRequest, PrefillRequest, + DraftTokens, + MultiQueryDecodeRequest, TextGenerationResult, TextGenerator, ) @@ -201,32 +204,214 @@ def profile_memory_usage(self, seq_lens): return self.get_used_memory() + def sample_from_logits( + self, + logits: Union[tvm.nd.NDArray, torch.Tensor], + sequence_ids: List[SequenceId], + requests: Union[ + List[PrefillRequest], List[DecodeRequest], List[MultiQueryDecodeRequest] + ], + ) -> List[TextGenerationResult]: + assert logits.shape[0] == len(requests) + + sampling_params = [req.sampling_params for req in requests] + + try: + next_tokens = sample(logits, sampling_params, self.vocab_size) + assert next_tokens is not None + outputs = [] + for i, (sequence_id, new_token) in enumerate( + zip(sequence_ids, next_tokens) + ): + if not new_token in sampling_params[i].appeared_tokens_freq: + requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + assert isinstance(requests[i], PrefillRequest) + for seq_id in range(requests[i].num_sequence): # type: ignore + outputs.append( + TextGenerationResult( + sequence_id=SequenceId(sequence_id.request_id, seq_id), + generated_tokens=[new_token], + error=None, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[new_token], + error=None, + ) + ) + + return outputs + except RuntimeError: + # Fallback to per-token sampling in case some logits values are corrupted. + outputs = [] + err_msg = ( + "Error from sampling: probability tensor contains either `inf`, `nan`" + " or element < 0" + ) + + for i, (sequence_id, logits_per_token, sampling_param) in enumerate( + zip(sequence_ids, torch.from_dlpack(logits), sampling_params) + ): + maybe_new_token = sample( + torch.unsqueeze(logits_per_token, 0), + [sampling_param], + self.vocab_size, + check_safety=True, + ) + + if maybe_new_token is not None: + new_token = maybe_new_token[0] + if ( + not new_token + in requests[i].sampling_params.appeared_tokens_freq + ): + requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + assert isinstance(requests[i], PrefillRequest) + for seq_id in range(requests[i].num_sequence): # type: ignore + outputs.append( + TextGenerationResult( + sequence_id=SequenceId( + sequence_id.request_id, seq_id + ), + generated_tokens=[new_token], # type: ignore + error=None, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[new_token], # type: ignore + error=None, + ) + ) + else: + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + assert isinstance(requests[i], PrefillRequest) + for seq_id in range(requests[i].num_sequence): # type: ignore + outputs.append( + TextGenerationResult( + sequence_id=SequenceId( + sequence_id.request_id, seq_id + ), + generated_tokens=[], + error=err_msg, + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[], + error=err_msg, + ) + ) + + return outputs + + def generate_multi_query( + self, requests: List[MultiQueryDecodeRequest], cache: KVCache + ) -> List[TextGenerationResult]: + sequence_ids = [] + last_query_offsets: List[int] = [] + for request in requests: + assert not isinstance(request.queries, DraftTokens) + sequence_ids.append(request.sequence_id) + + if len(last_query_offsets) == 0: + last_query_offsets.append(request.queries.num_tokens - 1) + else: + last_query_offsets.append( + last_query_offsets[-1] + request.queries.num_tokens + ) + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + ) = prepare_multi_query_decode_inputs( + requests, + cache.slot_mappings, + None, + self.dev, + ) + + torch.cuda.nvtx.range_push(f"forward multi-query decode {input_ids.shape}") + + if self.disco_session: + input_ids = copy_to_worker_0(self.disco_session, input_ids) + positions = copy_to_worker_0(self.disco_session, positions) + seq_lens = copy_to_worker_0(self.disco_session, seq_lens) + slot_mapping = copy_to_worker_0(self.disco_session, slot_mapping) + query_lens = copy_to_worker_0(self.disco_session, query_lens) + past_slot_mapping = copy_to_worker_0(self.disco_session, past_slot_mapping) + permute_map = copy_to_worker_0(self.disco_session, permute_map) + + out = self.mod["evaluate_multi_query"]( + input_ids, + positions, + seq_lens, + cache.cache_blocks, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + self.params, + ) + + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[0] + + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + last_query_logits = torch.from_dlpack(logits)[last_query_offsets] + + return self.sample_from_logits(last_query_logits, sequence_ids, requests) + def generate( self, - requests: Sequence[Union[PrefillRequest, DecodeRequest]], + requests: Union[ + List[PrefillRequest], List[DecodeRequest], List[MultiQueryDecodeRequest] + ], cache: KVCache, ) -> List[TextGenerationResult]: if len(requests) == 0: return [] is_prefill = isinstance(requests[0], PrefillRequest) + is_multi_query_decode = isinstance(requests[0], MultiQueryDecodeRequest) + if is_multi_query_decode: + return self.generate_multi_query(requests, cache) # type: ignore + + # Prefill or decode all_token_ids = [] - sampling_params = [] sequence_ids = [] prompt_lens = [] - num_sequences = [] for request in requests: if isinstance(request, PrefillRequest): sequence_ids.append(get_prompt_sequence_id(request.request_id)) - num_sequences.append(request.num_sequence) - else: + elif isinstance(request, DecodeRequest): sequence_ids.append(request.sequence_id) prompt_lens.append(request.prompt_token_counts) + assert not isinstance(request, MultiQueryDecodeRequest) all_token_ids.append(request.token_ids) - sampling_params.append(request.sampling_params) ( input_ids, @@ -235,7 +420,7 @@ def generate( slot_mapping, indices_within_window, block_tables, - ) = _prepare_inputs( + ) = prepare_inputs( sequence_ids, all_token_ids, prompt_lens, @@ -327,102 +512,7 @@ def generate( self.copy_cache_blocks_func(cache.cache_blocks, block_mapping) cache.pending_copy_from_to = [] - try: - next_tokens = sample(logits, sampling_params, self.vocab_size) - assert next_tokens is not None - outputs = [] - for i, (sequence_id, new_token) in enumerate( - zip(sequence_ids, next_tokens) - ): - if not new_token in requests[i].sampling_params.appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[new_token], - error=None, - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], - error=None, - ) - ) - - return outputs - except RuntimeError: - # Fallback to per-token sampling in case some logits values are corrupted. - outputs = [] - err_msg = ( - "Error from sampling: probability tensor contains either `inf`, `nan`" - " or element < 0" - ) - - for i, (sequence_id, logits_per_token, sampling_param) in enumerate( - zip(sequence_ids, torch.from_dlpack(logits), sampling_params) - ): - maybe_new_token = sample( - torch.unsqueeze(logits_per_token, 0), - [sampling_param], - self.vocab_size, - check_safety=True, - ) - - if maybe_new_token is not None: - new_token = maybe_new_token[0] - if ( - not new_token - in requests[i].sampling_params.appeared_tokens_freq - ): - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId( - sequence_id.request_id, seq_id - ), - generated_tokens=[new_token], # type: ignore - error=None, - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], # type: ignore - error=None, - ) - ) - else: - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - for seq_id in range(num_sequences[i]): - outputs.append( - TextGenerationResult( - sequence_id=SequenceId( - sequence_id.request_id, seq_id - ), - generated_tokens=[], - error=err_msg, - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[], - error=err_msg, - ) - ) - - return outputs + return self.sample_from_logits(logits, sequence_ids, requests) def init_tvm_model( From 2dfa28d0761d43d3e6b0a53f921161410b79a5cf Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 13 Jan 2024 07:59:00 +0000 Subject: [PATCH 32/39] fix mypy --- serve/mlc_serve/model/tvm_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index bb2ca1c059..07ccccf336 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -208,8 +208,8 @@ def sample_from_logits( self, logits: Union[tvm.nd.NDArray, torch.Tensor], sequence_ids: List[SequenceId], - requests: Union[ - List[PrefillRequest], List[DecodeRequest], List[MultiQueryDecodeRequest] + requests: Sequence[ + Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest] ], ) -> List[TextGenerationResult]: assert logits.shape[0] == len(requests) @@ -384,8 +384,8 @@ def generate_multi_query( def generate( self, - requests: Union[ - List[PrefillRequest], List[DecodeRequest], List[MultiQueryDecodeRequest] + requests: Sequence[ + Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest] ], cache: KVCache, ) -> List[TextGenerationResult]: From e287c5fd6993a4b609298078367e8183b0a91ca9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 13 Jan 2024 08:10:31 +0000 Subject: [PATCH 33/39] fix --- serve/mlc_serve/model/tvm_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 07ccccf336..f2abecf387 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -420,7 +420,7 @@ def generate( slot_mapping, indices_within_window, block_tables, - ) = prepare_inputs( + ) = _prepare_inputs( sequence_ids, all_token_ids, prompt_lens, From c925c52e9228aad2c7bc4487366164cdf564db24 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 31 Jan 2024 22:40:03 +0000 Subject: [PATCH 34/39] fix --- serve/mlc_serve/model/paged_cache_manager.py | 4 ++-- serve/mlc_serve/model/tvm_model.py | 18 +++++++++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_manager.py b/serve/mlc_serve/model/paged_cache_manager.py index ef38033288..40ee1b4c68 100644 --- a/serve/mlc_serve/model/paged_cache_manager.py +++ b/serve/mlc_serve/model/paged_cache_manager.py @@ -179,8 +179,8 @@ def set_size(self, sequence_ids: List[SequenceId], target_sizes: List[int]): prompt_seq_id = get_prompt_sequence_id(id.request_id) allocated_slot_counts = len( - self.kv_cache.slot_mappings[prompt_seq_id] - ) + len(self.kv_cache.slot_mappings[id]) + self.kv_cache_info.slot_mappings[prompt_seq_id] + ) + len(self.kv_cache_info.slot_mappings[id]) for current_size in range(allocated_slot_counts + 1, size + 1): pos = current_size - 1 diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 4c668ddec8..b186814ce1 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -229,7 +229,9 @@ def sample_from_logits( sampling_params = [req.sampling_params for req in requests] try: - next_tokens, logprob_infos = sample(logits, sampling_params, self.vocab_size) + next_tokens, logprob_infos = sample( + logits, sampling_params, self.vocab_size + ) assert next_tokens is not None outputs = [] for i, (sequence_id, new_token) in enumerate( @@ -296,7 +298,9 @@ def sample_from_logits( ), generated_tokens=[new_token], # type: ignore error=None, - logprob_info=self.get_logprob_infos(0, logprob_infos), + logprob_info=self.get_logprob_infos( + 0, logprob_infos + ), ) ) else: @@ -319,7 +323,9 @@ def sample_from_logits( ), generated_tokens=[], error=err_msg, - logprob_info=self.get_logprob_infos(0, logprob_infos), + logprob_info=self.get_logprob_infos( + 0, logprob_infos + ), ) ) else: @@ -335,7 +341,9 @@ def sample_from_logits( return outputs def generate_multi_query( - self, requests: List[MultiQueryDecodeRequest], cache: KVCache + self, + requests: List[MultiQueryDecodeRequest], + cache: KVCacheInfo, ) -> List[TextGenerationResult]: sequence_ids = [] last_query_offsets: List[int] = [] @@ -405,7 +413,7 @@ def generate( requests: Sequence[ Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest] ], - cache: KVCache, + cache: KVCacheInfo, ) -> List[TextGenerationResult]: if len(requests) == 0: return [] From a4d6e01c2c5e614d77e09d4045ce28e946e01ddc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 31 Jan 2024 22:44:50 +0000 Subject: [PATCH 35/39] mypy fix --- serve/mlc_serve/engine/base.py | 4 ++-- serve/mlc_serve/model/tvm_model.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index b66dea3479..a689e43888 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -19,8 +19,8 @@ class RawLogprobsInfo: current_token_id: int current_logprob: float - top_token_ids: Optional[np.array] - top_logprobs: Optional[np.array] + top_token_ids: Optional[np.ndarray] + top_logprobs: Optional[np.ndarray] RawLogprobsInfos = List[Optional[RawLogprobsInfo]] diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index b186814ce1..006541bab3 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -388,7 +388,7 @@ def generate_multi_query( input_ids, positions, seq_lens, - cache.cache_blocks, + self.cache_blocks, slot_mapping, query_lens, past_slot_mapping, From 5dbf73e8d9ffc74b67b84aa609c65db8226577c4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Feb 2024 08:16:23 +0000 Subject: [PATCH 36/39] refactor --- serve/mlc_serve/model/model_common.py | 126 ++++++++++++++++++++++- serve/mlc_serve/model/tvm_model.py | 140 ++------------------------ 2 files changed, 129 insertions(+), 137 deletions(-) diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 175410a990..6e0b5aeb53 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Sequence import structlog import numpy as np @@ -13,8 +13,16 @@ LOGPROB_TOP_K_MAX, RawLogprobsInfo, RawLogprobsInfos, + PROMPT_SEQEUNCE_INDEX, + RawLogprobsInfos, + SequenceId, +) +from ..engine.model_module import ( + DecodeRequest, + PrefillRequest, + MultiQueryDecodeRequest, + TextGenerationResult, ) -from ..engine.model_module import MultiQueryDecodeRequest LOG = structlog.stdlib.get_logger(__name__) @@ -67,8 +75,8 @@ def get_raw_logprob_info( top_logprobs, top_tokens = torch.topk( logprobs, k=top_logprobs_num, dim=-1, largest=True, sorted=True ) - top_tokens=top_tokens.cpu().numpy() - top_logprobs=top_logprobs.cpu().numpy() + top_tokens = top_tokens.cpu().numpy() + top_logprobs = top_logprobs.cpu().numpy() # Set to raw logprob info return RawLogprobsInfo( @@ -108,7 +116,7 @@ def get_raw_logprob_infos( logits: torch.Tensor, token_ids: torch.Tensor, ) -> RawLogprobsInfos: - for (i, ind, top_logprobs) in indices: + for i, ind, top_logprobs in indices: logprob_infos[i] = get_raw_logprob_info( logits[ind], token_ids[ind], @@ -294,6 +302,114 @@ def _is_safe_to_sample(prob_like): return res, check_logprob_infos(logprob_infos) +def sample_from_logits( + logits: Union[tvm.nd.NDArray, torch.Tensor], + sequence_ids: List[SequenceId], + requests: Sequence[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], + vocab_size, +) -> List[TextGenerationResult]: + assert logits.shape[0] == len(requests) + + sampling_params = [req.sampling_params for req in requests] + + try: + next_tokens, logprob_infos = sample(logits, sampling_params, vocab_size) + assert next_tokens is not None + outputs = [] + for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)): + if not new_token in sampling_params[i].appeared_tokens_freq: + requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + assert isinstance(requests[i], PrefillRequest) + for seq_id in range(requests[i].num_sequence): # type: ignore + outputs.append( + TextGenerationResult( + sequence_id=SequenceId(sequence_id.request_id, seq_id), + generated_tokens=[new_token], + error=None, + logprob_info=get_logprob_infos(i, logprob_infos), + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[new_token], + error=None, + logprob_info=get_logprob_infos(i, logprob_infos), + ) + ) + + return outputs + except RuntimeError: + # Fallback to per-token sampling in case some logits values are corrupted. + outputs = [] + err_msg = ( + "Error from sampling: probability tensor contains either `inf`, `nan`" + " or element < 0" + ) + + for i, (sequence_id, logits_per_token, sampling_param) in enumerate( + zip(sequence_ids, torch.from_dlpack(logits), sampling_params) + ): + maybe_new_token, logprob_infos = sample( + torch.unsqueeze(logits_per_token, 0), + [sampling_param], + vocab_size, + check_safety=True, + ) + + if maybe_new_token is not None: + new_token = maybe_new_token[0] + if not new_token in requests[i].sampling_params.appeared_tokens_freq: + requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 + requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + assert isinstance(requests[i], PrefillRequest) + for seq_id in range(requests[i].num_sequence): # type: ignore + outputs.append( + TextGenerationResult( + sequence_id=SequenceId(sequence_id.request_id, seq_id), + generated_tokens=[new_token], # type: ignore + error=None, + logprob_info=get_logprob_infos(0, logprob_infos), + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[new_token], # type: ignore + error=None, + logprob_info=get_logprob_infos(0, logprob_infos), + ) + ) + else: + if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: + assert isinstance(requests[i], PrefillRequest) + for seq_id in range(requests[i].num_sequence): # type: ignore + outputs.append( + TextGenerationResult( + sequence_id=SequenceId(sequence_id.request_id, seq_id), + generated_tokens=[], + error=err_msg, + logprob_info=get_logprob_infos(0, logprob_infos), + ) + ) + else: + outputs.append( + TextGenerationResult( + sequence_id=sequence_id, + generated_tokens=[], + error=err_msg, + logprob_info=get_logprob_infos(0, logprob_infos), + ) + ) + + return outputs + + def prepare_inputs( sequence_ids, all_token_ids, diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index a13433e7d6..53fa02da74 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -1,6 +1,6 @@ import math import os -from typing import List, Optional, Union, Tuple, Sequence +from typing import List, Union, Tuple, Sequence import structlog import numpy as np @@ -12,17 +12,13 @@ from .base import ModelArtifactConfig from .paged_cache_manager import KVCacheInfo, CacheManager from .model_common import ( - sample, + sample_from_logits, prepare_inputs, prepare_multi_query_decode_inputs, - get_logprob_infos, get_num_cache_blocks, ) from ..engine import ( - PROMPT_SEQEUNCE_INDEX, - RawLogprobsInfos, - SequenceId, get_prompt_sequence_id, MLCServeEngineConfig, ) @@ -208,130 +204,6 @@ def profile_memory_usage(self, seq_lens): return self.get_used_memory() - def sample_from_logits( - self, - logits: Union[tvm.nd.NDArray, torch.Tensor], - sequence_ids: List[SequenceId], - requests: Sequence[ - Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest] - ], - ) -> List[TextGenerationResult]: - assert logits.shape[0] == len(requests) - - sampling_params = [req.sampling_params for req in requests] - - try: - next_tokens, logprob_infos = sample( - logits, sampling_params, self.vocab_size - ) - assert next_tokens is not None - outputs = [] - for i, (sequence_id, new_token) in enumerate( - zip(sequence_ids, next_tokens) - ): - if not new_token in sampling_params[i].appeared_tokens_freq: - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequence): # type: ignore - outputs.append( - TextGenerationResult( - sequence_id=SequenceId(sequence_id.request_id, seq_id), - generated_tokens=[new_token], - error=None, - logprob_info=get_logprob_infos(i, logprob_infos), - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], - error=None, - logprob_info=get_logprob_infos(i, logprob_infos), - ) - ) - - return outputs - except RuntimeError: - # Fallback to per-token sampling in case some logits values are corrupted. - outputs = [] - err_msg = ( - "Error from sampling: probability tensor contains either `inf`, `nan`" - " or element < 0" - ) - - for i, (sequence_id, logits_per_token, sampling_param) in enumerate( - zip(sequence_ids, torch.from_dlpack(logits), sampling_params) - ): - maybe_new_token, logprob_infos = sample( - torch.unsqueeze(logits_per_token, 0), - [sampling_param], - self.vocab_size, - check_safety=True, - ) - - if maybe_new_token is not None: - new_token = maybe_new_token[0] - if ( - not new_token - in requests[i].sampling_params.appeared_tokens_freq - ): - requests[i].sampling_params.appeared_tokens_freq[new_token] = 0 - requests[i].sampling_params.appeared_tokens_freq[new_token] += 1 - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequence): # type: ignore - outputs.append( - TextGenerationResult( - sequence_id=SequenceId( - sequence_id.request_id, seq_id - ), - generated_tokens=[new_token], # type: ignore - error=None, - logprob_info=get_logprob_infos( - 0, logprob_infos - ), - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[new_token], # type: ignore - error=None, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) - else: - if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX: - assert isinstance(requests[i], PrefillRequest) - for seq_id in range(requests[i].num_sequence): # type: ignore - outputs.append( - TextGenerationResult( - sequence_id=SequenceId( - sequence_id.request_id, seq_id - ), - generated_tokens=[], - error=err_msg, - logprob_info=get_logprob_infos( - 0, logprob_infos - ), - ) - ) - else: - outputs.append( - TextGenerationResult( - sequence_id=sequence_id, - generated_tokens=[], - error=err_msg, - logprob_info=get_logprob_infos(0, logprob_infos), - ) - ) - - return outputs - def generate_multi_query( self, requests: List[MultiQueryDecodeRequest], @@ -376,6 +248,8 @@ def generate_multi_query( past_slot_mapping = copy_to_worker_0(self.disco_session, past_slot_mapping) permute_map = copy_to_worker_0(self.disco_session, permute_map) + print("evaluate_multi_query") + out = self.mod["evaluate_multi_query"]( input_ids, positions, @@ -398,7 +272,9 @@ def generate_multi_query( last_query_logits = torch.from_dlpack(logits)[last_query_offsets] - return self.sample_from_logits(last_query_logits, sequence_ids, requests) + return sample_from_logits( + last_query_logits, sequence_ids, requests, self.vocab_size + ) def generate( self, @@ -525,7 +401,7 @@ def generate( self.copy_cache_blocks_func(self.cache_blocks, block_mapping) cache.pending_copy_from_to = [] - return self.sample_from_logits(logits, sequence_ids, requests) + return sample_from_logits(logits, sequence_ids, requests, self.vocab_size) def init_tvm_model( From 78a6f7724239b7d1da3506a1ed92e334954a41cc Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Feb 2024 08:25:49 +0000 Subject: [PATCH 37/39] fix --- serve/mlc_serve/model/paged_cache_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 1a5683763a..70e3200dbb 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -29,17 +29,17 @@ def generate( requests: list[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], kv_cache, ) -> list[TextGenerationResult]: - prefill_requests = [r for r in requests if isinstance(r, PrefillRequest)] - decode_requests = [r for r in requests if isinstance(r, DecodeRequest)] - multi_query_decode_requests = [ - r for r in requests if isinstance(r, MultiQueryDecodeRequest) - ] - + prefill_requests = [] + decode_requests = [] multi_query_decode_requests = [] multi_query_decode_request_ids = set() for r in requests: - if isinstance(r, MultiQueryDecodeRequest): + if isinstance(r, PrefillRequest): + prefill_requests.append(r) + elif isinstance(r, DecodeRequest): + decode_requests.append(r) + elif isinstance(r, MultiQueryDecodeRequest): multi_query_decode_requests.append(r) multi_query_decode_request_ids.add(r.sequence_id.request_id) From 9189697ddf85813d3d28bbe09cc89cdf12194325 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Feb 2024 08:32:12 +0000 Subject: [PATCH 38/39] rename --- serve/mlc_serve/engine/engine_common.py | 10 +++++----- serve/mlc_serve/engine/model_module.py | 4 ++-- serve/mlc_serve/model/model_common.py | 8 ++++---- serve/mlc_serve/model/paged_cache_model.py | 6 +++--- serve/mlc_serve/model/tvm_model.py | 10 +++++----- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index 784afb76eb..af9dfb9da0 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -22,7 +22,7 @@ from .model_module import ( DecodeRequest, PrefillRequest, - MultiQueryDecodeRequest, + EvalMultiQueryRequest, EvictedTokens, ConversationTemplate, KVCacheManager, @@ -229,9 +229,9 @@ def update_sequence( def get_requests_to_process( current_states: list[RequestState], cache_manager: KVCacheManager ) -> Tuple[ - list[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], bool, int + list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], bool, int ]: - requests: list[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]] = [] + requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]] = [] # TODO: consider having hybrid batch if the underlying attention kernel supports # mixing prefill and decode. is_prompt_batch = any(not state.is_prefilled for state in current_states) @@ -263,7 +263,7 @@ def get_requests_to_process( for gen_seq in state.generation_sequences: requests.append( - MultiQueryDecodeRequest( + EvalMultiQueryRequest( sequence_id=gen_seq.seq_id, num_past_tokens=state.prompt_len, queries=EvictedTokens(gen_seq.generated_token_ids), @@ -275,7 +275,7 @@ def get_requests_to_process( len(gen_seq.generated_token_ids) + 1, ) - # TODO(masahi): How to account for token counts in MultiQueryDecodeRequest in + # TODO(masahi): How to account for token counts in EvalMultiQueryRequest in # Prometheus metric? elif not state.is_prefilled: token_ids = state.prompt_token_ids diff --git a/serve/mlc_serve/engine/model_module.py b/serve/mlc_serve/engine/model_module.py index 9594914230..00893efa44 100644 --- a/serve/mlc_serve/engine/model_module.py +++ b/serve/mlc_serve/engine/model_module.py @@ -59,7 +59,7 @@ def num_tokens(self): @dataclass -class MultiQueryDecodeRequest: +class EvalMultiQueryRequest: sequence_id: SequenceId num_past_tokens: int queries: Union[DraftTokens, EvictedTokens] @@ -151,7 +151,7 @@ class TextGenerator(Protocol): def generate( self, - requests: Sequence[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], + requests: Sequence[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], kv_cache, ) -> List[TextGenerationResult]: """ diff --git a/serve/mlc_serve/model/model_common.py b/serve/mlc_serve/model/model_common.py index 6e0b5aeb53..fee952ad4d 100644 --- a/serve/mlc_serve/model/model_common.py +++ b/serve/mlc_serve/model/model_common.py @@ -20,7 +20,7 @@ from ..engine.model_module import ( DecodeRequest, PrefillRequest, - MultiQueryDecodeRequest, + EvalMultiQueryRequest, TextGenerationResult, ) @@ -305,7 +305,7 @@ def _is_safe_to_sample(prob_like): def sample_from_logits( logits: Union[tvm.nd.NDArray, torch.Tensor], sequence_ids: List[SequenceId], - requests: Sequence[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], + requests: Sequence[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], vocab_size, ) -> List[TextGenerationResult]: assert logits.shape[0] == len(requests) @@ -494,7 +494,7 @@ def _pad_to_max(x: List[int], max_len: int) -> List[int]: def prepare_multi_query_decode_inputs( - requests: List[MultiQueryDecodeRequest], + requests: List[EvalMultiQueryRequest], all_slot_mappings, sliding_window, dev, @@ -536,7 +536,7 @@ def prepare_multi_query_decode_inputs( if request.num_past_tokens < len(prompt_slot_mappings): raise RuntimeError( - "For MultiQueryDecodeRequest, the number of past tokens" + "For EvalMultiQueryRequest, the number of past tokens" "smaller than the prompt length is not supported for now." ) elif request.num_past_tokens == len(prompt_slot_mappings): diff --git a/serve/mlc_serve/model/paged_cache_model.py b/serve/mlc_serve/model/paged_cache_model.py index 70e3200dbb..0b16ab0b3c 100644 --- a/serve/mlc_serve/model/paged_cache_model.py +++ b/serve/mlc_serve/model/paged_cache_model.py @@ -12,7 +12,7 @@ DecodeRequest, ModelModule, PrefillRequest, - MultiQueryDecodeRequest, + EvalMultiQueryRequest, TextGenerationResult, TextGenerator, ) @@ -26,7 +26,7 @@ def __init__(self, model: TextGenerator): def generate( self, - requests: list[Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest]], + requests: list[Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]], kv_cache, ) -> list[TextGenerationResult]: prefill_requests = [] @@ -39,7 +39,7 @@ def generate( prefill_requests.append(r) elif isinstance(r, DecodeRequest): decode_requests.append(r) - elif isinstance(r, MultiQueryDecodeRequest): + elif isinstance(r, EvalMultiQueryRequest): multi_query_decode_requests.append(r) multi_query_decode_request_ids.add(r.sequence_id.request_id) diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 53fa02da74..10757f722d 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -26,7 +26,7 @@ DecodeRequest, PrefillRequest, DraftTokens, - MultiQueryDecodeRequest, + EvalMultiQueryRequest, TextGenerationResult, TextGenerator, ) @@ -206,7 +206,7 @@ def profile_memory_usage(self, seq_lens): def generate_multi_query( self, - requests: List[MultiQueryDecodeRequest], + requests: List[EvalMultiQueryRequest], cache: KVCacheInfo, ) -> List[TextGenerationResult]: sequence_ids = [] @@ -279,7 +279,7 @@ def generate_multi_query( def generate( self, requests: Sequence[ - Union[PrefillRequest, DecodeRequest, MultiQueryDecodeRequest] + Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest] ], cache: KVCacheInfo, ) -> List[TextGenerationResult]: @@ -287,7 +287,7 @@ def generate( return [] is_prefill = isinstance(requests[0], PrefillRequest) - is_multi_query_decode = isinstance(requests[0], MultiQueryDecodeRequest) + is_multi_query_decode = isinstance(requests[0], EvalMultiQueryRequest) if is_multi_query_decode: return self.generate_multi_query(requests, cache) # type: ignore @@ -304,7 +304,7 @@ def generate( sequence_ids.append(request.sequence_id) prompt_lens.append(request.prompt_token_counts) - assert not isinstance(request, MultiQueryDecodeRequest) + assert not isinstance(request, EvalMultiQueryRequest) all_token_ids.append(request.token_ids) ( From d4fe2d72d45279706acd99955f7caed09fc74f0f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Feb 2024 09:59:05 +0000 Subject: [PATCH 39/39] Disallow preempting when a request has generated more than max_num_batched_tokens --- serve/mlc_serve/engine/engine_common.py | 27 ++++++++++++++++++- .../mlc_serve/engine/staging_engine_worker.py | 6 ++++- serve/mlc_serve/engine/sync_engine.py | 6 ++++- serve/mlc_serve/model/tvm_model.py | 2 -- 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/serve/mlc_serve/engine/engine_common.py b/serve/mlc_serve/engine/engine_common.py index af9dfb9da0..205f60d7bb 100644 --- a/serve/mlc_serve/engine/engine_common.py +++ b/serve/mlc_serve/engine/engine_common.py @@ -415,7 +415,7 @@ def check_prompt_too_long(self, prompt_len: int, num_sequences: int = 1) -> bool < self.max_decode_steps * num_sequences ) - def evict_request(self) -> int: + def evict_request(self, cancell_callback: Callable[[RequestId], None]) -> int: # Must be called with the queue lock held num_eviction = 0 @@ -438,6 +438,28 @@ def evict_request(self) -> int: candidate_victims = parallel_sample_requests request_to_remove = min(candidate_victims, key=lambda s: s.num_total_tokens) + victim_state = self.current_batch[request_to_remove.request_id] + + if victim_state.num_sequences != 1: + prev_generated_token_counts = sum( + [ + len(gen_seq.generated_token_ids) + for gen_seq in victim_state.generation_sequences + ] + ) + # We could allow evicting and restoring a parallel-sampling request whose prev_generated_token_counts + # is > max_num_batched_tokens, by making the model split a list of EvalMultiQuery requests into parts, + # so that an inference on each part can be done with the max_num_batched_tokens budget. + # But this introduces an undesirable coupling between the engine and the model. + if prev_generated_token_counts >= self.max_num_batched_tokens: + cancell_callback(request_to_remove.request_id) + self.remove_request_from_batch(request_to_remove.request_id) + LOG.warn( + f"Cancelling a parallel-sampling request '{request_to_remove.request_id}'" + f"since it has generated more than {self.max_num_batched_tokens} tokens in total" + "and currently we do not support preempting such request.", + ) + continue self.remove_request_from_batch(request_to_remove.request_id) request_to_remove.is_prefilled = False @@ -499,6 +521,9 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]: # Restoring an evicted parallel-sampling request is done by separate # Prefill and MultiQuery requests. The maximum below is an upper bound on the # batch size increase due to this request. + # TODO(masahi): Prefill and EvalMultiQuery requests are handled separately by the model. + # So comparing the sum of their batched token counts against max_num_batched_tokens + # is not optimal. num_new_batched_tokens += max(state.prompt_len, prev_generated_token_counts) if num_new_batched_tokens > self.max_num_batched_tokens: diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 4957f3e046..6c02c0811c 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -309,7 +309,11 @@ def step(self) -> GenerationLoopWorkerOutput: def _adjust_batch(self): with self.queue_lock: - num_eviction = self.evict_request() + num_eviction = self.evict_request( + cancell_callback=lambda request_id: self.cancelled_requests.append( + self.current_batch[request_id] + ) + ) self.prom_metrics.counter(NUM_CACHE_EVICTONS).inc(num_eviction) if self.cache_manager.get_max_new_tokens() <= self.max_decode_steps: diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index cfb73541c2..c400ec6b4a 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -254,7 +254,11 @@ def _adjust_batch(self): self.cache_manager.free_request(state) self.requests_to_be_cancelled.remove(request_id) - self.evict_request() + self.evict_request( + cancell_callback=lambda request_id: self.requests_to_be_cancelled.add( + request_id + ) + ) self._discard_cancelled_requests_from_queue() diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 10757f722d..202a04e30d 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -248,8 +248,6 @@ def generate_multi_query( past_slot_mapping = copy_to_worker_0(self.disco_session, past_slot_mapping) permute_map = copy_to_worker_0(self.disco_session, permute_map) - print("evaluate_multi_query") - out = self.mod["evaluate_multi_query"]( input_ids, positions,