From 264f44e1d7161dbf6e03fb364e128192041a0ffa Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 17 Jul 2025 13:53:35 +0000 Subject: [PATCH 01/15] init Signed-off-by: Sage Moore --- vllm/v1/attention/backends/utils.py | 70 +++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index db6eaa558642..3f239bc764df 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -65,6 +65,76 @@ def __post_init__(self): self.slot_mapping[self.num_actual_tokens:].fill_(-1) +def slice_query_start_locs( + query_start_loc: torch.Tensor, + req_slice: slice, +) -> torch.Tensor: + return query_start_loc[req_slice.start: req_slice.stop + 1] -\ + query_start_loc[req_slice.start] + + +def make_metadata_with_slice(ubatch_slice, query_start_loc, + query_start_loc_cpu, seq_lens, seq_lens_cpu, + num_computed_tokens_cpu, num_reqs, + num_actual_tokens, max_query_len, + block_table_tensor, + slot_mapping) -> CommonAttentionMetadata: + + req_slice = ubatch_slice[0] + token_slice = ubatch_slice[1] + + query_start_loc = slice_query_start_locs(query_start_loc, req_slice) + + # TODO (Sage) Make sure that this is correct + query_start_loc_cpu = slice_query_start_locs(query_start_loc_cpu, + req_slice) + + seq_lens = seq_lens[req_slice] + seq_lens_cpu = seq_lens_cpu[req_slice] + num_computed_tokens_cpu = num_computed_tokens_cpu[req_slice] + + num_requests = req_slice.stop - req_slice.start + num_actual_tokens = token_slice.stop - token_slice.start + max_query_len = 1 + + block_table_tensor = block_table_tensor[token_slice] + slot_mapping = slot_mapping[token_slice] + + return CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=num_requests, + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + block_table_tensor=block_table_tensor, + slot_mapping=slot_mapping, + ) + + +def split_attn_metadata( + token_slices, + common_attn_metadata, +) -> list[CommonAttentionMetadata]: + results = [] + for token_slice in token_slices: + results.append( + make_metadata_with_slice( + token_slice, common_attn_metadata.query_start_loc, + common_attn_metadata.query_start_loc_cpu, + common_attn_metadata.seq_lens, + common_attn_metadata.seq_lens_cpu, + common_attn_metadata.num_computed_tokens_cpu, + common_attn_metadata.num_reqs, + common_attn_metadata.num_actual_tokens, + common_attn_metadata.max_query_len, + common_attn_metadata.block_table_tensor, + common_attn_metadata.slot_mapping)) + return results + + M = TypeVar("M") From 83200acd6271b2744424d039c8fb152e6ab56a5f Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 17 Jul 2025 13:55:33 +0000 Subject: [PATCH 02/15] rename to ubatch slice Signed-off-by: Sage Moore --- vllm/v1/attention/backends/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 3f239bc764df..a15d295afd32 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -115,14 +115,14 @@ def make_metadata_with_slice(ubatch_slice, query_start_loc, def split_attn_metadata( - token_slices, + ubatch_slices, common_attn_metadata, ) -> list[CommonAttentionMetadata]: results = [] - for token_slice in token_slices: + for ubatch_slice in ubatch_slices: results.append( make_metadata_with_slice( - token_slice, common_attn_metadata.query_start_loc, + ubatch_slice, common_attn_metadata.query_start_loc, common_attn_metadata.query_start_loc_cpu, common_attn_metadata.seq_lens, common_attn_metadata.seq_lens_cpu, From 7bde512add1ca441207aa6be0a2ec2438073f0de Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 17 Jul 2025 13:59:11 +0000 Subject: [PATCH 03/15] add comments Signed-off-by: Sage Moore --- vllm/v1/attention/backends/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index a15d295afd32..e3b67a0e0bf2 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -69,6 +69,11 @@ def slice_query_start_locs( query_start_loc: torch.Tensor, req_slice: slice, ) -> torch.Tensor: + """ + Creates a new query_start_loc that corresponds to the requests in req_slice. + Note: This function creates a new tensor to hold the new query_start_locs. + This will break cudagraph compatibility. + """ return query_start_loc[req_slice.start: req_slice.stop + 1] -\ query_start_loc[req_slice.start] From 98aeec7d9a8cdb9cb7b65de1920de525f79f1089 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 17 Jul 2025 14:05:12 +0000 Subject: [PATCH 04/15] add types Signed-off-by: Sage Moore --- vllm/v1/attention/backends/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index e3b67a0e0bf2..bd963699ed25 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,7 +4,8 @@ import functools from abc import abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar +from typing import (TYPE_CHECKING, ClassVar, Generic, Optional, TypeAlias, + TypeVar) import numpy as np import torch @@ -65,6 +66,10 @@ def __post_init__(self): self.slot_mapping[self.num_actual_tokens:].fill_(-1) +UbatchSlice: TypeAlias = tuple[slice, slice] +UBatchSlices: TypeAlias = list[UbatchSlice] + + def slice_query_start_locs( query_start_loc: torch.Tensor, req_slice: slice, From ce69ff7056639c3637cd2874036aed356eaedae2 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 17 Jul 2025 15:36:35 +0000 Subject: [PATCH 05/15] add unit test Signed-off-by: Sage Moore --- .../v1/attention/test_attention_splitting.py | 30 +++++++++++++++++++ vllm/v1/attention/backends/utils.py | 20 ++++++------- 2 files changed, 40 insertions(+), 10 deletions(-) create mode 100644 tests/v1/attention/test_attention_splitting.py diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py new file mode 100644 index 000000000000..23138dadfc16 --- /dev/null +++ b/tests/v1/attention/test_attention_splitting.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for v1 attention backends without GPUModelRunner dependency.""" + +import pytest +import torch + +from tests.v1.attention.test_attention_backends import BATCH_SPECS +from tests.v1.attention.utils import (create_common_attn_metadata, + create_vllm_config) + + +@pytest.mark.parametrize("batch_spec_name", [ + "small_decode", "small_prefill", "mixed_small", "medium_decode", + "medium_prefill", "mixed_medium" +]) +@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) +def test_attention_splitting_correctness(batch_spec_name: str, model: str): + """ + """ + batch_spec = BATCH_SPECS[batch_spec_name] + vllm_config = create_vllm_config(model_name=model) + device = torch.device("cuda:0") + + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device) + + # assert backend_output.shape == sdpa_output.shape, ( + # f"[{backend_name}] shape {backend_output.shape} != " + # f"SDPA shape {sdpa_output.shape}") diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index bd963699ed25..57be0c89f894 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -83,12 +83,13 @@ def slice_query_start_locs( query_start_loc[req_slice.start] -def make_metadata_with_slice(ubatch_slice, query_start_loc, - query_start_loc_cpu, seq_lens, seq_lens_cpu, - num_computed_tokens_cpu, num_reqs, - num_actual_tokens, max_query_len, - block_table_tensor, - slot_mapping) -> CommonAttentionMetadata: +def _make_metadata_with_slice( + ubatch_slice: UbatchSlice, query_start_loc: torch.Tensor, + query_start_loc_cpu: torch.Tensor, seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, num_computed_tokens_cpu: torch.Tensor, + num_actual_tokens: int, max_query_len: int, + block_table_tensor: torch.Tensor, + slot_mapping: torch.Tensor) -> CommonAttentionMetadata: req_slice = ubatch_slice[0] token_slice = ubatch_slice[1] @@ -125,19 +126,18 @@ def make_metadata_with_slice(ubatch_slice, query_start_loc, def split_attn_metadata( - ubatch_slices, - common_attn_metadata, + ubatch_slices: UBatchSlices, + common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: results = [] for ubatch_slice in ubatch_slices: results.append( - make_metadata_with_slice( + _make_metadata_with_slice( ubatch_slice, common_attn_metadata.query_start_loc, common_attn_metadata.query_start_loc_cpu, common_attn_metadata.seq_lens, common_attn_metadata.seq_lens_cpu, common_attn_metadata.num_computed_tokens_cpu, - common_attn_metadata.num_reqs, common_attn_metadata.num_actual_tokens, common_attn_metadata.max_query_len, common_attn_metadata.block_table_tensor, From a68555895e3cc53eef22d22aef1ee703594eea92 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 18 Jul 2025 00:06:38 +0000 Subject: [PATCH 06/15] add unit test Signed-off-by: Sage Moore --- .../v1/attention/test_attention_splitting.py | 191 ++++++++++++++++-- vllm/v1/attention/backends/utils.py | 5 +- 2 files changed, 172 insertions(+), 24 deletions(-) diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 23138dadfc16..68b693e8e834 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -6,25 +6,172 @@ import torch from tests.v1.attention.test_attention_backends import BATCH_SPECS -from tests.v1.attention.utils import (create_common_attn_metadata, - create_vllm_config) - - -@pytest.mark.parametrize("batch_spec_name", [ - "small_decode", "small_prefill", "mixed_small", "medium_decode", - "medium_prefill", "mixed_medium" -]) -@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) -def test_attention_splitting_correctness(batch_spec_name: str, model: str): - """ - """ - batch_spec = BATCH_SPECS[batch_spec_name] - vllm_config = create_vllm_config(model_name=model) - device = torch.device("cuda:0") - - common_attn_metadata = create_common_attn_metadata( - batch_spec, vllm_config.cache_config.block_size, device) - - # assert backend_output.shape == sdpa_output.shape, ( - # f"[{backend_name}] shape {backend_output.shape} != " - # f"SDPA shape {sdpa_output.shape}") +from tests.v1.attention.utils import create_common_attn_metadata +from vllm.v1.attention.backends.utils import (_make_metadata_with_slice, + slice_query_start_locs, + split_attn_metadata) + + +@pytest.fixture +def sample_query_start_loc(): + """Sample query_start_loc tensor for testing""" + return torch.tensor([0, 5, 12, 20, 35, 50]) + + +def test_basic_slice_middle(sample_query_start_loc): + """Test slicing from middle of tensor""" + req_slice = slice(1, 3) # slice from index 1 to 3 + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 7, 15]) # [5, 12, 20] - 5 + assert torch.equal(result, expected) + + +def test_slice_from_beginning(sample_query_start_loc): + """Test slicing from the beginning of tensor""" + req_slice = slice(0, 2) # slice from index 0 to 2 + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 5, 12]) # [0, 5, 12] - 0 + assert torch.equal(result, expected) + + +def test_slice_to_end(sample_query_start_loc): + """Test slicing to the end of tensor""" + req_slice = slice(3, 5) # slice from index 3 to 5 (last index) + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 15, 30]) # [20, 35, 50] - 20 + assert torch.equal(result, expected) + + +def test_single_element_slice(sample_query_start_loc): + """Test slice that results in single element""" + req_slice = slice(2, 2) # slice from index 2 to 2 + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0]) # [12] - 12 + assert torch.equal(result, expected) + + +def test_full_tensor_slice(sample_query_start_loc): + """Test slicing the entire tensor""" + req_slice = slice(0, 5) # slice entire tensor + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 5, 12, 20, 35, 50]) # original - 0 + assert torch.equal(result, expected) + + +def test_slice_bounds_edge_cases(sample_query_start_loc): + # Test slice that goes exactly to the last element + req_slice = slice(4, 4) # Last index + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0]) # [50] - 50 + assert torch.equal(result, expected) + + +@pytest.fixture +def small_decode_metadata(): + """Create metadata for small decode batch""" + batch_spec = BATCH_SPECS["small_decode"] + device = torch.device("cpu") + return create_common_attn_metadata(batch_spec, + block_size=16, + device=device) + + +@pytest.fixture +def large_decode_metadata(): + """Create metadata for small decode batch""" + batch_spec = BATCH_SPECS["large_decode"] + device = torch.device("cpu") + return create_common_attn_metadata(batch_spec, + block_size=16, + device=device) + + +@pytest.fixture +def mixed_small_metadata(): + """Create metadata for mixed small batch""" + batch_spec = BATCH_SPECS["mixed_small"] + device = torch.device("cpu") + return create_common_attn_metadata(batch_spec, + block_size=16, + device=device) + + +# Tests for _make_metadata_with_slice +def test_make_metadata_with_slice_decode_batch(small_decode_metadata): + """Test slicing decode batch metadata""" + # Split first request only + ubatch_slice = (slice(0, 1), slice(0, 1)) # First request, first token + + result = _make_metadata_with_slice( + ubatch_slice, small_decode_metadata.query_start_loc, + small_decode_metadata.query_start_loc_cpu, + small_decode_metadata.seq_lens, small_decode_metadata.seq_lens_cpu, + small_decode_metadata.num_computed_tokens_cpu, + small_decode_metadata.num_actual_tokens, + small_decode_metadata.max_query_len, + small_decode_metadata.block_table_tensor, + small_decode_metadata.slot_mapping) + + # Check sliced results + assert result.num_reqs == 1 # slice(0, 0) gives 0 requests + assert result.num_actual_tokens == 1 # slice(0, 1) gives 1 token + assert result.max_query_len == 1 # Always set to 1 + assert torch.equal(result.query_start_loc, torch.tensor([0, 1])) + assert torch.equal(result.seq_lens, torch.tensor([32])) + + +def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata): + """Test slicing mixed batch metadata""" + # Split middle requests + ubatch_slice = (slice(1, 3), slice(1, 7)) # Requests 1-2, tokens 1-7 + + result = _make_metadata_with_slice( + ubatch_slice, mixed_small_metadata.query_start_loc, + mixed_small_metadata.query_start_loc_cpu, + mixed_small_metadata.seq_lens, mixed_small_metadata.seq_lens_cpu, + mixed_small_metadata.num_computed_tokens_cpu, + mixed_small_metadata.num_actual_tokens, + mixed_small_metadata.max_query_len, + mixed_small_metadata.block_table_tensor, + mixed_small_metadata.slot_mapping) + + # Check sliced results + assert result.num_reqs == 2 # slice(1, 3) gives 2 requests + assert result.num_actual_tokens == 6 # slice(1, 7) gives 5 tokens + assert result.max_query_len == 5 + # Query start should be offset: [1, 2] -> [0, 1] + assert torch.equal(result.query_start_loc, torch.tensor([0, 1, 6])) + # Should get second sequence length + assert torch.equal(result.seq_lens, torch.tensor([40, 48])) + + +# # Tests for split_attn_metadata +def test_split_attn_metadata_decode_batch(large_decode_metadata): + """Test splitting decode batch into two parts""" + num_tokens = large_decode_metadata.num_reqs + mid_point = num_tokens // 2 + ubatch_slices = [ + (slice(0, mid_point), slice(0, mid_point)), # First request + (slice(mid_point, num_tokens), slice(mid_point, + num_tokens)), # Second request + ] + + results = split_attn_metadata(ubatch_slices, large_decode_metadata) + + assert len(results) == 2 + + # Check first split + assert results[0].num_reqs == mid_point + assert results[0].num_actual_tokens == mid_point + assert torch.equal(results[0].seq_lens, torch.tensor([2048] * mid_point)) + + # Check second split + assert results[1].num_reqs == mid_point + assert results[1].num_actual_tokens == mid_point + assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 57be0c89f894..b304ce83bc57 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -96,7 +96,6 @@ def _make_metadata_with_slice( query_start_loc = slice_query_start_locs(query_start_loc, req_slice) - # TODO (Sage) Make sure that this is correct query_start_loc_cpu = slice_query_start_locs(query_start_loc_cpu, req_slice) @@ -106,7 +105,9 @@ def _make_metadata_with_slice( num_requests = req_slice.stop - req_slice.start num_actual_tokens = token_slice.stop - token_slice.start - max_query_len = 1 + max_query_len = int( + torch.max(torch.abs(query_start_loc[1:] - + query_start_loc[:-1])).item()) block_table_tensor = block_table_tensor[token_slice] slot_mapping = slot_mapping[token_slice] From 065d420520f0f0479587c2865e40c1f4d57623da Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 18 Jul 2025 00:22:11 +0000 Subject: [PATCH 07/15] minor comment update Signed-off-by: Sage Moore --- tests/v1/attention/test_attention_splitting.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 68b693e8e834..7a7c587684d1 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for v1 attention backends without GPUModelRunner dependency.""" import pytest import torch From bdcbaa490b9506822575aec02d7ac80dbcf148c8 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 18 Jul 2025 00:24:46 +0000 Subject: [PATCH 08/15] assert len(query_start_locs) >= 2 Signed-off-by: Sage Moore --- vllm/v1/attention/backends/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index b304ce83bc57..5cc239c82193 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -95,7 +95,7 @@ def _make_metadata_with_slice( token_slice = ubatch_slice[1] query_start_loc = slice_query_start_locs(query_start_loc, req_slice) - + assert len(query_start_loc >= 2) query_start_loc_cpu = slice_query_start_locs(query_start_loc_cpu, req_slice) From 0b2e62609dc32620ebe5c94a95b5aeb10668c7c4 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 18 Jul 2025 14:17:24 +0000 Subject: [PATCH 09/15] refactor _make_metadata_with_slice Signed-off-by: Sage Moore --- .../v1/attention/test_attention_splitting.py | 20 ++--------- vllm/v1/attention/backends/utils.py | 36 +++++++------------ 2 files changed, 14 insertions(+), 42 deletions(-) diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 7a7c587684d1..398ce4eac804 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -107,15 +107,7 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata): # Split first request only ubatch_slice = (slice(0, 1), slice(0, 1)) # First request, first token - result = _make_metadata_with_slice( - ubatch_slice, small_decode_metadata.query_start_loc, - small_decode_metadata.query_start_loc_cpu, - small_decode_metadata.seq_lens, small_decode_metadata.seq_lens_cpu, - small_decode_metadata.num_computed_tokens_cpu, - small_decode_metadata.num_actual_tokens, - small_decode_metadata.max_query_len, - small_decode_metadata.block_table_tensor, - small_decode_metadata.slot_mapping) + result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata) # Check sliced results assert result.num_reqs == 1 # slice(0, 0) gives 0 requests @@ -130,15 +122,7 @@ def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata): # Split middle requests ubatch_slice = (slice(1, 3), slice(1, 7)) # Requests 1-2, tokens 1-7 - result = _make_metadata_with_slice( - ubatch_slice, mixed_small_metadata.query_start_loc, - mixed_small_metadata.query_start_loc_cpu, - mixed_small_metadata.seq_lens, mixed_small_metadata.seq_lens_cpu, - mixed_small_metadata.num_computed_tokens_cpu, - mixed_small_metadata.num_actual_tokens, - mixed_small_metadata.max_query_len, - mixed_small_metadata.block_table_tensor, - mixed_small_metadata.slot_mapping) + result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata) # Check sliced results assert result.num_reqs == 2 # slice(1, 3) gives 2 requests diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 5cc239c82193..dedc0d0f5880 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -84,24 +84,21 @@ def slice_query_start_locs( def _make_metadata_with_slice( - ubatch_slice: UbatchSlice, query_start_loc: torch.Tensor, - query_start_loc_cpu: torch.Tensor, seq_lens: torch.Tensor, - seq_lens_cpu: torch.Tensor, num_computed_tokens_cpu: torch.Tensor, - num_actual_tokens: int, max_query_len: int, - block_table_tensor: torch.Tensor, - slot_mapping: torch.Tensor) -> CommonAttentionMetadata: + ubatch_slice: UbatchSlice, + attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: req_slice = ubatch_slice[0] token_slice = ubatch_slice[1] - query_start_loc = slice_query_start_locs(query_start_loc, req_slice) + query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc, + req_slice) assert len(query_start_loc >= 2) - query_start_loc_cpu = slice_query_start_locs(query_start_loc_cpu, - req_slice) + query_start_loc_cpu = slice_query_start_locs( + attn_metadata.query_start_loc_cpu, req_slice) - seq_lens = seq_lens[req_slice] - seq_lens_cpu = seq_lens_cpu[req_slice] - num_computed_tokens_cpu = num_computed_tokens_cpu[req_slice] + seq_lens = attn_metadata.seq_lens[req_slice] + seq_lens_cpu = attn_metadata.seq_lens_cpu[req_slice] + num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[req_slice] num_requests = req_slice.stop - req_slice.start num_actual_tokens = token_slice.stop - token_slice.start @@ -109,8 +106,8 @@ def _make_metadata_with_slice( torch.max(torch.abs(query_start_loc[1:] - query_start_loc[:-1])).item()) - block_table_tensor = block_table_tensor[token_slice] - slot_mapping = slot_mapping[token_slice] + block_table_tensor = attn_metadata.block_table_tensor[token_slice] + slot_mapping = attn_metadata.slot_mapping[token_slice] return CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -133,16 +130,7 @@ def split_attn_metadata( results = [] for ubatch_slice in ubatch_slices: results.append( - _make_metadata_with_slice( - ubatch_slice, common_attn_metadata.query_start_loc, - common_attn_metadata.query_start_loc_cpu, - common_attn_metadata.seq_lens, - common_attn_metadata.seq_lens_cpu, - common_attn_metadata.num_computed_tokens_cpu, - common_attn_metadata.num_actual_tokens, - common_attn_metadata.max_query_len, - common_attn_metadata.block_table_tensor, - common_attn_metadata.slot_mapping)) + _make_metadata_with_slice(ubatch_slice, common_attn_metadata)) return results From 4be36406c02f75e77540b606a930081488fe90e5 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 18 Jul 2025 14:30:08 +0000 Subject: [PATCH 10/15] misc comments Signed-off-by: Sage Moore --- vllm/v1/attention/backends/utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index dedc0d0f5880..1afd99e30d3b 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -66,6 +66,7 @@ def __post_init__(self): self.slot_mapping[self.num_actual_tokens:].fill_(-1) +# The first slice is for requests. The second is for tokens. UbatchSlice: TypeAlias = tuple[slice, slice] UBatchSlices: TypeAlias = list[UbatchSlice] @@ -86,6 +87,10 @@ def slice_query_start_locs( def _make_metadata_with_slice( ubatch_slice: UbatchSlice, attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: + """ + This function creates a new CommonAttentionMetadata that covers the + requests included in ubatch_slice + """ req_slice = ubatch_slice[0] token_slice = ubatch_slice[1] @@ -127,6 +132,12 @@ def split_attn_metadata( ubatch_slices: UBatchSlices, common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: + """ + Creates a new CommonAttentionMetadata instance that covers the + requests for each UbatchSlice in ubatch_slices. + + Note: This function does not modify common_attn_metadata + """ results = [] for ubatch_slice in ubatch_slices: results.append( From 12d0e7e2806201cb3f3fab82b6ebf93bb1e50477 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 18 Jul 2025 14:36:29 +0000 Subject: [PATCH 11/15] convert UbatchSlice to a dataclass Signed-off-by: Sage Moore --- .../v1/attention/test_attention_splitting.py | 15 ++++--- vllm/v1/attention/backends/utils.py | 41 +++++++++++-------- 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 398ce4eac804..086bd8a7d328 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -6,7 +6,8 @@ from tests.v1.attention.test_attention_backends import BATCH_SPECS from tests.v1.attention.utils import create_common_attn_metadata -from vllm.v1.attention.backends.utils import (_make_metadata_with_slice, +from vllm.v1.attention.backends.utils import (UbatchSlice, + _make_metadata_with_slice, slice_query_start_locs, split_attn_metadata) @@ -105,7 +106,8 @@ def mixed_small_metadata(): def test_make_metadata_with_slice_decode_batch(small_decode_metadata): """Test slicing decode batch metadata""" # Split first request only - ubatch_slice = (slice(0, 1), slice(0, 1)) # First request, first token + ubatch_slice = UbatchSlice(slice(0, 1), + slice(0, 1)) # First request, first token result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata) @@ -120,7 +122,8 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata): def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata): """Test slicing mixed batch metadata""" # Split middle requests - ubatch_slice = (slice(1, 3), slice(1, 7)) # Requests 1-2, tokens 1-7 + ubatch_slice = UbatchSlice(slice(1, 3), + slice(1, 7)) # Requests 1-2, tokens 1-7 result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata) @@ -140,9 +143,9 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): num_tokens = large_decode_metadata.num_reqs mid_point = num_tokens // 2 ubatch_slices = [ - (slice(0, mid_point), slice(0, mid_point)), # First request - (slice(mid_point, num_tokens), slice(mid_point, - num_tokens)), # Second request + UbatchSlice(slice(0, mid_point), slice(0, mid_point)), # First request + UbatchSlice(slice(mid_point, num_tokens), + slice(mid_point, num_tokens)), # Second request ] results = split_attn_metadata(ubatch_slices, large_decode_metadata) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 1afd99e30d3b..26ba8ceac6d8 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -66,46 +66,53 @@ def __post_init__(self): self.slot_mapping[self.num_actual_tokens:].fill_(-1) -# The first slice is for requests. The second is for tokens. -UbatchSlice: TypeAlias = tuple[slice, slice] +@dataclass +class UbatchSlice: + request_slice: slice + token_slice: slice + + UBatchSlices: TypeAlias = list[UbatchSlice] def slice_query_start_locs( query_start_loc: torch.Tensor, - req_slice: slice, + request_slice: slice, ) -> torch.Tensor: """ - Creates a new query_start_loc that corresponds to the requests in req_slice. + Creates a new query_start_loc that corresponds to the requests in + request_slice. + Note: This function creates a new tensor to hold the new query_start_locs. This will break cudagraph compatibility. """ - return query_start_loc[req_slice.start: req_slice.stop + 1] -\ - query_start_loc[req_slice.start] + return query_start_loc[request_slice.start: request_slice.stop + 1] -\ + query_start_loc[request_slice.start] def _make_metadata_with_slice( ubatch_slice: UbatchSlice, attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: """ - This function creates a new CommonAttentionMetadata that covers the - requests included in ubatch_slice + This function creates a new CommonAttentionMetadata that corresponds to + the requests included in ubatch_slice """ - req_slice = ubatch_slice[0] - token_slice = ubatch_slice[1] + request_slice = ubatch_slice.request_slice + token_slice = ubatch_slice.token_slice query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc, - req_slice) + request_slice) assert len(query_start_loc >= 2) query_start_loc_cpu = slice_query_start_locs( - attn_metadata.query_start_loc_cpu, req_slice) + attn_metadata.query_start_loc_cpu, request_slice) - seq_lens = attn_metadata.seq_lens[req_slice] - seq_lens_cpu = attn_metadata.seq_lens_cpu[req_slice] - num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[req_slice] + seq_lens = attn_metadata.seq_lens[request_slice] + seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] + num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ + request_slice] - num_requests = req_slice.stop - req_slice.start + num_requests = request_slice.stop - request_slice.start num_actual_tokens = token_slice.stop - token_slice.start max_query_len = int( torch.max(torch.abs(query_start_loc[1:] - @@ -133,7 +140,7 @@ def split_attn_metadata( common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: """ - Creates a new CommonAttentionMetadata instance that covers the + Creates a new CommonAttentionMetadata instance that corresponds to the requests for each UbatchSlice in ubatch_slices. Note: This function does not modify common_attn_metadata From 3a8ab3828b908ef4775f810a511cd55cc7b9edac Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 18 Jul 2025 15:08:17 +0000 Subject: [PATCH 12/15] misc test fixes Signed-off-by: Sage Moore --- .../v1/attention/test_attention_splitting.py | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 086bd8a7d328..5f20d0a06de1 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -23,7 +23,7 @@ def test_basic_slice_middle(sample_query_start_loc): req_slice = slice(1, 3) # slice from index 1 to 3 result = slice_query_start_locs(sample_query_start_loc, req_slice) - expected = torch.tensor([0, 7, 15]) # [5, 12, 20] - 5 + expected = torch.tensor([0, 7, 15]) assert torch.equal(result, expected) @@ -32,7 +32,7 @@ def test_slice_from_beginning(sample_query_start_loc): req_slice = slice(0, 2) # slice from index 0 to 2 result = slice_query_start_locs(sample_query_start_loc, req_slice) - expected = torch.tensor([0, 5, 12]) # [0, 5, 12] - 0 + expected = torch.tensor([0, 5, 12]) assert torch.equal(result, expected) @@ -41,16 +41,16 @@ def test_slice_to_end(sample_query_start_loc): req_slice = slice(3, 5) # slice from index 3 to 5 (last index) result = slice_query_start_locs(sample_query_start_loc, req_slice) - expected = torch.tensor([0, 15, 30]) # [20, 35, 50] - 20 + expected = torch.tensor([0, 15, 30]) assert torch.equal(result, expected) def test_single_element_slice(sample_query_start_loc): """Test slice that results in single element""" - req_slice = slice(2, 2) # slice from index 2 to 2 + req_slice = slice(2, 3) # slice from index 2 to 3 result = slice_query_start_locs(sample_query_start_loc, req_slice) - expected = torch.tensor([0]) # [12] - 12 + expected = torch.tensor([0, 8]) assert torch.equal(result, expected) @@ -59,7 +59,7 @@ def test_full_tensor_slice(sample_query_start_loc): req_slice = slice(0, 5) # slice entire tensor result = slice_query_start_locs(sample_query_start_loc, req_slice) - expected = torch.tensor([0, 5, 12, 20, 35, 50]) # original - 0 + expected = torch.tensor([0, 5, 12, 20, 35, 50]) assert torch.equal(result, expected) @@ -106,46 +106,40 @@ def mixed_small_metadata(): def test_make_metadata_with_slice_decode_batch(small_decode_metadata): """Test slicing decode batch metadata""" # Split first request only - ubatch_slice = UbatchSlice(slice(0, 1), - slice(0, 1)) # First request, first token + ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1)) result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata) # Check sliced results - assert result.num_reqs == 1 # slice(0, 0) gives 0 requests + assert result.num_reqs == 1 # slice(0, 1) gives 1 requests assert result.num_actual_tokens == 1 # slice(0, 1) gives 1 token - assert result.max_query_len == 1 # Always set to 1 + assert result.max_query_len == 1 assert torch.equal(result.query_start_loc, torch.tensor([0, 1])) assert torch.equal(result.seq_lens, torch.tensor([32])) def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata): """Test slicing mixed batch metadata""" - # Split middle requests ubatch_slice = UbatchSlice(slice(1, 3), - slice(1, 7)) # Requests 1-2, tokens 1-7 + slice(1, 7)) # Requests 1-3, tokens 1-7 result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata) - # Check sliced results assert result.num_reqs == 2 # slice(1, 3) gives 2 requests - assert result.num_actual_tokens == 6 # slice(1, 7) gives 5 tokens + assert result.num_actual_tokens == 6 # slice(1, 7) gives 6 tokens assert result.max_query_len == 5 - # Query start should be offset: [1, 2] -> [0, 1] assert torch.equal(result.query_start_loc, torch.tensor([0, 1, 6])) - # Should get second sequence length assert torch.equal(result.seq_lens, torch.tensor([40, 48])) -# # Tests for split_attn_metadata def test_split_attn_metadata_decode_batch(large_decode_metadata): - """Test splitting decode batch into two parts""" + """Test splitting decode batch into two equal parts""" num_tokens = large_decode_metadata.num_reqs mid_point = num_tokens // 2 ubatch_slices = [ - UbatchSlice(slice(0, mid_point), slice(0, mid_point)), # First request - UbatchSlice(slice(mid_point, num_tokens), - slice(mid_point, num_tokens)), # Second request + UbatchSlice(slice(0, mid_point), slice(0, mid_point)), + UbatchSlice(slice(mid_point, num_tokens), slice(mid_point, + num_tokens)), ] results = split_attn_metadata(ubatch_slices, large_decode_metadata) From 4377e9657e2f38e2506e4548032e7078d03f0112 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 18 Jul 2025 15:11:38 +0000 Subject: [PATCH 13/15] misc test fixes Signed-off-by: Sage Moore --- tests/v1/attention/test_attention_splitting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index 5f20d0a06de1..3fc1011d5042 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -65,10 +65,10 @@ def test_full_tensor_slice(sample_query_start_loc): def test_slice_bounds_edge_cases(sample_query_start_loc): # Test slice that goes exactly to the last element - req_slice = slice(4, 4) # Last index + req_slice = slice(4, 5) # Last index result = slice_query_start_locs(sample_query_start_loc, req_slice) - expected = torch.tensor([0]) # [50] - 50 + expected = torch.tensor([0, 15]) assert torch.equal(result, expected) From f3bf8cd70bad437ce42f3bf6f1ecb2faf6090ab0 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 18 Jul 2025 16:08:25 +0000 Subject: [PATCH 14/15] Remove UbatchSlices alias Signed-off-by: Sage Moore --- vllm/v1/attention/backends/utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 26ba8ceac6d8..7806dfdfef5c 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,8 +4,7 @@ import functools from abc import abstractmethod from dataclasses import dataclass -from typing import (TYPE_CHECKING, ClassVar, Generic, Optional, TypeAlias, - TypeVar) +from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar import numpy as np import torch @@ -72,9 +71,6 @@ class UbatchSlice: token_slice: slice -UBatchSlices: TypeAlias = list[UbatchSlice] - - def slice_query_start_locs( query_start_loc: torch.Tensor, request_slice: slice, @@ -136,7 +132,7 @@ def _make_metadata_with_slice( def split_attn_metadata( - ubatch_slices: UBatchSlices, + ubatch_slices: list[UbatchSlice], common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: """ From b975c75eddd652d15371f6279cab14f131e97100 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Mon, 21 Jul 2025 18:10:27 +0000 Subject: [PATCH 15/15] review comments Signed-off-by: Sage Moore --- vllm/v1/attention/backends/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 7806dfdfef5c..1d1cc4eddad6 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -111,10 +111,10 @@ def _make_metadata_with_slice( num_requests = request_slice.stop - request_slice.start num_actual_tokens = token_slice.stop - token_slice.start max_query_len = int( - torch.max(torch.abs(query_start_loc[1:] - - query_start_loc[:-1])).item()) + torch.max(torch.abs(query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1])).item()) - block_table_tensor = attn_metadata.block_table_tensor[token_slice] + block_table_tensor = attn_metadata.block_table_tensor[request_slice] slot_mapping = attn_metadata.slot_mapping[token_slice] return CommonAttentionMetadata(