From 8e5bbb30c3c17a0a1787028c0e787c9ea82ed3c5 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 17 Feb 2025 06:03:36 +0000 Subject: [PATCH 01/19] merge prompt and decode --- examples/offline_inference/basic.py | 13 +- requirements-tpu.txt | 8 +- vllm/model_executor/models/qwen2.py | 20 +- vllm/v1/attention/backends/pallas.py | 283 +++----- vllm/v1/worker/tpu_model_runner.py | 1002 +++++++++----------------- vllm/v1/worker/tpu_worker.py | 3 +- 6 files changed, 459 insertions(+), 870 deletions(-) diff --git a/examples/offline_inference/basic.py b/examples/offline_inference/basic.py index a6e96c0bb433..c110349a0eb9 100644 --- a/examples/offline_inference/basic.py +++ b/examples/offline_inference/basic.py @@ -2,18 +2,19 @@ from vllm import LLM, SamplingParams +# TODO(xw32): should remove the change in this file before merging the PR. # Sample prompts. prompts = [ "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", + # "The president of the United States is", + # "The capital of France is", + # "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams() #temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -21,4 +22,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 1abde714af7c..400e20923d57 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -17,7 +17,7 @@ ray[default] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.6.0.dev20241216+cpu -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.7.0.dev20250212+cpu +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250212+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250212+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250212+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e3de6b64fbb3..a7da23987be0 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -258,15 +258,17 @@ def forward( return hidden_states, residual -@support_torch_compile( - dynamic_arg_dims={ - "input_ids": 0, - # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, - # otherwise (seq_len, ). - "positions": -1, - "intermediate_tensors": 0, - "inputs_embeds": 0, - }) +# TODO(xw32): revert the change before merging the code. +# xw32 turns off dynamo +# @support_torch_compile( +# dynamic_arg_dims={ +# "input_ids": 0, +# # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, +# # otherwise (seq_len, ). +# "positions": -1, +# "intermediate_tensors": 0, +# "inputs_embeds": 0, +# }) class Qwen2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 37bf33f6e3e9..b417180ede55 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -53,6 +53,7 @@ def copy_blocks( kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], src_to_dists: Tuple[torch.Tensor, torch.Tensor], ) -> None: + assert False, "I assume this PallasAttentionBackend.copy_blocks function should not be used. But I could be wrong." # TODO(xw32): remove src_indices, dst_indices = src_to_dists for k_cache, v_cache in kv_caches: torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) @@ -61,33 +62,52 @@ def copy_blocks( v_cache[:, dst_indices] = v_cache[:, src_indices] +# Why remove the base class AttentionMetadata. Because the base requires +# to set num_prefills, num_prefill_tokens, and num_decode_tokens. @dataclass -class PallasMetadata(AttentionMetadata): +class PallasMetadata(): + # old begins # Currently, input sequences can only contain all prefills # or all decoding. - block_tables: Optional[torch.Tensor] = None - context_lens: Optional[torch.Tensor] = None - effective_query_lens: Optional[torch.Tensor] = None - - @property - def prefill_metadata(self) -> Optional["PallasMetadata"]: - if self.num_prefills == 0: - return None - - assert self.num_decode_tokens == 0 - return self - - @property - def decode_metadata(self) -> Optional["PallasMetadata"]: - if self.num_decode_tokens == 0: - return None - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.block_tables is not None - assert self.context_lens is not None - return self + # block_tables: Optional[torch.Tensor] = None + # context_lens: Optional[torch.Tensor] = None + # effective_query_lens: Optional[torch.Tensor] = None + + # @property + # def prefill_metadata(self) -> Optional["PallasMetadata"]: + # if self.num_prefills == 0: + # return None + + # assert self.num_decode_tokens == 0 + # return self + + # @property + # def decode_metadata(self) -> Optional["PallasMetadata"]: + # if self.num_decode_tokens == 0: + # return None + + # assert self.num_prefills == 0 + # assert self.num_prefill_tokens == 0 + # assert self.block_tables is not None + # assert self.context_lens is not None + # return self + # old ends + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Used in the PallasAttentionBackendImpl + slot_mapping: torch.Tensor + block_tables: torch.Tensor + context_lens: torch.Tensor + query_start_loc: torch.Tensor + num_seqs: int class PallasAttentionBackendImpl(AttentionImpl): @@ -105,10 +125,13 @@ def __init__( logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, ) -> None: + if blocksparse_params is not None: + raise ValueError( + "Paged attention Pallas kernel does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -126,25 +149,6 @@ def __init__( raise NotImplementedError( "Attention logits soft-capping is not supported.") - if torch_xla.tpu.version() < 4: - raise NotImplementedError("TPU version must be 4 or higher.") - - self.megacore_mode = None - tpu_env = torch_xla.tpu.get_tpu_env() - tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) - or tpu_env.get("TYPE", None) - or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) - assert tpu_type is not None - tpu_type = tpu_type.lower() - - if (("lite" not in tpu_type) and ("v6" not in tpu_type)): - if self.num_kv_heads % 2 == 0: - self.megacore_mode = "kv_head" - else: - # NOTE(woosuk): If the batch size is not a multiple of 2, the - # megacore mode will be None. - self.megacore_mode = "batch" - if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " @@ -164,17 +168,15 @@ def forward( """Forward pass with Pallas attention. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] - kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] - kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] - NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor - with shape [0] for profiling run. + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_kv_heads, num_blocks, block_size, head_size] attn_metadata: Metadata for attention. Returns: - shape = [batch_size, seq_len, num_heads * head_size] + shape = [num_tokens, num_heads * head_size] """ + print(f'xw32 PallasAttentionBackendImpl.forward begins {query.shape=}, {key.shape=}') if attn_metadata is None: if output is None: @@ -182,117 +184,32 @@ def forward( return output assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - batch_size, seq_len, hidden_size = query.shape - query = query.view(batch_size, seq_len, self.num_heads, self.head_size) - key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) - value = value.view(batch_size, seq_len, self.num_kv_heads, - self.head_size) + num_tokens, hidden_size = query.shape + query = query.view(num_tokens, self.num_heads, self.head_size) + key = key.view(num_tokens, self.num_kv_heads, self.head_size) + value = value.view(num_tokens, self.num_kv_heads, self.head_size) if kv_cache[0].numel() > 0: + print('xw32 write to kv cache') slot_mapping = attn_metadata.slot_mapping key_cache, value_cache = kv_cache write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale - if attn_metadata.num_prefills > 0: - if attn_metadata.block_tables is None: - # Prefill without paged KV cache. - assert seq_len % 16 == 0, ( - "Pallas FlashAttention kernel requires seq_len to be a " - f"multiple of 16 but got {seq_len}") - - # Handle GQA/MQA. - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, - dim=-2) - key = key.view(batch_size, seq_len, self.num_heads, - self.head_size) - value = value.repeat_interleave(self.num_queries_per_kv, - dim=-2) - value = value.view(batch_size, seq_len, self.num_heads, - self.head_size) - # FlashAttention kernel requires the input shape to be - # [batch_size, num_heads, seq_len, d_model] - # while the input is [batch_size, seq_len, num_heads, d_model]. - # Permute the input to match the required format. - output = torch.ops.xla.flash_attention( - query.permute(0, 2, 1, 3), - key.permute(0, 2, 1, 3), - value.permute(0, 2, 1, 3), - True, - ) - output = output.permute(0, 2, 1, 3) - else: - # Prefill with paged KV cache. - # TODO(woosuk): Tune the below knobs. - num_kv_pages_per_compute_block = 16 - num_queries_per_compute_block = 16 - assert seq_len % num_queries_per_compute_block == 0 - output = torch.ops.xla.multi_queries_paged_attention( - query, - key_cache, - value_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - attn_metadata.effective_query_lens, - num_kv_pages_per_compute_block, - num_queries_per_compute_block, - use_kernel=True, - ) - else: - # Decoding run. - assert kv_cache[0].numel() > 0 - query = query.squeeze(dim=1) - pages_per_compute_block = 16 # TODO(woosuk): Tune this value. - - assert attn_metadata.block_tables is not None - assert attn_metadata.context_lens is not None - # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire - # block table in SMEM. Therefore, if the block table is too large, - # the kernel compilation will fail. To avoid this, we split the - # batch dimension into smaller chunks and run the kernel multiple - # times. - MAX_SMEM_USAGE = 512 * 1024 - size_per_seq = 4 * attn_metadata.block_tables.shape[1] - max_num_seq = MAX_SMEM_USAGE // size_per_seq - - if batch_size <= max_num_seq: - output = paged_attention( - query, - key_cache, - value_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - pages_per_compute_block, - self.megacore_mode, - ) - else: - chunk_size = max_num_seq - # Make sure the chunk size is a multiple of 2. - chunk_size = chunk_size // 2 * 2 - num_chunks = (batch_size + chunk_size - 1) // chunk_size - - output = torch.empty_like(query) - for chunk_idx in range(num_chunks): - chunk_start = chunk_idx * chunk_size - chunk_end = chunk_start + chunk_size - # NOTE(woosuk): We skip this line because it causes Dynamo - # compilation error. Instead, we rely on the slice operation - # to handle the out-of-bound case. - # chunk_end = min(chunk_end, batch_size) - chunk_output = paged_attention( - query[chunk_start:chunk_end], - key_cache, - value_cache, - attn_metadata.context_lens[chunk_start:chunk_end], - attn_metadata.block_tables[chunk_start:chunk_end], - pages_per_compute_block, - self.megacore_mode, - ) - output[chunk_start:chunk_end] = chunk_output - - # Reshape the output tensor. - return output.reshape(batch_size, seq_len, hidden_size) + output = torch.ops.xla.ragged_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.query_start_loc, + attn_metadata.num_seqs, + num_kv_pages_per_block=16, + num_queries_per_block=128, + use_kernel=True, + ) + + return output def write_to_kv_cache( @@ -302,52 +219,24 @@ def write_to_kv_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, ) -> None: + """ Write the key and values to the KV cache. + + Args: + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + k_cache = [num_kv_heads, num_blocks, block_size, head_size] + v_cache = [num_kv_heads, num_blocks, block_size, head_size] + + """ + print(f'xw32 write_to_kv_cache {key.shape=}, {key_cache.shape=}, {slot_mapping=}', flush=True) torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) - key = key.flatten(0, 2) - value = value.flatten(0, 2) + # xw32: key = key.flatten(0, 1) or key = key.flatten(0, 2)? + # key = key.flatten(0, 1) because the key.shape has changed from [bs, seq_len, num_kv_heads, head_size] to [num_tokens, num_kv_heads, head_size] + key = key.flatten(0, 1) + value = value.flatten(0, 1) key_cache = key_cache.flatten(0, 2) value_cache = value_cache.flatten(0, 2) key_cache.index_copy_(0, slot_mapping, key) value_cache.index_copy_(0, slot_mapping, value) - - -def paged_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - context_lens: torch.Tensor, - block_tables: torch.Tensor, - pages_per_compute_block: int, - megacore_mode: Optional[str], -) -> torch.Tensor: - batch_size = query.shape[0] - if megacore_mode == "batch" and batch_size % 2 != 0: - megacore_mode = None - else: - megacore_mode = megacore_mode - - # NOTE(woosuk): A temporary workaround to avoid the error: - # "xla::paged_attention() Expected a value of type 'str' for - # argument 'megacore_mode' but instead found type 'NoneType'." - if megacore_mode is not None: - output = torch.ops.xla.paged_attention( - query, - key_cache, - value_cache, - context_lens, - block_tables, - pages_per_compute_block, - megacore_mode=megacore_mode, - ) - else: - output = torch.ops.xla.paged_attention( - query, - key_cache, - value_cache, - context_lens, - block_tables, - pages_per_compute_block, - ) - return output diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 8635ffce7027..9b55ee563ab2 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -2,7 +2,7 @@ import enum import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from unittest.mock import patch import numpy as np @@ -38,39 +38,14 @@ # Here we utilize the behavior that out-of-bound index is ignored. # FIXME(woosuk): Find a more reliable way to prevent possible bugs. _PAD_SLOT_ID = 1_000_000_000 +# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. +_ENABLE_TOP_P = False +# FIXME(woosuk): A temporary hack to support `n > 1`. +# This can significantly affect the performance if too large. +_MAX_NUM_SAMPLES = 128 -class ExecutionMode(enum.Enum): - PREFILL = enum.auto() - DECODE = enum.auto() - PREFIX_PREFILL = enum.auto() - - def is_prefill(self) -> bool: - return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) - - -@dataclass -class PromptDecodeInfo: - prompt_req_ids: List[str] - decode_req_ids: List[str] - prompt_scheduled_tokens: List[int] - - -@dataclass -class PromptData: - input_tokens: torch.Tensor - input_positions: torch.Tensor - attn_metadata: PallasMetadata - - -@dataclass -class DecodeData: - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - attn_metadata: Optional[PallasMetadata] = None - - -class TPUModelRunner: +class TPUModelRunner(): def __init__( self, @@ -135,50 +110,34 @@ def __init__( # KV caches for forward pass self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] - # Cached torch/numpy tensors - self.num_swaps = 2 - self.cur_swap_id = 0 - self.input_ids_cpu = [] - self.input_ids_np = [] - self.input_positions_cpu = [] - self.input_positions_np = [] - self.slot_mapping_cpu = [] - self.slot_mapping_np = [] - self.prompt_context_lens_cpu = [] - self.prompt_effective_query_lens_cpu = [] - self.decode_context_lens_cpu = [] - self.decode_context_lens_np = [] - for _ in range(self.num_swaps): - self.input_ids_cpu.append( - torch.empty(self.max_num_tokens, - dtype=torch.int32, - device="cpu")) - self.input_ids_np.append(self.input_ids_cpu[-1].numpy()) - - self.input_positions_cpu.append( - torch.empty(self.max_num_tokens, - dtype=torch.int32, - device="cpu")) - self.input_positions_np.append( - self.input_positions_cpu[-1].numpy()) - - self.slot_mapping_cpu.append( - torch.empty(self.max_num_tokens, - dtype=torch.int64, - device="cpu")) - self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy()) - - self.prompt_context_lens_cpu.append( - torch.empty((1), dtype=torch.int32, device="cpu")) - self.prompt_effective_query_lens_cpu.append( - torch.empty((1), dtype=torch.int32, device="cpu")) - - self.decode_context_lens_cpu.append( - torch.empty(self.max_num_tokens, - dtype=torch.int32, - device="cpu")) - self.decode_context_lens_np.append( - self.decode_context_lens_cpu[-1].numpy()) + # Cached torch/numpy tensor + # xw32: what's the numpy array (eg input_ids_np) for? + self.input_ids_cpu = torch.empty(self.max_num_tokens, + dtype=torch.int32, + device="cpu") + self.input_ids_np = self.input_ids_cpu.numpy() + + self.positions_cpu = torch.empty(self.max_num_tokens, + dtype=torch.int32, + device="cpu") + self.positions_np = self.positions_cpu.numpy() + + # xw32: slot_mapping maps a token to its position in the block (=block_numbers * self.block_size+block_offset) + self.slot_mapping_cpu = torch.empty(self.max_num_tokens, + dtype=torch.int64, + device="cpu") + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + + self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.query_start_loc_np = self.query_start_loc_cpu.numpy() + self.seq_lens_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.seq_lens_np = self.seq_lens_cpu.numpy() # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens @@ -304,9 +263,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.input_batch.condense(removed_req_indices) return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 - def swap_step(self): - self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps - def get_model(self) -> nn.Module: assert self.model is not None return self.model @@ -346,363 +302,275 @@ def get_kv_cache_spec(self) -> KVCacheSpec: return kv_cache_spec - def _get_prompts_and_decodes( - self, - scheduler_output: "SchedulerOutput", - ) -> PromptDecodeInfo: + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + print(f'xw32 _prepare_inputs begins. {scheduler_output=}') total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - # Traverse decodes first - decode_req_ids = [] - for i in range(num_reqs): - req_id = self.input_batch.req_ids[i] - assert req_id is not None - - num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] - num_prompt_tokens = self.input_batch.num_prompt_tokens[i] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - - if num_computed_tokens < num_prompt_tokens: - # This is prompt - break - - # This is decode - assert num_scheduled_tokens == 1 - decode_req_ids.append(req_id) + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + # xw32q: Do we need this? + self.input_batch.block_table.commit(num_reqs) - # Traverse prompts - prompt_req_ids = [] - prompt_scheduled_tokens = [] - for i in range(len(decode_req_ids), num_reqs): - req_id = self.input_batch.req_ids[i] + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens_per_req = [] + max_num_scheduled_tokens_all_reqs = 0 + for req_id in self.input_batch.req_ids[:num_reqs]: assert req_id is not None - - num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] - num_prompt_tokens = self.input_batch.num_prompt_tokens[i] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - - # Must be prompt - assert num_computed_tokens < num_prompt_tokens - - prompt_req_ids.append(req_id) - prompt_scheduled_tokens.append(num_scheduled_tokens) - - return PromptDecodeInfo(prompt_req_ids, decode_req_ids, - prompt_scheduled_tokens) - - def _prepare_prompt(self, req_index: int, - num_scheduled_tokens: int) -> PromptData: - num_computed_tokens = self.input_batch.num_computed_tokens_cpu[ - req_index] - num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index] - - # Must be prompt - assert num_computed_tokens < num_prompt_tokens - - # Prompt len - prompt_len = num_scheduled_tokens - padded_prompt_len = _get_padded_prompt_len(prompt_len) - assert padded_prompt_len <= self.max_model_len - - # Seq len - seq_len = num_computed_tokens + prompt_len - padded_seq_len = num_computed_tokens + padded_prompt_len - - # Input tokens - input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[ - req_index, num_computed_tokens:padded_seq_len] - input_tokens_cpu[prompt_len:] = 0 - - # Input positions - input_positions_np = self.input_positions_np[ - self.cur_swap_id][:padded_prompt_len] - np.add(num_computed_tokens, - self.arange_np[:padded_prompt_len], - out=input_positions_np) - input_positions_np[prompt_len:] = 0 - - # Slot mapping - block_table_np = \ - self.input_batch.block_table.get_numpy_array() - block_numbers_np = block_table_np[req_index, input_positions_np // - self.block_size] - block_offsets_np = input_positions_np % self.block_size - - slot_mapping_np = self.slot_mapping_np[ - self.cur_swap_id][:padded_prompt_len] - np.add(block_numbers_np * self.block_size, - block_offsets_np, - out=slot_mapping_np) - slot_mapping_np[prompt_len:] = _PAD_SLOT_ID - - # Block table - block_table_cpu = None - if num_computed_tokens > 0: - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - block_table_cpu = block_table_cpu[req_index] - - # Context len - self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0 - if num_computed_tokens > 0: - self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len - - # Effective query len - self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len - - # Get final tensors - input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device) - input_positions = self.input_positions_cpu[ - self.cur_swap_id][:padded_prompt_len].reshape(1, - -1).to(self.device) - slot_mapping = self.slot_mapping_cpu[ - self.cur_swap_id][:padded_prompt_len].reshape(1, - -1).to(self.device) - block_table = block_table_cpu.reshape(1, -1).to( - self.device) if block_table_cpu is not None else None - - context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to( - self.device) - effective_query_lens = self.prompt_effective_query_lens_cpu[ - self.cur_swap_id].to(self.device) - - self.swap_step() - - # Attn metadata - attn_metadata = PallasMetadata( - num_prefills=1, - num_prefill_tokens=0, # NOTE: This is not used. - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_table, - context_lens=context_lens, - effective_query_lens=effective_query_lens, - ) - - return PromptData(input_tokens, input_positions, attn_metadata) - - def _prepare_decode( - self, - decode_req_ids: List[str], - ) -> DecodeData: - # Batch size - batch_size = len(decode_req_ids) - padded_batch_size = _get_padded_batch_size(batch_size) - assert padded_batch_size <= self.max_model_len - - # Init [0 .. batch_size - 1] - req_indices_np = self.arange_np[:padded_batch_size] - - # Input positions - input_positions_np = self.input_positions_np[ - self.cur_swap_id][:padded_batch_size] - np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], - 0, - out=input_positions_np) - input_positions_np[batch_size:] = 0 - input_positions_cpu = self.input_positions_cpu[ - self.cur_swap_id][:padded_batch_size] - - # Input tokens - token_indices_np = ( - input_positions_np + - req_indices_np * self.input_batch.token_ids_cpu.shape[1]) - input_tokens_cpu = self.input_ids_cpu[ - self.cur_swap_id][:padded_batch_size] + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # print(f'xw32 TPUModelRunner.prepare_input line148. {req_id=}, {num_tokens=}') + # xw32 TPUModelRunner.prepare_input line148. req_id='0', num_tokens=5 + num_scheduled_tokens_per_req.append(num_tokens) + max_num_scheduled_tokens_all_reqs = max(max_num_scheduled_tokens_all_reqs, + num_tokens) + num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, dtype=np.int32) + assert max_num_scheduled_tokens_all_reqs > 0 + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens_per_req) + + # Get batched arange. + # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = np.concatenate( + [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) + + # Get positions. + # TODO(xw32): add an example of the output positions_np. + positions_np = self.positions_np[:total_num_scheduled_tokens] + # print(f'xw32 TPUModelRunner.prepare_input. {total_num_scheduled_tokens=}, {self.input_batch.num_computed_tokens_cpu=}, {self.input_batch.num_reqs=}, {self.model_config.uses_mrope=}') + # xw32 TPUModelRunner.prepare_input. total_num_scheduled_tokens=5, self.input_batch.num_computed_tokens_cpu=array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32), self.input_batch.num_reqs=1, self.model_config.uses_mrope=False + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + # xw32: Do we need to check self.model_config.uses_mrope? + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) + # print(f'xw32 TPUModelRunner.prepare_input line148. {positions_np=}, {req_indices=}, {self.input_batch.token_ids_cpu.shape=}, {token_indices}') + # xw32 TPUModelRunner.prepare_input line148. positions_np=array([0, 1, 2, 3, 4]), req_indices=array([0, 0, 0, 0, 0], dtype=int32), self.input_batch.token_ids_cpu.shape=(16, 512), [0 1 2 3 4] + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + + # print(f'xw32 TPUModelRunner.prepare_input line189 . {self.input_batch.token_ids_cpu_tensor=}') # prints a 2d tensor torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, - torch.from_numpy(token_indices_np), - out=input_tokens_cpu) - input_tokens_cpu[batch_size:] = 0 - - # Slot mapping - block_table_indices_np = ( - req_indices_np * self.max_num_blocks_per_req + - input_positions_np // self.block_size) - + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens]) + + # Calculate the slot mapping. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # because M (max_model_len) is not necessarily divisible by block_size. + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions_np // self.block_size) + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_offsets = positions_np % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:total_num_scheduled_tokens]) + + # Prepare the attention metadata. + print(f'xw32 TPUModelRunner.prepare_input line214. {self.query_start_loc_np.shape=}') + self.query_start_loc_np[0] = 0 + np.cumsum(num_scheduled_tokens_per_req, + out=self.query_start_loc_np[1:num_reqs + 1]) + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens_per_req) + + # Copy the tensors to the TPU. + self.input_ids = self.input_ids_cpu[:total_num_scheduled_tokens].to(self.device) + self.position_ids = self.positions_cpu[:total_num_scheduled_tokens].to(self.device) + query_start_loc = self.query_start_loc_cpu[:total_num_scheduled_tokens+1].to(self.device) + seq_lens = self.seq_lens_cpu[:total_num_scheduled_tokens].to(self.device) + slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(self.device) + print(f'xw32 TPUModelRunner.prepare_input line230 . {self.input_batch.block_table.get_device_tensor().shape=}, {total_num_scheduled_tokens=}, {num_reqs=}') - block_numbers_np = block_table_cpu.flatten( - )[block_table_indices_np].numpy() - - block_offsets_np = input_positions_np % self.block_size - - slot_mapping_np = self.slot_mapping_np[ - self.cur_swap_id][:padded_batch_size] - np.add(block_numbers_np * self.block_size, - block_offsets_np, - out=slot_mapping_np) - slot_mapping_np[batch_size:] = _PAD_SLOT_ID - - block_table_cpu = block_table_cpu[:padded_batch_size] - - # Context lens - context_lens_np = self.decode_context_lens_np[ - self.cur_swap_id][:padded_batch_size] - np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], - 1, - out=context_lens_np) - context_lens_np[batch_size:] = 0 - - # Get final tensors - input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device) - input_positions = input_positions_cpu.reshape(-1, 1).to(self.device) - slot_mapping = self.slot_mapping_cpu[ - self.cur_swap_id][:padded_batch_size].reshape(-1, - 1).to(self.device) - block_table = block_table_cpu.to(self.device) - context_lens = self.decode_context_lens_cpu[ - self.cur_swap_id][:padded_batch_size].to(self.device) - - self.swap_step() - - # Attn metadata attn_metadata = PallasMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=padded_batch_size, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_table, - context_lens=context_lens, - effective_query_lens=None, + block_tables=( + self.input_batch.block_table.get_device_tensor()[:total_num_scheduled_tokens]), + context_lens=seq_lens, + query_start_loc=query_start_loc, + num_seqs=num_reqs, + # num_actual_tokens=total_num_scheduled_tokens, + # max_query_len=max_num_scheduled_tokens, + # max_seq_len=max_seq_len, ) + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial + # request in the batch. While we should not sample any token from this + # partial request, we do so for simplicity. We will ignore the sampled + # token from the partial request. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + return attn_metadata, logits_indices - return DecodeData(input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata) @torch.no_grad() def execute_model( self, scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: + logger.info(f"xw32 TPUModelRunner.execute_model. {scheduler_output=}") + # Update cached state self._update_states(scheduler_output) - # If necessary, swap decodes/prompts to have all decodes on the start - ensure_decodes_first(self.input_batch) - - # Prepare prompts/decodes info - pd_info = self._get_prompts_and_decodes(scheduler_output) - - # Init - num_prompts = len(pd_info.prompt_req_ids) - num_decodes = len(pd_info.decode_req_ids) - decode_data = None - sampled_token_ids = [0] * self.input_batch.num_reqs - - # Run each prompt individually - is_first = True - for i in range(num_prompts): - req_id = pd_info.prompt_req_ids[i] - req_index = num_decodes + i - assert req_index == self.input_batch.req_id_to_index[ - req_id] # TODO: Remove - req_state = self.requests[req_id] - num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i] - prompt_len = num_scheduled_tokens - seq_len = req_state.num_computed_tokens + num_scheduled_tokens - - # Prepare first prompt - if is_first: - prompt_data = self._prepare_prompt(req_index, - num_scheduled_tokens) - is_first = False - - # Run forward pass - with set_forward_context(prompt_data.attn_metadata, - self.vllm_config): - assert self.model is not None - selected_token_ids = self.model(prompt_data.input_tokens, - prompt_data.input_positions, - prompt_data.attn_metadata, - self.kv_caches) - - # In parallel to TPU execution, prepare the next iteration - if i < num_prompts - 1: - # There is next prompt => prepare it - prompt_data = self._prepare_prompt( - req_index + 1, pd_info.prompt_scheduled_tokens[i + 1]) - elif i == num_prompts - 1 and num_decodes > 0: - # There is next decode => prepare it - decode_data = self._prepare_decode(pd_info.decode_req_ids) - - # Update cached state (if prompt is fully done) - if seq_len >= len(req_state.prompt_token_ids): - # Transfer sampled tokens from TPU to CPU - selected_token_ids_cpu = selected_token_ids.cpu() - - # Get output token - token_id = selected_token_ids_cpu[prompt_len - 1].item() - sampled_token_ids[req_index] = token_id - - # Add output token to the request - self.input_batch.token_ids_cpu[req_index, seq_len] = token_id - self.input_batch.num_tokens[req_index] += 1 - req_state.output_token_ids.append(token_id) + # Prepare inputs + attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_input_tokens = num_scheduled_tokens + + input_ids = self.input_ids[:num_input_tokens] + + # Run the decoder + with set_forward_context(attn_metadata, self.vllm_config): + positions = self.position_ids[:num_input_tokens] + selected_token_ids = self.model( + token_ids=input_ids, + position_ids=positions, + kv_caches=self.kv_caches, + # xw32q: Why in gpu_model_runner.py, attn_metadata is None https://github.com/vllm-project/vllm/blob/46fe9b46d83e733130ce952eb3967a9c96713583/vllm/v1/worker/gpu_model_runner.py#L455? + attn_metadata=attn_metadata, + ) - # Run decodes (a single batch) - if num_decodes > 0: - - # Prepare decode (if was not yet prepared) - if decode_data is None: - decode_data = self._prepare_decode(pd_info.decode_req_ids) - - # Run forward pass - with set_forward_context(decode_data.attn_metadata, - self.vllm_config): - assert self.model is not None - selected_token_ids = self.model(decode_data.input_tokens, - decode_data.input_positions, - decode_data.attn_metadata, - self.kv_caches) - - # Transfer sampled tokens from TPU to CPU - decode_token_ids_cpu = selected_token_ids.cpu() - # Convert to list - decode_token_ids_list = decode_token_ids_cpu.tolist() - - # Update cached state for each decode request - for i in range(num_decodes): - req_id = pd_info.decode_req_ids[i] - req_index = i - assert req_index == self.input_batch.req_id_to_index[ - req_id] # TODO: Remove - req_state = self.requests[req_id] - seq_len = req_state.num_computed_tokens + 1 - - token_id = decode_token_ids_list[i] - sampled_token_ids[req_index] = token_id - - self.input_batch.token_ids_cpu[req_index, seq_len] = token_id - self.input_batch.num_tokens[req_index] += 1 + # Then, let's update the cache state. + num_reqs = self.input_batch.num_reqs + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + assert req_id is not None + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len <= req_state.num_tokens + if seq_len == req_state.num_tokens: + # Append the sampled token to the output token ids. + token_id = selected_token_ids[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + self.input_batch.num_tokens[i] += 1 req_state.output_token_ids.append(token_id) - - # Create output. - all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids - prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} - for req_id in all_req_ids: - prompt_logprobs_dict[req_id] = None + else: + # xw32q: what are these from gpu_model_runner.py? I don't understand. + # Ignore the sampled token from the partial request. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + # This relies on cuda-specific torch-internal impl details + generator.set_offset(generator.get_offset() - 4) + + # num_reqs entries should be non-None + assert all( + req_id is not None for req_id in + self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) model_runner_output = ModelRunnerOutput( - req_ids=all_req_ids, + req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=[[token_id] for token_id in sampled_token_ids], - logprobs=None, - prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] + sampled_token_ids=selected_token_ids, + logprob_token_ids_cpu=None, + logprobs_cpu=None, ) - return model_runner_output + # old code begins + # # Init + # num_reqs = self.input_batch.num_reqs + # assert num_reqs > 0 + # sampled_token_ids_list = [0] * num_reqs + + # # Run decodes (a single batch) + # if len(decode_data.req_ids) > 0: + # # Forward + # with set_forward_context(decode_data.attn_metadata, + # self.vllm_config): + # assert self.model is not None + # selected_token_ids = self.model(decode_data.input_tokens, + # decode_data.input_positions, + # decode_data.attn_metadata, + # self.kv_caches) + + # # Transfer sampled tokens from TPU to CPU + # selected_token_ids_list = selected_token_ids.cpu().tolist() + + # # Update cached state + # for i, req_id in enumerate(decode_data.req_ids): + # # xw32: what is the difference between req_index and req_id? + # req_index = self.input_batch.req_id_to_index[req_id] + # req_state = self.requests[req_id] + + # seq_len = (req_state.num_computed_tokens + + # scheduler_output.num_scheduled_tokens[req_id]) + + # token_id = selected_token_ids_list[i] + + # self.input_batch.token_ids_cpu[req_index, seq_len] = token_id + # self.input_batch.num_tokens[req_index] += 1 + # req_state.output_token_ids.append(token_id) + + # sampled_token_ids_list[req_index] = token_id + + # # Run each prompt + # for (req_id, prompt_len, input_tokens, input_positions, + # attn_metadata) in prompt_data.zipped(): + # assert req_id is not None + # req_state = self.requests[req_id] + # req_index = self.input_batch.req_id_to_index[req_id] + + # # Forward + # with set_forward_context(attn_metadata, self.vllm_config): + # assert self.model is not None + # selected_token_ids = self.model(input_tokens, input_positions, + # attn_metadata, self.kv_caches) + + # seq_len = (req_state.num_computed_tokens + + # scheduler_output.num_scheduled_tokens[req_id]) + # if seq_len >= len(req_state.prompt_token_ids): + # # Transfer sampled tokens from TPU to CPU + # token_id = selected_token_ids.cpu()[prompt_len - 1].item() + # sampled_token_ids_list[req_index] = token_id + + # # Update cached state + # self.input_batch.token_ids_cpu[req_index, seq_len] = token_id + # self.input_batch.num_tokens[req_index] += 1 + # req_state.output_token_ids.append(token_id) + + # # Get req_ids + # assert all( + # req_id is not None for req_id in + # self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + # req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) + + # model_runner_output = ModelRunnerOutput( + # req_ids=req_ids, + # req_id_to_index=self.input_batch.req_id_to_index, + # sampled_token_ids=sampled_token_ids_list, + # logprob_token_ids_cpu=None, + # logprobs_cpu=None, + # ) + + # return model_runner_output + # old code ends + def load_model(self) -> None: + logger.info("xw32 TPUModelRunner.load_model begins.") self.device = self.device_config.device # NOTE(woosuk): While the executor assigns the TP ranks to the worker @@ -724,193 +592,82 @@ def load_model(self) -> None: xm.mark_step() xm.wait_device_ops() model = ModelWrapperV1(model) - self.model = torch.compile(model, - backend="openxla", - fullgraph=True, - dynamic=False) - + # TODO(xw32): turn on dynamo. + # xw32 turns off dynamo + self.model = model + # self.model = torch.compile(model, + # backend="openxla", + # fullgraph=True, + # dynamic=False) + logger.info("xw32 TPUModelRunner.load_model ends.") + + # @torch.inference_mode() fails so I disabled it. + # It's also not in the original v1 tpu_model_runner.py + # @torch.inference_mode() def dummy_run( self, kv_caches, num_tokens: int, seq_len: Optional[int] = None, - exec_mode: Optional[ExecutionMode] = None, ) -> None: - assert seq_len is not None - assert exec_mode is not None - - exec_mode = ExecutionMode(exec_mode) - if exec_mode.is_prefill(): - seq_len = (seq_len + 15) // 16 * 16 - token_ids = torch.zeros((num_tokens, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((num_tokens, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((num_tokens, seq_len), - dtype=torch.int64, - device=self.device) - if exec_mode == ExecutionMode.PREFILL: - attn_metadata = PallasMetadata( - num_prefills=num_tokens, - num_prefill_tokens=num_tokens * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=None, - context_lens=None, - effective_query_lens=None, - ) - - else: - context_lens = torch.ones((num_tokens, ), - dtype=torch.int32, - device=self.device) - - block_tables = torch.zeros( - (num_tokens, self.max_num_blocks_per_req), - dtype=torch.int32, - device=self.device) - - effective_query_lens = torch.ones_like(context_lens) - - attn_metadata = PallasMetadata( - num_prefills=num_tokens, - num_prefill_tokens=num_tokens * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_tables, - context_lens=context_lens, - effective_query_lens=effective_query_lens, - ) - else: - assert seq_len == 1 - token_ids = torch.zeros((num_tokens, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((num_tokens, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((num_tokens, seq_len), - dtype=torch.int64, - device=self.device) - block_tables = torch.zeros( - (num_tokens, self.max_num_blocks_per_req), - dtype=torch.int32, - device=self.device) - context_lens = torch.ones((num_tokens, ), - dtype=torch.int32, - device=self.device) - attn_metadata = PallasMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=num_tokens * seq_len, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_tables, - context_lens=context_lens, - ) - - # NOTE(woosuk): There are two stages of compilation: torch.compile and - # XLA compilation. Using `mark_dynamic` can reduce the torch.compile - # overhead by reusing the FX graph for different shapes. - # However, the XLA graph will still require static shapes and needs to - # be re-compiled for every different shapes. This overhead is inevitable - # in the first run, but can be skipped afterwards as we cache the XLA - # graphs in the disk (VLLM_XLA_CACHE_PATH). - if exec_mode.is_prefill(): - # Prefll - torch._dynamo.mark_dynamic(token_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) - else: - # Decode - torch._dynamo.mark_dynamic(token_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - - with set_forward_context(attn_metadata, self.vllm_config, 0): + # logger.info(f"xw32 TPUModelRunner.dummy_run. {self.input_ids_cpu.shape=}, {self.positions_cpu.shape=}, {num_tokens=}, {self.input_ids_cpu.device=}") + # xw32 qq: what are input_ids and positions and slot_mapping? What are their shapes? Here is the answer: + # xw32 TPUModelRunner.dummy_run. self.input_ids.shape=torch.Size([8192]), self.positions.shape=torch.Size([8192]), num_tokens=16, 32, ..., self.input_ids.device=device(type='xla', index=0) + input_ids = torch.zeros(num_tokens, + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros(num_tokens, + dtype=torch.int64, + device=self.device) + slot_mapping = torch.zeros(num_tokens, + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros( + (num_tokens, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device) + context_lens = torch.ones((num_tokens, ), + dtype=torch.int32, + device=self.device) + block_tables = torch.zeros( + (num_tokens, self.max_num_blocks_per_req), + dtype=torch.int32, + device=self.device) + query_start_loc = torch.zeros(num_tokens+1, dtype=torch.int32, device=self.device) + # how do I set torch._dynamo.mark_dynamic? + # The attn_metadata is used in torch._dynamo.mark_dynamic. + attn_metadata = PallasMetadata( + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + query_start_loc=query_start_loc, + num_seqs=num_tokens, # xw32: is it correct? + ) + with set_forward_context(None, self.vllm_config): assert self.model is not None - self.model(token_ids, position_ids, attn_metadata, kv_caches) + logger.info(f"xw32 TPUModelRunner.dummy_run. before calling self.model, {input_ids.shape=}, {position_ids.shape=}") + self.model(input_ids, position_ids, None, kv_caches) + logger.info(f"xw32 TPUModelRunner.dummy_run. after calling self.model") def capture_model(self) -> None: """Compile the model.""" - # Prefill - logger.info( - "Compiling the model with different input shapes for prefill:") - start = time.time() - for batch_size in [1]: - seq_len = 16 - while seq_len <= self.model_config.max_model_len: - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.PREFILL) - xm.wait_device_ops() - logger.info(" batch_size: %d, seq_len: %d", batch_size, - seq_len) - num_tokens = batch_size * seq_len - if num_tokens >= self.scheduler_config.max_num_batched_tokens: - break - seq_len = seq_len * 2 - - end = time.time() - logger.info(" -- Compilation for prefill done in %.2f [secs].", - end - start) + logger.info("xw32 TPUModelRunner.capture_model.") + logger.info("Compiling the model with different input shapes.") - # Prefix prefill - if self.scheduler_config.enable_chunked_prefill: - logger.info("Compiling the model with different input shapes for " - "prefix prefill:") - start = time.time() - for batch_size in [1]: - seq_len = 16 - while seq_len <= self.model_config.max_model_len: - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.PREFIX_PREFILL) - xm.wait_device_ops() - logger.info(" batch_size: %d, seq_len: %d", batch_size, - seq_len) - num_tokens = batch_size * seq_len - if (num_tokens - >= self.scheduler_config.max_num_batched_tokens): - break - seq_len = seq_len * 2 - end = time.time() - logger.info( - " -- Compilation for prefix prefill done in %.2f [secs].", - end - start) - - # Decode - logger.info( - "Compiling the model with different input shapes for decode:") - start = time.time() - seq_len = 1 - batch_size = 8 # Must be in sync with _get_padded_batch_size() + # xw32 qq: is the compilation here for both torch.compile and the XLA compile? + # xw32: may need to compile for num_seqs. + start = time.perf_counter() + num_tokens = 16 while True: - self.dummy_run(self.kv_caches, - batch_size, - seq_len, - exec_mode=ExecutionMode.DECODE) + self.dummy_run(self.kv_caches, num_tokens) xm.wait_device_ops() - logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len) - - if batch_size >= self.scheduler_config.max_num_seqs: + logger.info(" -- num_tokens: %d", num_tokens) + if num_tokens >= self.scheduler_config.max_num_batched_tokens: break - batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 - - end = time.time() - logger.info(" -- Compilation for decode done in %.2f [secs].", + num_tokens *= 2 + end = time.perf_counter() + logger.info("Compilation finished in in %.2f [secs].", end - start) def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: @@ -920,6 +677,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ + logger.info(f"xw32 TPUModelRunner.initialize_kv_cache. {kv_cache_config=}") if len(kv_cache_config.groups) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " @@ -979,12 +737,19 @@ def forward( memory profiling at initialization. """ # Skip this in memory profiling at initialization. - if attn_metadata is not None and kv_caches[0][0].numel() > 0: + logger.info("xw32 ModelWrapperV1.forward.") + print(f'xw32 ModelWrapperV1.forward', flush=True) + print(f'xw32 ModelWrapperV1.forward {token_ids=}') + print(f'xw32 ModelWrapperV1.forward {position_ids=}') + print(f'xw32 ModelWrapperV1.forward {attn_metadata=}') + print(f'xw32 ModelWrapperV1.forward {len(kv_caches)=}, {kv_caches[0][0].shape=}') + if attn_metadata is not None: # index_copy_(slot_mapping) only works when the inserted dimension # is 0. However, the KV cache in the Pallas backend has the shape # [num_kv_heads, num_blocks, block_size, head_size]. To make it # work, we need to flatten the first three dimensions and modify # the slot_mapping accordingly. + # kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape slot_mapping = attn_metadata.slot_mapping slot_mapping = slot_mapping.flatten() @@ -1000,97 +765,30 @@ def forward( attn_metadata.slot_mapping = slot_mapping assert self.model is not None + print(f'xw32 ModelWrapperV1.forward, right before calling self.model, {token_ids=}', flush=True) hidden_states = self.model( token_ids, position_ids, kv_caches, attn_metadata, ) + print(f'xw32 ModelWrapperV1.forward, right after calling self.model, {hidden_states.shape=}', flush=True) - hidden_states = hidden_states.flatten(0, 1) + # hidden_states = hidden_states.flatten(0, 1) is not needed because previously hidden_states has shape [bs, T, C] and we need to combine the first 2 dimensions. + # hidden_states = hidden_states.flatten(0, 1) + print(f'xw32 ModelWrapperV1.forward, right after calling hidden_states.flatten, {hidden_states.shape=}', flush=True) logits = self.model.compute_logits(hidden_states, None) + print(f'xw32 ModelWrapperV1.forward, right after calling self.model.compute_logits', flush=True) # Greedy sampling. argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + print(f'xw32 ModelWrapperV1.forward, right after calling torch.argmax', flush=True) argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + print(f'xw32 ModelWrapperV1.forward, right after calling argmax_token_ids.squeeze', flush=True) return argmax_token_ids -def swap_positions(b: InputBatch, id_1, id_2): - assert id_1 != id_2 - req_id_1 = b.req_ids[id_1] - req_id_2 = b.req_ids[id_2] - assert req_id_1 is not None - assert req_id_2 is not None - assert id_1 == b.req_id_to_index[req_id_1] - assert id_2 == b.req_id_to_index[req_id_2] - - b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1] - b.req_id_to_index[req_id_1], b.req_id_to_index[ - req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1] - - ids = [id_1, id_2] - rev_ids = [id_2, id_1] - b.num_tokens[ids] = b.num_tokens[rev_ids] - b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids] - b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids] - b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids] - - b.block_table.swap_row(id_1, id_2) - - b.temperature_cpu[ids] = b.temperature_cpu[rev_ids] - b.top_p_cpu[ids] = b.top_p_cpu[rev_ids] - b.top_k_cpu[ids] = b.top_k_cpu[rev_ids] - b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids] - b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids] - b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids] - - b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[ - id_1] - b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[ - id_2], b.stop_token_ids[id_1] - - gen_1 = b.generators.pop(id_1, None) - gen_2 = b.generators.pop(id_2, None) - if gen_1 is not None: - b.generators[id_2] = gen_1 - if gen_2 is not None: - b.generators[id_1] = gen_2 - - -def ensure_decodes_first(b: InputBatch): - num_reqs = b.num_reqs - while True: - # Find the first prompt index - first_prompt_index = None - for i in range(num_reqs): - if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]: - first_prompt_index = i - break - if first_prompt_index is None: - break - - # Find the last decode index - last_decode_index = None - for i in reversed(range(num_reqs)): - if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]: - last_decode_index = i - break - if last_decode_index is None: - break - - # Sanity - assert first_prompt_index != last_decode_index - - # Check if done - if first_prompt_index > last_decode_index: - break - - # Swap - swap_positions(b, first_prompt_index, last_decode_index) - - -def _get_padded_prompt_len(x: int) -> int: +def _get_padded_prefill_len(x: int) -> int: # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence # length to be a multiple of 16. We pad the prompt length to the nearest # multiple of 16. This is also good for performance. diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index f29edd34ede3..d5c71a21b4ba 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -21,7 +21,7 @@ KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import bind_kv_cache -from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner +from vllm.v1.worker.tpu_model_runner import TPUModelRunner logger = init_logger(__name__) @@ -127,7 +127,6 @@ def determine_available_memory(self) -> int: runner_kv_caches, num_tokens=1, seq_len=self.scheduler_config.max_num_batched_tokens, - exec_mode=ExecutionMode.PREFILL, ) # Synchronize before measuring the memory usage. From d0eac0ffb5ec69fa87dba5a2572b8f82d8e610af Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 17 Feb 2025 21:56:29 +0000 Subject: [PATCH 02/19] add more comments --- vllm/v1/worker/tpu_model_runner.py | 8 ++++++-- vllm/v1/worker/tpu_worker.py | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 9b55ee563ab2..5817939dd870 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -110,8 +110,9 @@ def __init__( # KV caches for forward pass self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] + # xw32: do the swap thing later. Use the synchronous way now as baseline. # Cached torch/numpy tensor - # xw32: what's the numpy array (eg input_ids_np) for? + # The pytorch tensor and numpy array share the same buffer. Sometimes the numpy op is faster. self.input_ids_cpu = torch.empty(self.max_num_tokens, dtype=torch.int32, device="cpu") @@ -141,6 +142,7 @@ def __init__( # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens + # TODO(xw32): may need to replace max_num_tokens with max_model_len. self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: @@ -312,7 +314,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. # xw32q: Do we need this? - self.input_batch.block_table.commit(num_reqs) + # TODO(xw32): check if TPU support async copy. Similar to the pined_memory + # self.input_batch.block_table.commit(num_reqs) # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. @@ -643,6 +646,7 @@ def dummy_run( query_start_loc=query_start_loc, num_seqs=num_tokens, # xw32: is it correct? ) + # TODO(xw32): work with Alex to fix the issue later. with set_forward_context(None, self.vllm_config): assert self.model is not None logger.info(f"xw32 TPUModelRunner.dummy_run. before calling self.model, {input_ids.shape=}, {position_ids.shape=}") diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index d5c71a21b4ba..8ddb245ba108 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -102,6 +102,7 @@ def init_device(self): self.model_runner = TPUModelRunner(self.vllm_config, self.device) def determine_available_memory(self) -> int: + # TODO(xw32): may need to follow gpu_worker's determine_available_memory kv_caches: Dict[str, torch.Tensor] = {} kv_cache_spec = self.model_runner.get_kv_cache_spec() for layer_name, layer_spec in kv_cache_spec.items(): @@ -123,6 +124,7 @@ def determine_available_memory(self) -> int: self.vllm_config.compilation_config.static_forward_context, runner_kv_caches) + # TODO(xw32): change here. self.model_runner.dummy_run( runner_kv_caches, num_tokens=1, From 2830ed41cb56dea29611bd59409e5aadead0f1e8 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Tue, 18 Feb 2025 06:19:43 +0000 Subject: [PATCH 03/19] cleaned up a bit --- vllm/v1/attention/backends/pallas.py | 54 ++---- vllm/v1/core/scheduler.py | 1 + vllm/v1/worker/tpu_model_runner.py | 249 +++++++++++---------------- 3 files changed, 117 insertions(+), 187 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index b417180ede55..e4d44214115a 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -12,6 +12,9 @@ from vllm.attention.backends.utils import CommonAttentionState +NUM_QUERIES_PER_BLOCK = 128 + + class PallasAttentionBackend(AttentionBackend): @staticmethod @@ -53,7 +56,7 @@ def copy_blocks( kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], src_to_dists: Tuple[torch.Tensor, torch.Tensor], ) -> None: - assert False, "I assume this PallasAttentionBackend.copy_blocks function should not be used. But I could be wrong." # TODO(xw32): remove + assert False, "I assume this PallasAttentionBackend.copy_blocks function should not be used. But I could be wrong." # TODO(xw32): If it turns out all tests passed, remove this method. src_indices, dst_indices = src_to_dists for k_cache, v_cache in kv_caches: torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) @@ -62,38 +65,8 @@ def copy_blocks( v_cache[:, dst_indices] = v_cache[:, src_indices] -# Why remove the base class AttentionMetadata. Because the base requires -# to set num_prefills, num_prefill_tokens, and num_decode_tokens. @dataclass class PallasMetadata(): - - # old begins - # Currently, input sequences can only contain all prefills - # or all decoding. - # block_tables: Optional[torch.Tensor] = None - # context_lens: Optional[torch.Tensor] = None - # effective_query_lens: Optional[torch.Tensor] = None - - # @property - # def prefill_metadata(self) -> Optional["PallasMetadata"]: - # if self.num_prefills == 0: - # return None - - # assert self.num_decode_tokens == 0 - # return self - - # @property - # def decode_metadata(self) -> Optional["PallasMetadata"]: - # if self.num_decode_tokens == 0: - # return None - - # assert self.num_prefills == 0 - # assert self.num_prefill_tokens == 0 - # assert self.block_tables is not None - # assert self.context_lens is not None - # return self - # old ends - # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -109,6 +82,8 @@ class PallasMetadata(): query_start_loc: torch.Tensor num_seqs: int + total_num_scheduled_tokens: int # TODO(xw32): remove it before merging the PR. + class PallasAttentionBackendImpl(AttentionImpl): @@ -171,12 +146,13 @@ def forward( query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_kv_heads, num_blocks, block_size, head_size] + kv_cache = ([num_kv_heads, num_blocks, block_size, head_size], [num_kv_heads, num_blocks, block_size, head_size]) attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ - print(f'xw32 PallasAttentionBackendImpl.forward begins {query.shape=}, {key.shape=}') + # xw32: kv_cache[0].shape=torch.Size([2, 57599, 16, 128]) + print(f'xw32 PallasAttentionBackendImpl.forward begins {query.shape=}, {key.shape=}, {len(kv_cache)=}, {kv_cache[0].shape=}') if attn_metadata is None: if output is None: @@ -193,9 +169,10 @@ def forward( print('xw32 write to kv cache') slot_mapping = attn_metadata.slot_mapping key_cache, value_cache = kv_cache - write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping, attn_metadata.total_num_scheduled_tokens) query = query * self.scale + print(f'xw32 xw32 PallasAttentionBackendImpl.forward: {query.shape=}, {key_cache.shape=}, {value_cache.shape=}, {attn_metadata.context_lens.shape=}, {attn_metadata.block_tables.shape=}, {attn_metadata.query_start_loc.shape=}, {attn_metadata.num_seqs=}', flush=True) output = torch.ops.xla.ragged_paged_attention( query, key_cache, @@ -205,11 +182,12 @@ def forward( attn_metadata.query_start_loc, attn_metadata.num_seqs, num_kv_pages_per_block=16, - num_queries_per_block=128, + num_queries_per_block=NUM_QUERIES_PER_BLOCK, use_kernel=True, ) + print(f'xw32 PallasAttentionBackendImpl.forward finished', flush=True) - return output + return output.reshape(num_tokens, hidden_size) def write_to_kv_cache( @@ -218,6 +196,7 @@ def write_to_kv_cache( key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, + total_num_scheduled_tokens: int, ) -> None: """ Write the key and values to the KV cache. @@ -228,7 +207,7 @@ def write_to_kv_cache( v_cache = [num_kv_heads, num_blocks, block_size, head_size] """ - print(f'xw32 write_to_kv_cache {key.shape=}, {key_cache.shape=}, {slot_mapping=}', flush=True) + print(f'xw32 write_to_kv_cache {key.shape=}, {key_cache.shape=}, {slot_mapping.shape=}', flush=True) torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) @@ -240,3 +219,4 @@ def write_to_kv_cache( value_cache = value_cache.flatten(0, 2) key_cache.index_copy_(0, slot_mapping, key) value_cache.index_copy_(0, slot_mapping, value) + print(f'xw32 write_to_kv_cache finished', flush=True) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 82c4b307d48b..f4904fa94e95 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -481,6 +481,7 @@ def update_from_output( req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] + print(f'xw32 update_from_output {req_index=}, {len(generated_token_ids)=}') if req_id not in scheduler_output.scheduled_spec_decode_tokens: # When the request's num_computed_tokens catches up # its num_tokens, the request generates output tokens. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5817939dd870..13a52161d573 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -23,7 +23,8 @@ from vllm.sampling_params import SamplingType from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, - PallasMetadata) + PallasMetadata, + NUM_QUERIES_PER_BLOCK) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput @@ -43,6 +44,7 @@ # FIXME(woosuk): A temporary hack to support `n > 1`. # This can significantly affect the performance if too large. _MAX_NUM_SAMPLES = 128 +INVALID_TOKEN_ID = -1 class TPUModelRunner(): @@ -77,8 +79,8 @@ def __init__( self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.max_num_tokens = scheduler_config.max_num_batched_tokens - self.max_num_reqs = scheduler_config.max_num_seqs + self.max_num_tokens = scheduler_config.max_num_batched_tokens # 8192 + self.max_num_reqs = scheduler_config.max_num_seqs # 16 # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( @@ -92,6 +94,7 @@ def __init__( self.model: Optional[nn.Module] = None # Persistent batch. + # self.max_model_len=512, self.max_num_tokens=8192 self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, @@ -113,21 +116,26 @@ def __init__( # xw32: do the swap thing later. Use the synchronous way now as baseline. # Cached torch/numpy tensor # The pytorch tensor and numpy array share the same buffer. Sometimes the numpy op is faster. - self.input_ids_cpu = torch.empty(self.max_num_tokens, + self.input_ids_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu") self.input_ids_np = self.input_ids_cpu.numpy() - self.positions_cpu = torch.empty(self.max_num_tokens, + self.positions_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu") self.positions_np = self.positions_cpu.numpy() - # xw32: slot_mapping maps a token to its position in the block (=block_numbers * self.block_size+block_offset) - self.slot_mapping_cpu = torch.empty(self.max_num_tokens, + # xw32: slot_mapping maps a token to its position in the kvcache (=block_numbers * self.block_size+block_offset) + self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int64, device="cpu") self.slot_mapping_np = self.slot_mapping_cpu.numpy() + # self.input_batch.block_table has shape of [max_num_reqs, max_num_blocks_per_req]. + # Because I want the block_table.shape[0] to be num_token, so I did this way. + self.block_table_cpu = torch.zeros((self.max_num_tokens, self.input_batch.block_table.get_cpu_tensor().shape[1]), + dtype=self.input_batch.block_table.get_cpu_tensor().dtype, + device="cpu") self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, dtype=torch.int32, @@ -142,7 +150,6 @@ def __init__( # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens - # TODO(xw32): may need to replace max_num_tokens with max_model_len. self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: @@ -318,13 +325,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # self.input_batch.block_table.commit(num_reqs) # Get the number of scheduled tokens for each request. - # TODO: The Python loop can be slow. Optimize. num_scheduled_tokens_per_req = [] max_num_scheduled_tokens_all_reqs = 0 for req_id in self.input_batch.req_ids[:num_reqs]: assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # print(f'xw32 TPUModelRunner.prepare_input line148. {req_id=}, {num_tokens=}') # xw32 TPUModelRunner.prepare_input line148. req_id='0', num_tokens=5 num_scheduled_tokens_per_req.append(num_tokens) max_num_scheduled_tokens_all_reqs = max(max_num_scheduled_tokens_all_reqs, @@ -334,38 +339,38 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + # For each scheduled token, what are the corresponding req index. req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens_per_req) # Get batched arange. # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # For each scheduled token, what is its position in the corresponding req. arange = np.concatenate( [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) # Get positions. # TODO(xw32): add an example of the output positions_np. + # self.positions_np.shape=(8192,) self.positions_np=array([281337856, 0, 32768, ..., 0, 281734336, 0], dtype=int32), the value in self.positions_np because it's initialize as torch.empty. positions_np = self.positions_np[:total_num_scheduled_tokens] - # print(f'xw32 TPUModelRunner.prepare_input. {total_num_scheduled_tokens=}, {self.input_batch.num_computed_tokens_cpu=}, {self.input_batch.num_reqs=}, {self.model_config.uses_mrope=}') # xw32 TPUModelRunner.prepare_input. total_num_scheduled_tokens=5, self.input_batch.num_computed_tokens_cpu=array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32), self.input_batch.num_reqs=1, self.model_config.uses_mrope=False np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) + # xw32 TPUModelRunner.prepare_input. line355 self.input_batch.num_computed_tokens_cpu[req_indices]=array([0, 0, 0, 0, 0], dtype=int32), arange=array([0, 1, 2, 3, 4], dtype=int32), positions_np=array([0, 1, 2, 3, 4], dtype=int32) - # xw32: Do we need to check self.model_config.uses_mrope? - # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) - # print(f'xw32 TPUModelRunner.prepare_input line148. {positions_np=}, {req_indices=}, {self.input_batch.token_ids_cpu.shape=}, {token_indices}') # xw32 TPUModelRunner.prepare_input line148. positions_np=array([0, 1, 2, 3, 4]), req_indices=array([0, 0, 0, 0, 0], dtype=int32), self.input_batch.token_ids_cpu.shape=(16, 512), [0 1 2 3 4] + # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - - # print(f'xw32 TPUModelRunner.prepare_input line189 . {self.input_batch.token_ids_cpu_tensor=}') # prints a 2d tensor + # xw32 note, self.input_batch.token_ids_cpu_tensor is a 2d tensor torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices), @@ -388,9 +393,9 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping_np[:total_num_scheduled_tokens]) + # xw32 TPUModelRunner.prepare_input line401 . self.slot_mapping_cpu.shape=torch.Size([8192]), self.slot_mapping_cpu=tensor([0, 1, 2, ..., 0, 0, 0]) # Prepare the attention metadata. - print(f'xw32 TPUModelRunner.prepare_input line214. {self.query_start_loc_np.shape=}') self.query_start_loc_np[0] = 0 np.cumsum(num_scheduled_tokens_per_req, out=self.query_start_loc_np[1:num_reqs + 1]) @@ -400,23 +405,27 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): num_scheduled_tokens_per_req) # Copy the tensors to the TPU. - self.input_ids = self.input_ids_cpu[:total_num_scheduled_tokens].to(self.device) - self.position_ids = self.positions_cpu[:total_num_scheduled_tokens].to(self.device) - query_start_loc = self.query_start_loc_cpu[:total_num_scheduled_tokens+1].to(self.device) - seq_lens = self.seq_lens_cpu[:total_num_scheduled_tokens].to(self.device) - slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(self.device) - print(f'xw32 TPUModelRunner.prepare_input line230 . {self.input_batch.block_table.get_device_tensor().shape=}, {total_num_scheduled_tokens=}, {num_reqs=}') - + padded_total_num_scheduled_tokens = _get_padded_number(total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK) + self.input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens].to(self.device) + self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to(self.device) + self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID + slot_mapping = self.slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(self.device) + block_table = self.block_table_cpu[:padded_total_num_scheduled_tokens] + block_table[:num_reqs] = self.input_batch.block_table.get_cpu_tensor()[:num_reqs] + block_table = block_table.to(self.device) + query_start_loc = self.query_start_loc_cpu[:padded_total_num_scheduled_tokens+1].to(self.device) + seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to(self.device) + + # print(f'xw32 TPUModelRunner.prepare_input line421 . {self.input_batch.block_table.get_device_tensor().shape=}') # self.input_batch.block_table.get_device_tensor().shape=torch.Size([16, 32]=(max_num_reqs, max_num_blocks_per_req) + # block_table.get_device_tensor()=tensor([[0, 0, 0, 0, 0...], [0]]) + # slot_mapping=tensor([ 0, 1, 2, 3, 4, 1000000000, ...]) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, - block_tables=( - self.input_batch.block_table.get_device_tensor()[:total_num_scheduled_tokens]), + block_tables=block_table, context_lens=seq_lens, query_start_loc=query_start_loc, num_seqs=num_reqs, - # num_actual_tokens=total_num_scheduled_tokens, - # max_query_len=max_num_scheduled_tokens, - # max_seq_len=max_seq_len, + total_num_scheduled_tokens=total_num_scheduled_tokens, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -440,37 +449,32 @@ def execute_model( # Prepare inputs attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - num_input_tokens = num_scheduled_tokens - input_ids = self.input_ids[:num_input_tokens] + input_ids = self.input_ids + print(f'xw32 TPUModelRunner.execute_model line459 {input_ids.shape=}, {num_scheduled_tokens=}') # Run the decoder with set_forward_context(attn_metadata, self.vllm_config): - positions = self.position_ids[:num_input_tokens] selected_token_ids = self.model( - token_ids=input_ids, - position_ids=positions, + token_ids=self.input_ids, + position_ids=self.position_ids, kv_caches=self.kv_caches, - # xw32q: Why in gpu_model_runner.py, attn_metadata is None https://github.com/vllm-project/vllm/blob/46fe9b46d83e733130ce952eb3967a9c96713583/vllm/v1/worker/gpu_model_runner.py#L455? attn_metadata=attn_metadata, + logits_indices=logits_indices, ) + print(f'xw32 TPUModelRunner.execute_model line470 {selected_token_ids.shape=}') # Then, let's update the cache state. num_reqs = self.input_batch.num_reqs - for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] + for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) - assert seq_len <= req_state.num_tokens - if seq_len == req_state.num_tokens: - # Append the sampled token to the output token ids. - token_id = selected_token_ids[i] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - self.input_batch.num_tokens[i] += 1 - req_state.output_token_ids.append(token_id) + if seq_len >= req_state.num_tokens: + request_seq_lens.append((i, req_state, seq_len)) else: - # xw32q: what are these from gpu_model_runner.py? I don't understand. # Ignore the sampled token from the partial request. # Rewind the generator state as if the token was not sampled. generator = self.input_batch.generators.get(i) @@ -484,94 +488,43 @@ def execute_model( self.input_batch.req_ids[:num_reqs]), "req_ids contains None" req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) + prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} + for req_id in self.input_batch.req_ids[:num_reqs]: + prompt_logprobs_dict[req_id] = None + + print(f'xw32 TPUModelRunner.execute_model line496 {selected_token_ids.shape=}') + max_gen_len = selected_token_ids.shape[-1] + if max_gen_len == 1: + valid_sampled_token_ids = selected_token_ids.tolist() + for i, req_state, seq_len in request_seq_lens: + token_id = valid_sampled_token_ids[i][0] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + self.input_batch.num_tokens[i] += 1 + else: + valid_mask = selected_token_ids != INVALID_TOKEN_ID + gen_lens = valid_mask.sum(dim=1).tolist() + valid_sampled_token_ids = [ + seq.tolist() + for seq in selected_token_ids[valid_mask].split(gen_lens) + ] + self.input_batch.num_tokens[:num_reqs] += gen_lens + for i, req_state, seq_len in request_seq_lens: + target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) + self.input_batch.token_ids_cpu[ + i, target_slice] = valid_sampled_token_ids[i] + req_state.output_token_ids.extend(valid_sampled_token_ids[i]) + + # print(f'xw32 TPUModelRunner.execute_model line518 execute_model, {len(req_ids)=}, {len(self.input_batch.req_id_to_index)=}, {selected_token_ids.shape=}, {selected_token_ids=} {num_reqs=}') model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=selected_token_ids, - logprob_token_ids_cpu=None, - logprobs_cpu=None, + sampled_token_ids=valid_sampled_token_ids, + logprobs=None, + prompt_logprobs_dict=prompt_logprobs_dict, ) return model_runner_output - # old code begins - # # Init - # num_reqs = self.input_batch.num_reqs - # assert num_reqs > 0 - # sampled_token_ids_list = [0] * num_reqs - - # # Run decodes (a single batch) - # if len(decode_data.req_ids) > 0: - # # Forward - # with set_forward_context(decode_data.attn_metadata, - # self.vllm_config): - # assert self.model is not None - # selected_token_ids = self.model(decode_data.input_tokens, - # decode_data.input_positions, - # decode_data.attn_metadata, - # self.kv_caches) - - # # Transfer sampled tokens from TPU to CPU - # selected_token_ids_list = selected_token_ids.cpu().tolist() - - # # Update cached state - # for i, req_id in enumerate(decode_data.req_ids): - # # xw32: what is the difference between req_index and req_id? - # req_index = self.input_batch.req_id_to_index[req_id] - # req_state = self.requests[req_id] - - # seq_len = (req_state.num_computed_tokens + - # scheduler_output.num_scheduled_tokens[req_id]) - - # token_id = selected_token_ids_list[i] - - # self.input_batch.token_ids_cpu[req_index, seq_len] = token_id - # self.input_batch.num_tokens[req_index] += 1 - # req_state.output_token_ids.append(token_id) - - # sampled_token_ids_list[req_index] = token_id - - # # Run each prompt - # for (req_id, prompt_len, input_tokens, input_positions, - # attn_metadata) in prompt_data.zipped(): - # assert req_id is not None - # req_state = self.requests[req_id] - # req_index = self.input_batch.req_id_to_index[req_id] - - # # Forward - # with set_forward_context(attn_metadata, self.vllm_config): - # assert self.model is not None - # selected_token_ids = self.model(input_tokens, input_positions, - # attn_metadata, self.kv_caches) - - # seq_len = (req_state.num_computed_tokens + - # scheduler_output.num_scheduled_tokens[req_id]) - # if seq_len >= len(req_state.prompt_token_ids): - # # Transfer sampled tokens from TPU to CPU - # token_id = selected_token_ids.cpu()[prompt_len - 1].item() - # sampled_token_ids_list[req_index] = token_id - - # # Update cached state - # self.input_batch.token_ids_cpu[req_index, seq_len] = token_id - # self.input_batch.num_tokens[req_index] += 1 - # req_state.output_token_ids.append(token_id) - - # # Get req_ids - # assert all( - # req_id is not None for req_id in - # self.input_batch.req_ids[:num_reqs]), "req_ids contains None" - # req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) - - # model_runner_output = ModelRunnerOutput( - # req_ids=req_ids, - # req_id_to_index=self.input_batch.req_id_to_index, - # sampled_token_ids=sampled_token_ids_list, - # logprob_token_ids_cpu=None, - # logprobs_cpu=None, - # ) - - # return model_runner_output - # old code ends - def load_model(self) -> None: logger.info("xw32 TPUModelRunner.load_model begins.") self.device = self.device_config.device @@ -613,9 +566,6 @@ def dummy_run( num_tokens: int, seq_len: Optional[int] = None, ) -> None: - # logger.info(f"xw32 TPUModelRunner.dummy_run. {self.input_ids_cpu.shape=}, {self.positions_cpu.shape=}, {num_tokens=}, {self.input_ids_cpu.device=}") - # xw32 qq: what are input_ids and positions and slot_mapping? What are their shapes? Here is the answer: - # xw32 TPUModelRunner.dummy_run. self.input_ids.shape=torch.Size([8192]), self.positions.shape=torch.Size([8192]), num_tokens=16, 32, ..., self.input_ids.device=device(type='xla', index=0) input_ids = torch.zeros(num_tokens, dtype=torch.int32, device=self.device) @@ -629,15 +579,11 @@ def dummy_run( (num_tokens, self.max_num_blocks_per_req), dtype=torch.int32, device=self.device) + query_start_loc = torch.zeros(num_tokens+1, dtype=torch.int32, device=self.device) context_lens = torch.ones((num_tokens, ), dtype=torch.int32, device=self.device) - block_tables = torch.zeros( - (num_tokens, self.max_num_blocks_per_req), - dtype=torch.int32, - device=self.device) - query_start_loc = torch.zeros(num_tokens+1, dtype=torch.int32, device=self.device) - # how do I set torch._dynamo.mark_dynamic? + # TODO(xw32): how do I set torch._dynamo.mark_dynamic? # The attn_metadata is used in torch._dynamo.mark_dynamic. attn_metadata = PallasMetadata( slot_mapping=slot_mapping, @@ -645,12 +591,13 @@ def dummy_run( context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_tokens, # xw32: is it correct? + total_num_scheduled_tokens=num_tokens, ) # TODO(xw32): work with Alex to fix the issue later. with set_forward_context(None, self.vllm_config): assert self.model is not None logger.info(f"xw32 TPUModelRunner.dummy_run. before calling self.model, {input_ids.shape=}, {position_ids.shape=}") - self.model(input_ids, position_ids, None, kv_caches) + self.model(input_ids, position_ids, None, kv_caches, None) logger.info(f"xw32 TPUModelRunner.dummy_run. after calling self.model") def capture_model(self) -> None: @@ -659,7 +606,6 @@ def capture_model(self) -> None: logger.info("xw32 TPUModelRunner.capture_model.") logger.info("Compiling the model with different input shapes.") - # xw32 qq: is the compilation here for both torch.compile and the XLA compile? # xw32: may need to compile for num_seqs. start = time.perf_counter() num_tokens = 16 @@ -681,7 +627,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config: Configuration for the KV cache, including the KV cache size of each layer """ - logger.info(f"xw32 TPUModelRunner.initialize_kv_cache. {kv_cache_config=}") if len(kv_cache_config.groups) > 1: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " @@ -726,6 +671,7 @@ def forward( position_ids: torch.Tensor, attn_metadata: AttentionMetadata, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + logits_indices: torch.Tensor, ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. @@ -742,11 +688,8 @@ def forward( """ # Skip this in memory profiling at initialization. logger.info("xw32 ModelWrapperV1.forward.") - print(f'xw32 ModelWrapperV1.forward', flush=True) - print(f'xw32 ModelWrapperV1.forward {token_ids=}') - print(f'xw32 ModelWrapperV1.forward {position_ids=}') - print(f'xw32 ModelWrapperV1.forward {attn_metadata=}') - print(f'xw32 ModelWrapperV1.forward {len(kv_caches)=}, {kv_caches[0][0].shape=}') + # token_ids=tensor([9707, 11, 847, 829, 374, 0...0] + # position_ids=tensor([0, 1, 2, 3, 4, 0, ..., 0] if attn_metadata is not None: # index_copy_(slot_mapping) only works when the inserted dimension # is 0. However, the KV cache in the Pallas backend has the shape @@ -769,26 +712,29 @@ def forward( attn_metadata.slot_mapping = slot_mapping assert self.model is not None - print(f'xw32 ModelWrapperV1.forward, right before calling self.model, {token_ids=}', flush=True) hidden_states = self.model( token_ids, position_ids, kv_caches, attn_metadata, ) - print(f'xw32 ModelWrapperV1.forward, right after calling self.model, {hidden_states.shape=}', flush=True) + # TODO(xw32): should unconditionally run hidden_states = hidden_states[:attn_metadata.total_num_scheduled_tokens]. Same for logits_indices + if attn_metadata is not None: + print(f'xw32 ModelWrapperV1.forward line724 {attn_metadata.total_num_scheduled_tokens=}, {hidden_states.shape=}') + hidden_states = hidden_states[:attn_metadata.total_num_scheduled_tokens] + if logits_indices is not None: + logits_indices = logits_indices[:attn_metadata.num_seqs] + hidden_states = hidden_states[logits_indices] + print(f'xw32 ModelWrapperV1.forward line728 {logits_indices=}, {hidden_states.shape=}') # hidden_states = hidden_states.flatten(0, 1) is not needed because previously hidden_states has shape [bs, T, C] and we need to combine the first 2 dimensions. # hidden_states = hidden_states.flatten(0, 1) - print(f'xw32 ModelWrapperV1.forward, right after calling hidden_states.flatten, {hidden_states.shape=}', flush=True) logits = self.model.compute_logits(hidden_states, None) - print(f'xw32 ModelWrapperV1.forward, right after calling self.model.compute_logits', flush=True) # Greedy sampling. argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - print(f'xw32 ModelWrapperV1.forward, right after calling torch.argmax', flush=True) - argmax_token_ids = argmax_token_ids.squeeze(dim=-1) - print(f'xw32 ModelWrapperV1.forward, right after calling argmax_token_ids.squeeze', flush=True) + print(f'xw32 line728 {argmax_token_ids.shape=}') + # argmax_token_ids = argmax_token_ids.squeeze(dim=-1) return argmax_token_ids @@ -809,3 +755,6 @@ def _get_padded_batch_size(batch_size: int) -> int: return 8 else: return ((batch_size + 15) // 16) * 16 + +def _get_padded_number(n: int, multiple: int) -> int: + return ((n + multiple - 1) // multiple) * multiple From 2316f143b05b7959a6f001952606c42170bcafb2 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 19 Feb 2025 17:50:10 +0000 Subject: [PATCH 04/19] disable print, enable torch.compile --- examples/offline_inference/basic.py | 6 ++--- vllm/model_executor/models/qwen2.py | 18 ++++++------- vllm/v1/attention/backends/pallas.py | 12 ++++----- vllm/v1/core/scheduler.py | 2 +- vllm/v1/worker/tpu_model_runner.py | 39 ++++++++++++++-------------- 5 files changed, 39 insertions(+), 38 deletions(-) diff --git a/examples/offline_inference/basic.py b/examples/offline_inference/basic.py index c110349a0eb9..22fe702de9c0 100644 --- a/examples/offline_inference/basic.py +++ b/examples/offline_inference/basic.py @@ -6,9 +6,9 @@ # Sample prompts. prompts = [ "Hello, my name is", - # "The president of the United States is", - # "The capital of France is", - # "The future of AI is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", ] # Create a sampling params object. sampling_params = SamplingParams() #temperature=0.8, top_p=0.95) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index a7da23987be0..c171b607eaf2 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -260,15 +260,15 @@ def forward( # TODO(xw32): revert the change before merging the code. # xw32 turns off dynamo -# @support_torch_compile( -# dynamic_arg_dims={ -# "input_ids": 0, -# # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, -# # otherwise (seq_len, ). -# "positions": -1, -# "intermediate_tensors": 0, -# "inputs_embeds": 0, -# }) +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) class Qwen2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index e4d44214115a..79e1632efbae 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -152,7 +152,7 @@ def forward( shape = [num_tokens, num_heads * head_size] """ # xw32: kv_cache[0].shape=torch.Size([2, 57599, 16, 128]) - print(f'xw32 PallasAttentionBackendImpl.forward begins {query.shape=}, {key.shape=}, {len(kv_cache)=}, {kv_cache[0].shape=}') + # print(f'xw32 PallasAttentionBackendImpl.forward begins {query.shape=}, {key.shape=}, {len(kv_cache)=}, {kv_cache[0].shape=}') if attn_metadata is None: if output is None: @@ -166,13 +166,13 @@ def forward( value = value.view(num_tokens, self.num_kv_heads, self.head_size) if kv_cache[0].numel() > 0: - print('xw32 write to kv cache') + # print('xw32 write to kv cache') slot_mapping = attn_metadata.slot_mapping key_cache, value_cache = kv_cache write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping, attn_metadata.total_num_scheduled_tokens) query = query * self.scale - print(f'xw32 xw32 PallasAttentionBackendImpl.forward: {query.shape=}, {key_cache.shape=}, {value_cache.shape=}, {attn_metadata.context_lens.shape=}, {attn_metadata.block_tables.shape=}, {attn_metadata.query_start_loc.shape=}, {attn_metadata.num_seqs=}', flush=True) + # print(f'xw32 xw32 PallasAttentionBackendImpl.forward: {query.shape=}, {key_cache.shape=}, {value_cache.shape=}, {attn_metadata.context_lens.shape=}, {attn_metadata.block_tables.shape=}, {attn_metadata.query_start_loc.shape=}, {attn_metadata.num_seqs=}', flush=True) output = torch.ops.xla.ragged_paged_attention( query, key_cache, @@ -185,7 +185,7 @@ def forward( num_queries_per_block=NUM_QUERIES_PER_BLOCK, use_kernel=True, ) - print(f'xw32 PallasAttentionBackendImpl.forward finished', flush=True) + # print(f'xw32 PallasAttentionBackendImpl.forward finished', flush=True) return output.reshape(num_tokens, hidden_size) @@ -207,7 +207,7 @@ def write_to_kv_cache( v_cache = [num_kv_heads, num_blocks, block_size, head_size] """ - print(f'xw32 write_to_kv_cache {key.shape=}, {key_cache.shape=}, {slot_mapping.shape=}', flush=True) + # print(f'xw32 write_to_kv_cache {key.shape=}, {key_cache.shape=}, {slot_mapping.shape=}', flush=True) torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) @@ -219,4 +219,4 @@ def write_to_kv_cache( value_cache = value_cache.flatten(0, 2) key_cache.index_copy_(0, slot_mapping, key) value_cache.index_copy_(0, slot_mapping, value) - print(f'xw32 write_to_kv_cache finished', flush=True) + # print(f'xw32 write_to_kv_cache finished', flush=True) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f4904fa94e95..a7b6be909e6d 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -481,7 +481,7 @@ def update_from_output( req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] - print(f'xw32 update_from_output {req_index=}, {len(generated_token_ids)=}') + # print(f'xw32 update_from_output {req_index=}, {len(generated_token_ids)=}') if req_id not in scheduler_output.scheduled_spec_decode_tokens: # When the request's num_computed_tokens catches up # its num_tokens, the request generates output tokens. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 13a52161d573..3b61dc6e04b9 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -312,7 +312,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec: return kv_cache_spec def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): - print(f'xw32 _prepare_inputs begins. {scheduler_output=}') + # print(f'xw32 _prepare_inputs begins. {scheduler_output=}') total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -441,7 +441,7 @@ def execute_model( self, scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: - logger.info(f"xw32 TPUModelRunner.execute_model. {scheduler_output=}") + # logger.info(f"xw32 TPUModelRunner.execute_model. {scheduler_output=}") # Update cached state self._update_states(scheduler_output) @@ -451,7 +451,7 @@ def execute_model( num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens input_ids = self.input_ids - print(f'xw32 TPUModelRunner.execute_model line459 {input_ids.shape=}, {num_scheduled_tokens=}') + # print(f'xw32 TPUModelRunner.execute_model line459 {input_ids.shape=}, {num_scheduled_tokens=}') # Run the decoder with set_forward_context(attn_metadata, self.vllm_config): @@ -462,7 +462,7 @@ def execute_model( attn_metadata=attn_metadata, logits_indices=logits_indices, ) - print(f'xw32 TPUModelRunner.execute_model line470 {selected_token_ids.shape=}') + # print(f'xw32 TPUModelRunner.execute_model line470 {selected_token_ids.shape=}') # Then, let's update the cache state. num_reqs = self.input_batch.num_reqs @@ -492,7 +492,7 @@ def execute_model( for req_id in self.input_batch.req_ids[:num_reqs]: prompt_logprobs_dict[req_id] = None - print(f'xw32 TPUModelRunner.execute_model line496 {selected_token_ids.shape=}') + # print(f'xw32 TPUModelRunner.execute_model line496 {selected_token_ids.shape=}') max_gen_len = selected_token_ids.shape[-1] if max_gen_len == 1: valid_sampled_token_ids = selected_token_ids.tolist() @@ -502,6 +502,7 @@ def execute_model( req_state.output_token_ids.append(token_id) self.input_batch.num_tokens[i] += 1 else: + # print('xw32 TPUModelRunner.execute_model line505 max_gen_len>1 is triggered') valid_mask = selected_token_ids != INVALID_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist() valid_sampled_token_ids = [ @@ -526,7 +527,7 @@ def execute_model( return model_runner_output def load_model(self) -> None: - logger.info("xw32 TPUModelRunner.load_model begins.") + #logger.info("xw32 TPUModelRunner.load_model begins.") self.device = self.device_config.device # NOTE(woosuk): While the executor assigns the TP ranks to the worker @@ -550,12 +551,12 @@ def load_model(self) -> None: model = ModelWrapperV1(model) # TODO(xw32): turn on dynamo. # xw32 turns off dynamo - self.model = model - # self.model = torch.compile(model, - # backend="openxla", - # fullgraph=True, - # dynamic=False) - logger.info("xw32 TPUModelRunner.load_model ends.") + # self.model = model + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) + # logger.info("xw32 TPUModelRunner.load_model ends.") # @torch.inference_mode() fails so I disabled it. # It's also not in the original v1 tpu_model_runner.py @@ -596,14 +597,14 @@ def dummy_run( # TODO(xw32): work with Alex to fix the issue later. with set_forward_context(None, self.vllm_config): assert self.model is not None - logger.info(f"xw32 TPUModelRunner.dummy_run. before calling self.model, {input_ids.shape=}, {position_ids.shape=}") + #logger.info(f"xw32 TPUModelRunner.dummy_run. before calling self.model, {input_ids.shape=}, {position_ids.shape=}") self.model(input_ids, position_ids, None, kv_caches, None) - logger.info(f"xw32 TPUModelRunner.dummy_run. after calling self.model") + #logger.info(f"xw32 TPUModelRunner.dummy_run. after calling self.model") def capture_model(self) -> None: """Compile the model.""" - logger.info("xw32 TPUModelRunner.capture_model.") + #logger.info("xw32 TPUModelRunner.capture_model.") logger.info("Compiling the model with different input shapes.") # xw32: may need to compile for num_seqs. @@ -687,7 +688,7 @@ def forward( memory profiling at initialization. """ # Skip this in memory profiling at initialization. - logger.info("xw32 ModelWrapperV1.forward.") + #logger.info("xw32 ModelWrapperV1.forward.") # token_ids=tensor([9707, 11, 847, 829, 374, 0...0] # position_ids=tensor([0, 1, 2, 3, 4, 0, ..., 0] if attn_metadata is not None: @@ -720,12 +721,12 @@ def forward( ) # TODO(xw32): should unconditionally run hidden_states = hidden_states[:attn_metadata.total_num_scheduled_tokens]. Same for logits_indices if attn_metadata is not None: - print(f'xw32 ModelWrapperV1.forward line724 {attn_metadata.total_num_scheduled_tokens=}, {hidden_states.shape=}') + #print(f'xw32 ModelWrapperV1.forward line724 {attn_metadata.total_num_scheduled_tokens=}, {hidden_states.shape=}') hidden_states = hidden_states[:attn_metadata.total_num_scheduled_tokens] if logits_indices is not None: logits_indices = logits_indices[:attn_metadata.num_seqs] hidden_states = hidden_states[logits_indices] - print(f'xw32 ModelWrapperV1.forward line728 {logits_indices=}, {hidden_states.shape=}') + #print(f'xw32 ModelWrapperV1.forward line728 {logits_indices=}, {hidden_states.shape=}') # hidden_states = hidden_states.flatten(0, 1) is not needed because previously hidden_states has shape [bs, T, C] and we need to combine the first 2 dimensions. # hidden_states = hidden_states.flatten(0, 1) @@ -733,7 +734,7 @@ def forward( # Greedy sampling. argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - print(f'xw32 line728 {argmax_token_ids.shape=}') + #print(f'xw32 line728 {argmax_token_ids.shape=}') # argmax_token_ids = argmax_token_ids.squeeze(dim=-1) return argmax_token_ids From f5d5429c9e309530c0fd44060f2371e2043bc57c Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 19 Feb 2025 20:54:25 +0000 Subject: [PATCH 05/19] pad block_table 2nd dim to a multiple of 128 to accomodate the kernel. --- vllm/v1/attention/backends/pallas.py | 3 ++- vllm/v1/worker/tpu_model_runner.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 79e1632efbae..0712e72131b0 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -13,6 +13,7 @@ NUM_QUERIES_PER_BLOCK = 128 +NUM_KV_PAGES_PER_BLOCK = 128 class PallasAttentionBackend(AttentionBackend): @@ -181,7 +182,7 @@ def forward( attn_metadata.block_tables, attn_metadata.query_start_loc, attn_metadata.num_seqs, - num_kv_pages_per_block=16, + num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, num_queries_per_block=NUM_QUERIES_PER_BLOCK, use_kernel=True, ) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 3b61dc6e04b9..46414e0b8773 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -24,7 +24,8 @@ from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, PallasMetadata, - NUM_QUERIES_PER_BLOCK) + NUM_QUERIES_PER_BLOCK, + NUM_KV_PAGES_PER_BLOCK) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput @@ -133,7 +134,8 @@ def __init__( self.slot_mapping_np = self.slot_mapping_cpu.numpy() # self.input_batch.block_table has shape of [max_num_reqs, max_num_blocks_per_req]. # Because I want the block_table.shape[0] to be num_token, so I did this way. - self.block_table_cpu = torch.zeros((self.max_num_tokens, self.input_batch.block_table.get_cpu_tensor().shape[1]), + padded_max_num_blocks_per_req = _get_padded_number(self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) + self.block_table_cpu = torch.zeros((self.max_num_tokens, padded_max_num_blocks_per_req), dtype=self.input_batch.block_table.get_cpu_tensor().dtype, device="cpu") @@ -410,9 +412,9 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to(self.device) self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID slot_mapping = self.slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(self.device) - block_table = self.block_table_cpu[:padded_total_num_scheduled_tokens] - block_table[:num_reqs] = self.input_batch.block_table.get_cpu_tensor()[:num_reqs] - block_table = block_table.to(self.device) + padded_block_table = self.block_table_cpu[:padded_total_num_scheduled_tokens] + padded_block_table[:num_reqs, :self.max_num_blocks_per_req] = self.input_batch.block_table.get_cpu_tensor()[:num_reqs] + padded_block_table = padded_block_table.to(self.device) query_start_loc = self.query_start_loc_cpu[:padded_total_num_scheduled_tokens+1].to(self.device) seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to(self.device) @@ -421,7 +423,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # slot_mapping=tensor([ 0, 1, 2, 3, 4, 1000000000, ...]) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, - block_tables=block_table, + block_tables=padded_block_table, context_lens=seq_lens, query_start_loc=query_start_loc, num_seqs=num_reqs, From 08cda8f0d9b2802c61a4818dc7c343bd59c28bda Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 20 Feb 2025 19:05:38 +0000 Subject: [PATCH 06/19] Updated the torch_xla pin again: the smem oom is gone. Also use the real attn_metadata in dummy_run and basic.py is still working fine. --- requirements-tpu.txt | 8 ++++---- vllm/v1/attention/backends/pallas.py | 2 +- vllm/v1/worker/tpu_model_runner.py | 21 +++++++++++++-------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 400e20923d57..0ebbc9a02c61 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -17,7 +17,7 @@ ray[default] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.7.0.dev20250212+cpu -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250212+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250212+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250212+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.7.0.dev20250220+cpu +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250220+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250220+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250220+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 0712e72131b0..fdb85c707632 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -12,7 +12,7 @@ from vllm.attention.backends.utils import CommonAttentionState -NUM_QUERIES_PER_BLOCK = 128 +NUM_QUERIES_PER_BLOCK = 16 NUM_KV_PAGES_PER_BLOCK = 128 diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 46414e0b8773..f9fd088bb0af 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -570,16 +570,16 @@ def dummy_run( seq_len: Optional[int] = None, ) -> None: input_ids = torch.zeros(num_tokens, - dtype=torch.int32, - device=self.device) + dtype=torch.int32, + device=self.device) position_ids = torch.zeros(num_tokens, - dtype=torch.int64, - device=self.device) + dtype=torch.int64, + device=self.device) slot_mapping = torch.zeros(num_tokens, dtype=torch.int64, device=self.device) block_tables = torch.zeros( - (num_tokens, self.max_num_blocks_per_req), + (num_tokens, self.block_table_cpu.shape[1]), dtype=torch.int32, device=self.device) query_start_loc = torch.zeros(num_tokens+1, dtype=torch.int32, device=self.device) @@ -600,7 +600,7 @@ def dummy_run( with set_forward_context(None, self.vllm_config): assert self.model is not None #logger.info(f"xw32 TPUModelRunner.dummy_run. before calling self.model, {input_ids.shape=}, {position_ids.shape=}") - self.model(input_ids, position_ids, None, kv_caches, None) + self.model(input_ids, position_ids, attn_metadata, kv_caches, None) #logger.info(f"xw32 TPUModelRunner.dummy_run. after calling self.model") def capture_model(self) -> None: @@ -612,10 +612,14 @@ def capture_model(self) -> None: # xw32: may need to compile for num_seqs. start = time.perf_counter() num_tokens = 16 + # The num_tokens_list below is how GPU precompiles. + # num_tokens_list = [512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1] while True: self.dummy_run(self.kv_caches, num_tokens) - xm.wait_device_ops() logger.info(" -- num_tokens: %d", num_tokens) + xm.mark_step() + xm.wait_device_ops() + # print(f'xw32 capture_model line 620 {self.scheduler_config.max_num_batched_tokens=}') if num_tokens >= self.scheduler_config.max_num_batched_tokens: break num_tokens *= 2 @@ -693,13 +697,14 @@ def forward( #logger.info("xw32 ModelWrapperV1.forward.") # token_ids=tensor([9707, 11, 847, 829, 374, 0...0] # position_ids=tensor([0, 1, 2, 3, 4, 0, ..., 0] - if attn_metadata is not None: + if attn_metadata is not None and kv_caches[0][0].numel() > 0: # index_copy_(slot_mapping) only works when the inserted dimension # is 0. However, the KV cache in the Pallas backend has the shape # [num_kv_heads, num_blocks, block_size, head_size]. To make it # work, we need to flatten the first three dimensions and modify # the slot_mapping accordingly. # kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] + # print(f'xw32 line705 {kv_caches[0][0].shape=}') num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape slot_mapping = attn_metadata.slot_mapping slot_mapping = slot_mapping.flatten() From 89ea8f1af7a943c8dc4697d71e626798dc46e481 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Fri, 21 Feb 2025 00:11:26 +0000 Subject: [PATCH 07/19] remove total_num_scheduled_tokens from attn_metadata. But it didn't help much about the dynamo compilation --- vllm/v1/attention/backends/pallas.py | 4 +--- vllm/v1/worker/tpu_model_runner.py | 33 +++++++++++++++++++++++----- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index fdb85c707632..96f01bc68381 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -83,7 +83,6 @@ class PallasMetadata(): query_start_loc: torch.Tensor num_seqs: int - total_num_scheduled_tokens: int # TODO(xw32): remove it before merging the PR. class PallasAttentionBackendImpl(AttentionImpl): @@ -170,7 +169,7 @@ def forward( # print('xw32 write to kv cache') slot_mapping = attn_metadata.slot_mapping key_cache, value_cache = kv_cache - write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping, attn_metadata.total_num_scheduled_tokens) + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale # print(f'xw32 xw32 PallasAttentionBackendImpl.forward: {query.shape=}, {key_cache.shape=}, {value_cache.shape=}, {attn_metadata.context_lens.shape=}, {attn_metadata.block_tables.shape=}, {attn_metadata.query_start_loc.shape=}, {attn_metadata.num_seqs=}', flush=True) @@ -197,7 +196,6 @@ def write_to_kv_cache( key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, - total_num_scheduled_tokens: int, ) -> None: """ Write the key and values to the KV cache. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f9fd088bb0af..f88c18f73f7d 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -427,7 +427,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): context_lens=seq_lens, query_start_loc=query_start_loc, num_seqs=num_reqs, - total_num_scheduled_tokens=total_num_scheduled_tokens, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -450,7 +449,7 @@ def execute_model( # Prepare inputs attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens input_ids = self.input_ids # print(f'xw32 TPUModelRunner.execute_model line459 {input_ids.shape=}, {num_scheduled_tokens=}') @@ -463,6 +462,7 @@ def execute_model( kv_caches=self.kv_caches, attn_metadata=attn_metadata, logits_indices=logits_indices, + total_num_scheduled_tokens=total_num_scheduled_tokens, ) # print(f'xw32 TPUModelRunner.execute_model line470 {selected_token_ids.shape=}') @@ -594,13 +594,20 @@ def dummy_run( context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_tokens, # xw32: is it correct? - total_num_scheduled_tokens=num_tokens, ) + + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(slot_mapping, 0) + torch._dynamo.mark_dynamic(block_tables, 0) + torch._dynamo.mark_dynamic(query_start_loc, 0) + torch._dynamo.mark_dynamic(context_lens, 0) + # TODO(xw32): work with Alex to fix the issue later. with set_forward_context(None, self.vllm_config): assert self.model is not None #logger.info(f"xw32 TPUModelRunner.dummy_run. before calling self.model, {input_ids.shape=}, {position_ids.shape=}") - self.model(input_ids, position_ids, attn_metadata, kv_caches, None) + self.model(input_ids, position_ids, attn_metadata, kv_caches, None, total_num_scheduled_tokens=num_tokens) #logger.info(f"xw32 TPUModelRunner.dummy_run. after calling self.model") def capture_model(self) -> None: @@ -613,7 +620,6 @@ def capture_model(self) -> None: start = time.perf_counter() num_tokens = 16 # The num_tokens_list below is how GPU precompiles. - # num_tokens_list = [512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1] while True: self.dummy_run(self.kv_caches, num_tokens) logger.info(" -- num_tokens: %d", num_tokens) @@ -627,6 +633,20 @@ def capture_model(self) -> None: logger.info("Compilation finished in in %.2f [secs].", end - start) + # GPU way of warming up. + # start = time.perf_counter() + # num_tokens_list = [512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1] + # # The num_tokens_list below is how GPU precompiles. + # for num_tokens in num_tokens_list: + # self.dummy_run(self.kv_caches, num_tokens) + # logger.info(" -- num_tokens: %d", num_tokens) + # xm.mark_step() + # xm.wait_device_ops() + # if num_tokens >= self.scheduler_config.max_num_batched_tokens: + # break + # end = time.perf_counter() + # logger.info("Compilation finished in in %.2f [secs].", end - start) + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -679,6 +699,7 @@ def forward( attn_metadata: AttentionMetadata, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], logits_indices: torch.Tensor, + total_num_scheduled_tokens: int, ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. @@ -729,7 +750,7 @@ def forward( # TODO(xw32): should unconditionally run hidden_states = hidden_states[:attn_metadata.total_num_scheduled_tokens]. Same for logits_indices if attn_metadata is not None: #print(f'xw32 ModelWrapperV1.forward line724 {attn_metadata.total_num_scheduled_tokens=}, {hidden_states.shape=}') - hidden_states = hidden_states[:attn_metadata.total_num_scheduled_tokens] + hidden_states = hidden_states[:total_num_scheduled_tokens] if logits_indices is not None: logits_indices = logits_indices[:attn_metadata.num_seqs] hidden_states = hidden_states[logits_indices] From 6520319ce696184b32bd02da33f419b6888f55f7 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Fri, 21 Feb 2025 00:48:00 +0000 Subject: [PATCH 08/19] pull total_num_scheduled_tokens and logits_indices out of ModelWrapperV1.forward --- vllm/v1/worker/tpu_model_runner.py | 41 +++++++++++++----------------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f88c18f73f7d..3d41b2b11e25 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -456,18 +456,21 @@ def execute_model( # Run the decoder with set_forward_context(attn_metadata, self.vllm_config): - selected_token_ids = self.model( + hidden_states = self.model( token_ids=self.input_ids, position_ids=self.position_ids, kv_caches=self.kv_caches, attn_metadata=attn_metadata, - logits_indices=logits_indices, - total_num_scheduled_tokens=total_num_scheduled_tokens, ) # print(f'xw32 TPUModelRunner.execute_model line470 {selected_token_ids.shape=}') + hidden_states = hidden_states[:total_num_scheduled_tokens] + num_reqs = self.input_batch.num_reqs + logits_indices = logits_indices[:num_reqs] + hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(hidden_states, None) + selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True) # Then, let's update the cache state. - num_reqs = self.input_batch.num_reqs request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None @@ -607,7 +610,7 @@ def dummy_run( with set_forward_context(None, self.vllm_config): assert self.model is not None #logger.info(f"xw32 TPUModelRunner.dummy_run. before calling self.model, {input_ids.shape=}, {position_ids.shape=}") - self.model(input_ids, position_ids, attn_metadata, kv_caches, None, total_num_scheduled_tokens=num_tokens) + self.model(input_ids, position_ids, attn_metadata, kv_caches) #logger.info(f"xw32 TPUModelRunner.dummy_run. after calling self.model") def capture_model(self) -> None: @@ -698,8 +701,6 @@ def forward( position_ids: torch.Tensor, attn_metadata: AttentionMetadata, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - logits_indices: torch.Tensor, - total_num_scheduled_tokens: int, ) -> torch.Tensor: """Executes the forward pass of the model and samples the next token. @@ -747,24 +748,16 @@ def forward( kv_caches, attn_metadata, ) - # TODO(xw32): should unconditionally run hidden_states = hidden_states[:attn_metadata.total_num_scheduled_tokens]. Same for logits_indices - if attn_metadata is not None: - #print(f'xw32 ModelWrapperV1.forward line724 {attn_metadata.total_num_scheduled_tokens=}, {hidden_states.shape=}') - hidden_states = hidden_states[:total_num_scheduled_tokens] - if logits_indices is not None: - logits_indices = logits_indices[:attn_metadata.num_seqs] - hidden_states = hidden_states[logits_indices] - #print(f'xw32 ModelWrapperV1.forward line728 {logits_indices=}, {hidden_states.shape=}') - - # hidden_states = hidden_states.flatten(0, 1) is not needed because previously hidden_states has shape [bs, T, C] and we need to combine the first 2 dimensions. - # hidden_states = hidden_states.flatten(0, 1) - logits = self.model.compute_logits(hidden_states, None) + + return hidden_states - # Greedy sampling. - argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - #print(f'xw32 line728 {argmax_token_ids.shape=}') - # argmax_token_ids = argmax_token_ids.squeeze(dim=-1) - return argmax_token_ids + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata, + ) -> Optional[torch.Tensor]: + logits = self.model.compute_logits(hidden_states, sampling_metadata) + return logits def _get_padded_prefill_len(x: int) -> int: From e272741eba287645bd0fc412a575e3245adf27a8 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Fri, 21 Feb 2025 03:59:00 +0000 Subject: [PATCH 09/19] change position_ids to use int32 instead of int64 --- vllm/v1/worker/tpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 3d41b2b11e25..687fc8952e14 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -576,7 +576,7 @@ def dummy_run( dtype=torch.int32, device=self.device) position_ids = torch.zeros(num_tokens, - dtype=torch.int64, + dtype=torch.int32, device=self.device) slot_mapping = torch.zeros(num_tokens, dtype=torch.int64, From ea03585d3dbf81601cd63ebf7759883f980a45e1 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Fri, 21 Feb 2025 06:32:36 +0000 Subject: [PATCH 10/19] removed my comments --- examples/offline_inference/basic/basic.py | 1 - vllm/model_executor/models/qwen2.py | 2 - vllm/v1/attention/backends/pallas.py | 17 ------- vllm/v1/worker/tpu_model_runner.py | 62 ++++------------------- vllm/v1/worker/tpu_worker.py | 2 - 5 files changed, 10 insertions(+), 74 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 22fe702de9c0..c6ef22c0dc1f 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -2,7 +2,6 @@ from vllm import LLM, SamplingParams -# TODO(xw32): should remove the change in this file before merging the PR. # Sample prompts. prompts = [ "Hello, my name is", diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index c171b607eaf2..e3de6b64fbb3 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -258,8 +258,6 @@ def forward( return hidden_states, residual -# TODO(xw32): revert the change before merging the code. -# xw32 turns off dynamo @support_torch_compile( dynamic_arg_dims={ "input_ids": 0, diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 96f01bc68381..d9e21c0e394c 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -51,20 +51,6 @@ def swap_blocks( ) -> None: raise RuntimeError("swap_blocks is not used for the TPU backend.") - @torch.compile(backend="openxla") - @staticmethod - def copy_blocks( - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - src_to_dists: Tuple[torch.Tensor, torch.Tensor], - ) -> None: - assert False, "I assume this PallasAttentionBackend.copy_blocks function should not be used. But I could be wrong." # TODO(xw32): If it turns out all tests passed, remove this method. - src_indices, dst_indices = src_to_dists - for k_cache, v_cache in kv_caches: - torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) - k_cache[:, dst_indices] = k_cache[:, src_indices] - torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) - v_cache[:, dst_indices] = v_cache[:, src_indices] - @dataclass class PallasMetadata(): @@ -151,9 +137,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - # xw32: kv_cache[0].shape=torch.Size([2, 57599, 16, 128]) - # print(f'xw32 PallasAttentionBackendImpl.forward begins {query.shape=}, {key.shape=}, {len(kv_cache)=}, {kv_cache[0].shape=}') - if attn_metadata is None: if output is None: output = torch.ones_like(query) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 1232b9a2b7c4..8ff2ad4e4097 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -95,7 +95,6 @@ def __init__( self.model: Optional[nn.Module] = None # Persistent batch. - # self.max_model_len=512, self.max_num_tokens=8192 self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, @@ -114,9 +113,9 @@ def __init__( # KV caches for forward pass self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] - # xw32: do the swap thing later. Use the synchronous way now as baseline. # Cached torch/numpy tensor - # The pytorch tensor and numpy array share the same buffer. Sometimes the numpy op is faster. + # The pytorch tensor and numpy array share the same buffer. + # Sometimes the numpy op is faster so we create both. self.input_ids_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu") @@ -127,13 +126,13 @@ def __init__( device="cpu") self.positions_np = self.positions_cpu.numpy() - # xw32: slot_mapping maps a token to its position in the kvcache (=block_numbers * self.block_size+block_offset) self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int64, device="cpu") self.slot_mapping_np = self.slot_mapping_cpu.numpy() - # self.input_batch.block_table has shape of [max_num_reqs, max_num_blocks_per_req]. - # Because I want the block_table.shape[0] to be num_token, so I did this way. + + # self.input_batch.block_table has a shape of [max_num_reqs, max_num_blocks_per_req]. + # Because we want the block_table.shape[0] to be num_tokens nad block_table[1] to be multiple of NUM_KV_PAGES_PER_BLOCK, so we create a separate one. padded_max_num_blocks_per_req = _get_padded_number(self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) self.block_table_cpu = torch.zeros((self.max_num_tokens, padded_max_num_blocks_per_req), dtype=self.input_batch.block_table.get_cpu_tensor().dtype, @@ -144,6 +143,7 @@ def __init__( device="cpu", pin_memory=self.pin_memory) self.query_start_loc_np = self.query_start_loc_cpu.numpy() + self.seq_lens_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int32, device="cpu", @@ -314,25 +314,17 @@ def get_kv_cache_spec(self) -> KVCacheSpec: return kv_cache_spec def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): - # print(f'xw32 _prepare_inputs begins. {scheduler_output=}') total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - # OPTIMIZATION: Start copying the block table first. - # This way, we can overlap the copy with the following CPU operations. - # xw32q: Do we need this? - # TODO(xw32): check if TPU support async copy. Similar to the pined_memory - # self.input_batch.block_table.commit(num_reqs) - # Get the number of scheduled tokens for each request. num_scheduled_tokens_per_req = [] max_num_scheduled_tokens_all_reqs = 0 for req_id in self.input_batch.req_ids[:num_reqs]: assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # xw32 TPUModelRunner.prepare_input line148. req_id='0', num_tokens=5 num_scheduled_tokens_per_req.append(num_tokens) max_num_scheduled_tokens_all_reqs = max(max_num_scheduled_tokens_all_reqs, num_tokens) @@ -352,14 +344,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) # Get positions. - # TODO(xw32): add an example of the output positions_np. - # self.positions_np.shape=(8192,) self.positions_np=array([281337856, 0, 32768, ..., 0, 281734336, 0], dtype=int32), the value in self.positions_np because it's initialize as torch.empty. positions_np = self.positions_np[:total_num_scheduled_tokens] - # xw32 TPUModelRunner.prepare_input. total_num_scheduled_tokens=5, self.input_batch.num_computed_tokens_cpu=array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32), self.input_batch.num_reqs=1, self.model_config.uses_mrope=False np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) - # xw32 TPUModelRunner.prepare_input. line355 self.input_batch.num_computed_tokens_cpu[req_indices]=array([0, 0, 0, 0, 0], dtype=int32), arange=array([0, 1, 2, 3, 4], dtype=int32), positions_np=array([0, 1, 2, 3, 4], dtype=int32) # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] @@ -367,12 +355,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # where M is the max_model_len. token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) - # xw32 TPUModelRunner.prepare_input line148. positions_np=array([0, 1, 2, 3, 4]), req_indices=array([0, 0, 0, 0, 0], dtype=int32), self.input_batch.token_ids_cpu.shape=(16, 512), [0 1 2 3 4] # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - # xw32 note, self.input_batch.token_ids_cpu_tensor is a 2d tensor torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices), @@ -384,6 +370,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # where K is the max_num_blocks_per_req and the block size is 2. # NOTE(woosuk): We can't simply use `token_indices // block_size` here # because M (max_model_len) is not necessarily divisible by block_size. + # req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) # NOTE(woosuk): We use torch.index_select instead of np.take here @@ -395,7 +382,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping_np[:total_num_scheduled_tokens]) - # xw32 TPUModelRunner.prepare_input line401 . self.slot_mapping_cpu.shape=torch.Size([8192]), self.slot_mapping_cpu=tensor([0, 1, 2, ..., 0, 0, 0]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -406,7 +392,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens_per_req) - # Copy the tensors to the TPU. + # Do the padding and copy the tensors to the TPU. padded_total_num_scheduled_tokens = _get_padded_number(total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK) self.input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens].to(self.device) self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to(self.device) @@ -418,9 +404,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): query_start_loc = self.query_start_loc_cpu[:padded_total_num_scheduled_tokens+1].to(self.device) seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to(self.device) - # print(f'xw32 TPUModelRunner.prepare_input line421 . {self.input_batch.block_table.get_device_tensor().shape=}') # self.input_batch.block_table.get_device_tensor().shape=torch.Size([16, 32]=(max_num_reqs, max_num_blocks_per_req) - # block_table.get_device_tensor()=tensor([[0, 0, 0, 0, 0...], [0]]) - # slot_mapping=tensor([ 0, 1, 2, 3, 4, 1000000000, ...]) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=padded_block_table, @@ -442,8 +425,6 @@ def execute_model( self, scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: - # logger.info(f"xw32 TPUModelRunner.execute_model. {scheduler_output=}") - # Update cached state self._update_states(scheduler_output) @@ -451,9 +432,6 @@ def execute_model( attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - input_ids = self.input_ids - # print(f'xw32 TPUModelRunner.execute_model line459 {input_ids.shape=}, {num_scheduled_tokens=}') - # Run the decoder with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( @@ -462,7 +440,6 @@ def execute_model( kv_caches=self.kv_caches, attn_metadata=attn_metadata, ) - # print(f'xw32 TPUModelRunner.execute_model line470 {selected_token_ids.shape=}') hidden_states = hidden_states[:total_num_scheduled_tokens] num_reqs = self.input_batch.num_reqs logits_indices = logits_indices[:num_reqs] @@ -497,7 +474,6 @@ def execute_model( for req_id in self.input_batch.req_ids[:num_reqs]: prompt_logprobs_dict[req_id] = None - # print(f'xw32 TPUModelRunner.execute_model line496 {selected_token_ids.shape=}') max_gen_len = selected_token_ids.shape[-1] if max_gen_len == 1: valid_sampled_token_ids = selected_token_ids.tolist() @@ -507,7 +483,6 @@ def execute_model( req_state.output_token_ids.append(token_id) self.input_batch.num_tokens[i] += 1 else: - # print('xw32 TPUModelRunner.execute_model line505 max_gen_len>1 is triggered') valid_mask = selected_token_ids != INVALID_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist() valid_sampled_token_ids = [ @@ -521,18 +496,17 @@ def execute_model( i, target_slice] = valid_sampled_token_ids[i] req_state.output_token_ids.extend(valid_sampled_token_ids[i]) - # print(f'xw32 TPUModelRunner.execute_model line518 execute_model, {len(req_ids)=}, {len(self.input_batch.req_id_to_index)=}, {selected_token_ids.shape=}, {selected_token_ids=} {num_reqs=}') model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, + spec_token_ids=None, logprobs=None, prompt_logprobs_dict=prompt_logprobs_dict, ) return model_runner_output def load_model(self) -> None: - #logger.info("xw32 TPUModelRunner.load_model begins.") self.device = self.device_config.device # NOTE(woosuk): While the executor assigns the TP ranks to the worker @@ -554,14 +528,10 @@ def load_model(self) -> None: xm.mark_step() xm.wait_device_ops() model = ModelWrapperV1(model) - # TODO(xw32): turn on dynamo. - # xw32 turns off dynamo - # self.model = model self.model = torch.compile(model, backend="openxla", fullgraph=True, dynamic=False) - # logger.info("xw32 TPUModelRunner.load_model ends.") # @torch.inference_mode() fails so I disabled it. # It's also not in the original v1 tpu_model_runner.py @@ -589,14 +559,12 @@ def dummy_run( context_lens = torch.ones((num_tokens, ), dtype=torch.int32, device=self.device) - # TODO(xw32): how do I set torch._dynamo.mark_dynamic? - # The attn_metadata is used in torch._dynamo.mark_dynamic. attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=block_tables, context_lens=context_lens, query_start_loc=query_start_loc, - num_seqs=num_tokens, # xw32: is it correct? + num_seqs=num_tokens, ) torch._dynamo.mark_dynamic(input_ids, 0) @@ -606,20 +574,15 @@ def dummy_run( torch._dynamo.mark_dynamic(query_start_loc, 0) torch._dynamo.mark_dynamic(context_lens, 0) - # TODO(xw32): work with Alex to fix the issue later. with set_forward_context(None, self.vllm_config): assert self.model is not None - #logger.info(f"xw32 TPUModelRunner.dummy_run. before calling self.model, {input_ids.shape=}, {position_ids.shape=}") self.model(input_ids, position_ids, attn_metadata, kv_caches) - #logger.info(f"xw32 TPUModelRunner.dummy_run. after calling self.model") def capture_model(self) -> None: """Compile the model.""" - #logger.info("xw32 TPUModelRunner.capture_model.") logger.info("Compiling the model with different input shapes.") - # xw32: may need to compile for num_seqs. start = time.perf_counter() num_tokens = 16 # The num_tokens_list below is how GPU precompiles. @@ -628,7 +591,6 @@ def capture_model(self) -> None: logger.info(" -- num_tokens: %d", num_tokens) xm.mark_step() xm.wait_device_ops() - # print(f'xw32 capture_model line 620 {self.scheduler_config.max_num_batched_tokens=}') if num_tokens >= self.scheduler_config.max_num_batched_tokens: break num_tokens *= 2 @@ -716,9 +678,6 @@ def forward( memory profiling at initialization. """ # Skip this in memory profiling at initialization. - #logger.info("xw32 ModelWrapperV1.forward.") - # token_ids=tensor([9707, 11, 847, 829, 374, 0...0] - # position_ids=tensor([0, 1, 2, 3, 4, 0, ..., 0] if attn_metadata is not None and kv_caches[0][0].numel() > 0: # index_copy_(slot_mapping) only works when the inserted dimension # is 0. However, the KV cache in the Pallas backend has the shape @@ -726,7 +685,6 @@ def forward( # work, we need to flatten the first three dimensions and modify # the slot_mapping accordingly. # kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] - # print(f'xw32 line705 {kv_caches[0][0].shape=}') num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape slot_mapping = attn_metadata.slot_mapping slot_mapping = slot_mapping.flatten() diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 8ddb245ba108..d5c71a21b4ba 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -102,7 +102,6 @@ def init_device(self): self.model_runner = TPUModelRunner(self.vllm_config, self.device) def determine_available_memory(self) -> int: - # TODO(xw32): may need to follow gpu_worker's determine_available_memory kv_caches: Dict[str, torch.Tensor] = {} kv_cache_spec = self.model_runner.get_kv_cache_spec() for layer_name, layer_spec in kv_cache_spec.items(): @@ -124,7 +123,6 @@ def determine_available_memory(self) -> int: self.vllm_config.compilation_config.static_forward_context, runner_kv_caches) - # TODO(xw32): change here. self.model_runner.dummy_run( runner_kv_caches, num_tokens=1, From 84ec0827cccf6253fc323edf07c4abfa565a0fd8 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Sat, 22 Feb 2025 05:33:02 +0000 Subject: [PATCH 11/19] remove some comments. Fix mypy annotation. --- vllm/v1/attention/backends/pallas.py | 12 +++--------- vllm/v1/core/scheduler.py | 1 - vllm/v1/worker/tpu_model_runner.py | 16 ++++++---------- 3 files changed, 9 insertions(+), 20 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index d9e21c0e394c..e6edca7953dc 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -137,7 +137,8 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_metadata is None: + # For determine_available_memory case. + if kv_cache[0].numel() == 0: if output is None: output = torch.ones_like(query) return output @@ -148,14 +149,12 @@ def forward( key = key.view(num_tokens, self.num_kv_heads, self.head_size) value = value.view(num_tokens, self.num_kv_heads, self.head_size) + key_cache, value_cache = kv_cache if kv_cache[0].numel() > 0: - # print('xw32 write to kv cache') slot_mapping = attn_metadata.slot_mapping - key_cache, value_cache = kv_cache write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale - # print(f'xw32 xw32 PallasAttentionBackendImpl.forward: {query.shape=}, {key_cache.shape=}, {value_cache.shape=}, {attn_metadata.context_lens.shape=}, {attn_metadata.block_tables.shape=}, {attn_metadata.query_start_loc.shape=}, {attn_metadata.num_seqs=}', flush=True) output = torch.ops.xla.ragged_paged_attention( query, key_cache, @@ -168,7 +167,6 @@ def forward( num_queries_per_block=NUM_QUERIES_PER_BLOCK, use_kernel=True, ) - # print(f'xw32 PallasAttentionBackendImpl.forward finished', flush=True) return output.reshape(num_tokens, hidden_size) @@ -189,16 +187,12 @@ def write_to_kv_cache( v_cache = [num_kv_heads, num_blocks, block_size, head_size] """ - # print(f'xw32 write_to_kv_cache {key.shape=}, {key_cache.shape=}, {slot_mapping.shape=}', flush=True) torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) - # xw32: key = key.flatten(0, 1) or key = key.flatten(0, 2)? - # key = key.flatten(0, 1) because the key.shape has changed from [bs, seq_len, num_kv_heads, head_size] to [num_tokens, num_kv_heads, head_size] key = key.flatten(0, 1) value = value.flatten(0, 1) key_cache = key_cache.flatten(0, 2) value_cache = value_cache.flatten(0, 2) key_cache.index_copy_(0, slot_mapping, key) value_cache.index_copy_(0, slot_mapping, value) - # print(f'xw32 write_to_kv_cache finished', flush=True) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 0b0b1a74b8c0..535aa644c53c 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -497,7 +497,6 @@ def update_from_output( req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] - # print(f'xw32 update_from_output {req_index=}, {len(generated_token_ids)=}') if req_id not in scheduler_output.scheduled_spec_decode_tokens: # When the request's num_computed_tokens catches up # its num_tokens, the request generates output tokens. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 8ff2ad4e4097..1bf6467c980a 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -92,8 +92,6 @@ def __init__( self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() - self.model: Optional[nn.Module] = None - # Persistent batch. self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, @@ -470,7 +468,7 @@ def execute_model( self.input_batch.req_ids[:num_reqs]), "req_ids contains None" req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) - prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} + prompt_logprobs_dict: Dict[str, LogprobsTensors] = {} for req_id in self.input_batch.req_ids[:num_reqs]: prompt_logprobs_dict[req_id] = None @@ -533,8 +531,6 @@ def load_model(self) -> None: fullgraph=True, dynamic=False) - # @torch.inference_mode() fails so I disabled it. - # It's also not in the original v1 tpu_model_runner.py # @torch.inference_mode() def dummy_run( self, @@ -569,12 +565,12 @@ def dummy_run( torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(slot_mapping, 0) - torch._dynamo.mark_dynamic(block_tables, 0) - torch._dynamo.mark_dynamic(query_start_loc, 0) - torch._dynamo.mark_dynamic(context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) - with set_forward_context(None, self.vllm_config): + with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None self.model(input_ids, position_ids, attn_metadata, kv_caches) From c9096e3f7908c2d63c72a6a370158e1da1a6c3f8 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 24 Feb 2025 22:22:13 +0000 Subject: [PATCH 12/19] correctly initiate the query_start_loc in dummy_run --- vllm/v1/worker/tpu_model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 1bf6467c980a..64a046250f45 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -526,6 +526,7 @@ def load_model(self) -> None: xm.mark_step() xm.wait_device_ops() model = ModelWrapperV1(model) + # self.model = model self.model = torch.compile(model, backend="openxla", fullgraph=True, @@ -551,7 +552,8 @@ def dummy_run( (num_tokens, self.block_table_cpu.shape[1]), dtype=torch.int32, device=self.device) - query_start_loc = torch.zeros(num_tokens+1, dtype=torch.int32, device=self.device) + query_lens = [1] * num_tokens + query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32).to(self.device) context_lens = torch.ones((num_tokens, ), dtype=torch.int32, device=self.device) From 1984125e9a4556241dfeaa2a5b37350a744ea7cf Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 27 Feb 2025 19:43:58 +0000 Subject: [PATCH 13/19] after rebase, it couldnt run. I fixed some issues so it runs to completion. --- vllm/v1/attention/backends/pallas.py | 1 + vllm/v1/worker/tpu_model_runner.py | 20 ++++++++------------ vllm/v1/worker/tpu_worker.py | 3 +-- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index e6edca7953dc..5e9f95880894 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -155,6 +155,7 @@ def forward( write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale + # TODO(xw32): change use_kernel=False output = torch.ops.xla.ragged_paged_attention( query, key_cache, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 3a7bcbd84516..3cc2a4996857 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -13,6 +13,7 @@ import torch_xla.core.xla_model as xm import torch_xla.runtime as xr +from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import VllmConfig @@ -129,7 +130,8 @@ def __init__( self.slot_mapping_np = self.slot_mapping_cpu.numpy() # self.input_batch.block_table has a shape of [max_num_reqs, max_num_blocks_per_req]. - # Because we want the block_table.shape[0] to be num_tokens nad block_table[1] to be multiple of NUM_KV_PAGES_PER_BLOCK, so we create a separate one. + # To reduce the number of recompilation, we want the block_table.shape[0] to be num_tokens. + # To make the block_table to be compatible with the paged attention kernel, we want the block_table[1] to be multiple of NUM_KV_PAGES_PER_BLOCK. padded_max_num_blocks_per_req = _get_padded_number(self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) self.block_table_cpu = torch.zeros((self.max_num_tokens, padded_max_num_blocks_per_req), dtype=self.input_batch.block_table.get_cpu_tensor().dtype, @@ -435,7 +437,6 @@ def execute_model( token_ids=self.input_ids, position_ids=self.position_ids, kv_caches=self.kv_caches, - attn_metadata=attn_metadata, ) hidden_states = hidden_states[:total_num_scheduled_tokens] num_reqs = self.input_batch.num_reqs @@ -531,12 +532,10 @@ def load_model(self) -> None: fullgraph=True, dynamic=False) - # @torch.inference_mode() def dummy_run( self, kv_caches, num_tokens: int, - seq_len: Optional[int] = None, ) -> None: input_ids = torch.zeros(num_tokens, dtype=torch.int32, @@ -573,7 +572,7 @@ def dummy_run( with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None - self.model(input_ids, position_ids, attn_metadata, kv_caches) + self.model(input_ids, position_ids, kv_caches) def capture_model(self) -> None: """Compile the model.""" @@ -588,6 +587,8 @@ def capture_model(self) -> None: logger.info(" -- num_tokens: %d", num_tokens) xm.mark_step() xm.wait_device_ops() + # TODO(xw32): remove the next line. + break # temperarily reduce precompile time. if num_tokens >= self.scheduler_config.max_num_batched_tokens: break num_tokens *= 2 @@ -663,12 +664,8 @@ def forward( """Executes the forward pass of the model and samples the next token. Args: - token_ids: The input token IDs of shape [batch_size, seq_len]. - position_ids: The input position IDs of shape [batch_size, seq_len]. - input_lens: The actual input lengths of shape [batch_size]. - t: The sampling temperature of shape [batch_size]. - p: The top-p probability of shape [batch_size]. - num_samples: Number of samples to draw from each logits vector. + token_ids: The input token IDs of shape [num_tokens]. + position_ids: The input position IDs of shape [num_tokens]. kv_caches: The key and value caches. They can be None during the memory profiling at initialization. """ @@ -700,7 +697,6 @@ def forward( token_ids, position_ids, kv_caches, - attn_metadata, ) return hidden_states diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 1fa832262817..405dc628ee1c 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -126,8 +126,7 @@ def determine_available_memory(self) -> int: self.model_runner.dummy_run( runner_kv_caches, - num_tokens=1, - seq_len=self.scheduler_config.max_num_batched_tokens, + num_tokens=self.scheduler_config.max_num_batched_tokens, ) # Synchronize before measuring the memory usage. From e8a7f9b1714c84575e17c6ec0ec5fabe3c737fb0 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 27 Feb 2025 19:45:47 +0000 Subject: [PATCH 14/19] clean up --- examples/offline_inference/basic/basic.py | 4 ++-- vllm/v1/attention/backends/pallas.py | 3 +-- vllm/v1/worker/tpu_model_runner.py | 16 ---------------- 3 files changed, 3 insertions(+), 20 deletions(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index c6ef22c0dc1f..792afb9c8b7f 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -10,10 +10,10 @@ "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams() #temperature=0.8, top_p=0.95) +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16) +llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 5e9f95880894..451055050e92 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -155,7 +155,6 @@ def forward( write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale - # TODO(xw32): change use_kernel=False output = torch.ops.xla.ragged_paged_attention( query, key_cache, @@ -166,7 +165,7 @@ def forward( attn_metadata.num_seqs, num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, num_queries_per_block=NUM_QUERIES_PER_BLOCK, - use_kernel=True, + use_kernel=False, ) return output.reshape(num_tokens, hidden_size) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 3cc2a4996857..74c6933103d2 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -587,8 +587,6 @@ def capture_model(self) -> None: logger.info(" -- num_tokens: %d", num_tokens) xm.mark_step() xm.wait_device_ops() - # TODO(xw32): remove the next line. - break # temperarily reduce precompile time. if num_tokens >= self.scheduler_config.max_num_batched_tokens: break num_tokens *= 2 @@ -596,20 +594,6 @@ def capture_model(self) -> None: logger.info("Compilation finished in in %.2f [secs].", end - start) - # GPU way of warming up. - # start = time.perf_counter() - # num_tokens_list = [512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1] - # # The num_tokens_list below is how GPU precompiles. - # for num_tokens in num_tokens_list: - # self.dummy_run(self.kv_caches, num_tokens) - # logger.info(" -- num_tokens: %d", num_tokens) - # xm.mark_step() - # xm.wait_device_ops() - # if num_tokens >= self.scheduler_config.max_num_batched_tokens: - # break - # end = time.perf_counter() - # logger.info("Compilation finished in in %.2f [secs].", end - start) - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. From 1a942d54ccb76378f1e25ea7401bd2fec276f648 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 27 Feb 2025 21:07:41 +0000 Subject: [PATCH 15/19] run linter --- vllm/v1/attention/backends/pallas.py | 10 ++- vllm/v1/worker/tpu_model_runner.py | 103 +++++++++++++++------------ 2 files changed, 61 insertions(+), 52 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 451055050e92..e496ea037884 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -7,11 +7,9 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops. from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType) + AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState - NUM_QUERIES_PER_BLOCK = 16 NUM_KV_PAGES_PER_BLOCK = 128 @@ -53,7 +51,7 @@ def swap_blocks( @dataclass -class PallasMetadata(): +class PallasMetadata: # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -70,7 +68,6 @@ class PallasMetadata(): num_seqs: int - class PallasAttentionBackendImpl(AttentionImpl): def __init__( @@ -88,7 +85,8 @@ def __init__( ) -> None: if blocksparse_params is not None: raise ValueError( - "Paged attention Pallas kernel does not support block-sparse attention.") + "Paged attention Pallas kernel does not support block-sparse attention." + ) self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 74c6933103d2..6d574c9745c4 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import enum import time -from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from unittest.mock import patch @@ -13,7 +11,6 @@ import torch_xla.core.xla_model as xm import torch_xla.runtime as xr -from vllm.attention import AttentionMetadata from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import VllmConfig @@ -22,13 +19,13 @@ from vllm.model_executor.model_loader import get_model from vllm.sampling_params import SamplingType from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available -from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, - PallasMetadata, +from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, NUM_QUERIES_PER_BLOCK, - NUM_KV_PAGES_PER_BLOCK) + PallasAttentionBackend, + PallasMetadata) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) -from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput +from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -48,7 +45,7 @@ INVALID_TOKEN_ID = -1 -class TPUModelRunner(): +class TPUModelRunner: def __init__( self, @@ -80,8 +77,8 @@ def __init__( self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.max_num_tokens = scheduler_config.max_num_batched_tokens # 8192 - self.max_num_reqs = scheduler_config.max_num_seqs # 16 + self.max_num_tokens = scheduler_config.max_num_batched_tokens + self.max_num_reqs = scheduler_config.max_num_seqs # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( @@ -115,8 +112,8 @@ def __init__( # The pytorch tensor and numpy array share the same buffer. # Sometimes the numpy op is faster so we create both. self.input_ids_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu") + dtype=torch.int32, + device="cpu") self.input_ids_np = self.input_ids_cpu.numpy() self.positions_cpu = torch.zeros(self.max_num_tokens, @@ -132,10 +129,12 @@ def __init__( # self.input_batch.block_table has a shape of [max_num_reqs, max_num_blocks_per_req]. # To reduce the number of recompilation, we want the block_table.shape[0] to be num_tokens. # To make the block_table to be compatible with the paged attention kernel, we want the block_table[1] to be multiple of NUM_KV_PAGES_PER_BLOCK. - padded_max_num_blocks_per_req = _get_padded_number(self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) - self.block_table_cpu = torch.zeros((self.max_num_tokens, padded_max_num_blocks_per_req), - dtype=self.input_batch.block_table.get_cpu_tensor().dtype, - device="cpu") + padded_max_num_blocks_per_req = _get_padded_number( + self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) + self.block_table_cpu = torch.zeros( + (self.max_num_tokens, padded_max_num_blocks_per_req), + dtype=self.input_batch.block_table.get_cpu_tensor().dtype, + device="cpu") self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, dtype=torch.int32, @@ -325,9 +324,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens_per_req.append(num_tokens) - max_num_scheduled_tokens_all_reqs = max(max_num_scheduled_tokens_all_reqs, - num_tokens) - num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, dtype=np.int32) + max_num_scheduled_tokens_all_reqs = max( + max_num_scheduled_tokens_all_reqs, num_tokens) + num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, + dtype=np.int32) assert max_num_scheduled_tokens_all_reqs > 0 # Get request indices. @@ -341,13 +341,13 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # For each scheduled token, what is its position in the corresponding req. arange = np.concatenate( [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) - + # Get positions. positions_np = self.positions_np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) - + # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] @@ -362,7 +362,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - + # Calculate the slot mapping. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] @@ -381,27 +381,40 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping_np[:total_num_scheduled_tokens]) - + # Prepare the attention metadata. self.query_start_loc_np[0] = 0 np.cumsum(num_scheduled_tokens_per_req, out=self.query_start_loc_np[1:num_reqs + 1]) - + self.seq_lens_np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens_per_req) # Do the padding and copy the tensors to the TPU. - padded_total_num_scheduled_tokens = _get_padded_number(total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK) - self.input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens].to(self.device) - self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to(self.device) + padded_total_num_scheduled_tokens = _get_padded_number( + total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK) + self.input_ids = self.input_ids_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) + self.position_ids = self.positions_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID - slot_mapping = self.slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(self.device) - padded_block_table = self.block_table_cpu[:padded_total_num_scheduled_tokens] - padded_block_table[:num_reqs, :self.max_num_blocks_per_req] = self.input_batch.block_table.get_cpu_tensor()[:num_reqs] + slot_mapping = self.slot_mapping_cpu[: + padded_total_num_scheduled_tokens].to( + self.device) + padded_block_table = self.block_table_cpu[: + padded_total_num_scheduled_tokens] + padded_block_table[:num_reqs, :self. + max_num_blocks_per_req] = self.input_batch.block_table.get_cpu_tensor( + )[:num_reqs] padded_block_table = padded_block_table.to(self.device) - query_start_loc = self.query_start_loc_cpu[:padded_total_num_scheduled_tokens+1].to(self.device) - seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to(self.device) + query_start_loc = self.query_start_loc_cpu[: + padded_total_num_scheduled_tokens + + 1].to(self.device) + seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to( + self.device) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, @@ -418,7 +431,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): logits_indices = query_start_loc[1:] - 1 return attn_metadata, logits_indices - @torch.no_grad() def execute_model( self, @@ -432,7 +444,7 @@ def execute_model( total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens # Run the decoder - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( token_ids=self.input_ids, position_ids=self.position_ids, @@ -461,14 +473,14 @@ def execute_model( if generator is not None: # This relies on cuda-specific torch-internal impl details generator.set_offset(generator.get_offset() - 4) - + # num_reqs entries should be non-None assert all( req_id is not None for req_id in self.input_batch.req_ids[:num_reqs]), "req_ids contains None" req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) - prompt_logprobs_dict: Dict[str, LogprobsTensors] = {} + prompt_logprobs_dict = {} for req_id in self.input_batch.req_ids[:num_reqs]: prompt_logprobs_dict[req_id] = None @@ -526,7 +538,6 @@ def load_model(self) -> None: xm.mark_step() xm.wait_device_ops() model = ModelWrapperV1(model) - # self.model = model self.model = torch.compile(model, backend="openxla", fullgraph=True, @@ -546,12 +557,14 @@ def dummy_run( slot_mapping = torch.zeros(num_tokens, dtype=torch.int64, device=self.device) - block_tables = torch.zeros( - (num_tokens, self.block_table_cpu.shape[1]), - dtype=torch.int32, - device=self.device) + block_tables = torch.zeros((num_tokens, self.block_table_cpu.shape[1]), + dtype=torch.int32, + device=self.device) query_lens = [1] * num_tokens - query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32).to(self.device) + query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, + dtype=torch.int32), + dim=0, + dtype=torch.int32).to(self.device) context_lens = torch.ones((num_tokens, ), dtype=torch.int32, device=self.device) @@ -581,7 +594,6 @@ def capture_model(self) -> None: start = time.perf_counter() num_tokens = 16 - # The num_tokens_list below is how GPU precompiles. while True: self.dummy_run(self.kv_caches, num_tokens) logger.info(" -- num_tokens: %d", num_tokens) @@ -591,8 +603,7 @@ def capture_model(self) -> None: break num_tokens *= 2 end = time.perf_counter() - logger.info("Compilation finished in in %.2f [secs].", - end - start) + logger.info("Compilation finished in in %.2f [secs].", end - start) def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -682,7 +693,7 @@ def forward( position_ids, kv_caches, ) - + return hidden_states def compute_logits( From 059289095a440230f9b9267229e5faf7e8a474d2 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Thu, 27 Feb 2025 21:51:36 +0000 Subject: [PATCH 16/19] Bump torch_xla version again Signed-off-by: Xiongfei Wei --- requirements-tpu.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 02cad825fcd3..725b1a2e4a58 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -18,7 +18,7 @@ ray[default] --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.7.0.dev20250220+cpu -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250220+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250220+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250220+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.7.0.dev20250226+cpu +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" From b9bac9af5b6c65703053ba25b936e6b282e70b48 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 28 Feb 2025 03:33:24 +0000 Subject: [PATCH 17/19] Fix lint issues Signed-off-by: mgoin --- vllm/v1/attention/backends/pallas.py | 9 ++++----- vllm/v1/outputs.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 4 ++-- vllm/v1/worker/tpu_model_runner.py | 27 ++++++++++++--------------- 4 files changed, 19 insertions(+), 23 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index e496ea037884..b0727768e5ad 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type import torch -import torch_xla.experimental.custom_kernel # Required to register custom ops. from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) @@ -84,9 +83,8 @@ def __init__( attn_type: str = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: - raise ValueError( - "Paged attention Pallas kernel does not support block-sparse attention." - ) + raise ValueError("Paged attention Pallas kernel does " + "not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -130,7 +128,8 @@ def forward( query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = ([num_kv_heads, num_blocks, block_size, head_size], [num_kv_heads, num_blocks, block_size, head_size]) + kv_cache = ([num_kv_heads, num_blocks, block_size, head_size], + [num_kv_heads, num_blocks, block_size, head_size]) attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 0c8eca38ade7..f461d52cc984 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -79,4 +79,4 @@ class ModelRunnerOutput: # [prompt_len, num_prompt_logprobs] # [prompt_len, num_prompt_logprobs] # [prompt_len] - prompt_logprobs_dict: Dict[str, LogprobsTensors] + prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c9212d993f2b..320cc2a47bac 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1071,12 +1071,12 @@ def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, scheduler_output: "SchedulerOutput", - ) -> Dict[str, LogprobsTensors]: + ) -> Dict[str, Optional[LogprobsTensors]]: num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs if not num_prompt_logprobs_dict: return {} - prompt_logprobs_dict: Dict[str, LogprobsTensors] = {} + prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} # Since prompt logprobs are a rare feature, prioritize simple, # maintainable loop over optimal performance. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 6d574c9745c4..d16a0a4165c7 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -25,7 +25,7 @@ PallasMetadata) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -37,11 +37,6 @@ # Here we utilize the behavior that out-of-bound index is ignored. # FIXME(woosuk): Find a more reliable way to prevent possible bugs. _PAD_SLOT_ID = 1_000_000_000 -# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. -_ENABLE_TOP_P = False -# FIXME(woosuk): A temporary hack to support `n > 1`. -# This can significantly affect the performance if too large. -_MAX_NUM_SAMPLES = 128 INVALID_TOKEN_ID = -1 @@ -126,9 +121,12 @@ def __init__( device="cpu") self.slot_mapping_np = self.slot_mapping_cpu.numpy() - # self.input_batch.block_table has a shape of [max_num_reqs, max_num_blocks_per_req]. - # To reduce the number of recompilation, we want the block_table.shape[0] to be num_tokens. - # To make the block_table to be compatible with the paged attention kernel, we want the block_table[1] to be multiple of NUM_KV_PAGES_PER_BLOCK. + # self.input_batch.block_table has a shape of [max_num_reqs, + # max_num_blocks_per_req]. To reduce the number of recompilation, + # we want the block_table.shape[0] to be num_tokens. + # To make the block_table to be compatible with the paged attention + # kernel, we want the block_table[1] to be multiple of + # NUM_KV_PAGES_PER_BLOCK. padded_max_num_blocks_per_req = _get_padded_number( self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) self.block_table_cpu = torch.zeros( @@ -160,7 +158,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: the input GPU tensors for the model. Returns: - True if there is a new/resumed/paused/finished request in the batch. + True if there is a new/resumed/paused/finished request. If False, we can skip copying SamplingMetadata to the GPU. """ # Remove finished requests from the cached states. @@ -338,7 +336,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Get batched arange. # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # For each scheduled token, what is its position in the corresponding req. + # For each scheduled token, what is its position in corresponding req. arange = np.concatenate( [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) @@ -406,9 +404,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.device) padded_block_table = self.block_table_cpu[: padded_total_num_scheduled_tokens] - padded_block_table[:num_reqs, :self. - max_num_blocks_per_req] = self.input_batch.block_table.get_cpu_tensor( - )[:num_reqs] + padded_block_table[:num_reqs, :self.max_num_blocks_per_req] = ( + self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) padded_block_table = padded_block_table.to(self.device) query_start_loc = self.query_start_loc_cpu[: padded_total_num_scheduled_tokens @@ -480,7 +477,7 @@ def execute_model( self.input_batch.req_ids[:num_reqs]), "req_ids contains None" req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) - prompt_logprobs_dict = {} + prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {} for req_id in self.input_batch.req_ids[:num_reqs]: prompt_logprobs_dict[req_id] = None From 6bf9e68a434eadc59f2940865a8b8dba6b72f82c Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 28 Feb 2025 03:34:13 +0000 Subject: [PATCH 18/19] Revert basic Signed-off-by: mgoin --- examples/offline_inference/basic/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 792afb9c8b7f..a6e96c0bb433 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -21,4 +21,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file From 175224a13e367cff45c744c490fbff70ef194fd5 Mon Sep 17 00:00:00 2001 From: mgoin Date: Fri, 28 Feb 2025 03:37:20 +0000 Subject: [PATCH 19/19] Keep import torch_xla.experimental.custom_kernel Signed-off-by: mgoin --- vllm/v1/attention/backends/pallas.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index b0727768e5ad..a9f7b3fd4471 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -4,6 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type import torch +# Required to register custom ops. +import torch_xla.experimental.custom_kernel # noqa: F401 from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType)