diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py new file mode 100644 index 000000000000..3fc1011d5042 --- /dev/null +++ b/tests/v1/attention/test_attention_splitting.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.v1.attention.test_attention_backends import BATCH_SPECS +from tests.v1.attention.utils import create_common_attn_metadata +from vllm.v1.attention.backends.utils import (UbatchSlice, + _make_metadata_with_slice, + slice_query_start_locs, + split_attn_metadata) + + +@pytest.fixture +def sample_query_start_loc(): + """Sample query_start_loc tensor for testing""" + return torch.tensor([0, 5, 12, 20, 35, 50]) + + +def test_basic_slice_middle(sample_query_start_loc): + """Test slicing from middle of tensor""" + req_slice = slice(1, 3) # slice from index 1 to 3 + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 7, 15]) + assert torch.equal(result, expected) + + +def test_slice_from_beginning(sample_query_start_loc): + """Test slicing from the beginning of tensor""" + req_slice = slice(0, 2) # slice from index 0 to 2 + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 5, 12]) + assert torch.equal(result, expected) + + +def test_slice_to_end(sample_query_start_loc): + """Test slicing to the end of tensor""" + req_slice = slice(3, 5) # slice from index 3 to 5 (last index) + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 15, 30]) + assert torch.equal(result, expected) + + +def test_single_element_slice(sample_query_start_loc): + """Test slice that results in single element""" + req_slice = slice(2, 3) # slice from index 2 to 3 + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 8]) + assert torch.equal(result, expected) + + +def test_full_tensor_slice(sample_query_start_loc): + """Test slicing the entire tensor""" + req_slice = slice(0, 5) # slice entire tensor + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 5, 12, 20, 35, 50]) + assert torch.equal(result, expected) + + +def test_slice_bounds_edge_cases(sample_query_start_loc): + # Test slice that goes exactly to the last element + req_slice = slice(4, 5) # Last index + result = slice_query_start_locs(sample_query_start_loc, req_slice) + + expected = torch.tensor([0, 15]) + assert torch.equal(result, expected) + + +@pytest.fixture +def small_decode_metadata(): + """Create metadata for small decode batch""" + batch_spec = BATCH_SPECS["small_decode"] + device = torch.device("cpu") + return create_common_attn_metadata(batch_spec, + block_size=16, + device=device) + + +@pytest.fixture +def large_decode_metadata(): + """Create metadata for small decode batch""" + batch_spec = BATCH_SPECS["large_decode"] + device = torch.device("cpu") + return create_common_attn_metadata(batch_spec, + block_size=16, + device=device) + + +@pytest.fixture +def mixed_small_metadata(): + """Create metadata for mixed small batch""" + batch_spec = BATCH_SPECS["mixed_small"] + device = torch.device("cpu") + return create_common_attn_metadata(batch_spec, + block_size=16, + device=device) + + +# Tests for _make_metadata_with_slice +def test_make_metadata_with_slice_decode_batch(small_decode_metadata): + """Test slicing decode batch metadata""" + # Split first request only + ubatch_slice = UbatchSlice(slice(0, 1), slice(0, 1)) + + result = _make_metadata_with_slice(ubatch_slice, small_decode_metadata) + + # Check sliced results + assert result.num_reqs == 1 # slice(0, 1) gives 1 requests + assert result.num_actual_tokens == 1 # slice(0, 1) gives 1 token + assert result.max_query_len == 1 + assert torch.equal(result.query_start_loc, torch.tensor([0, 1])) + assert torch.equal(result.seq_lens, torch.tensor([32])) + + +def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata): + """Test slicing mixed batch metadata""" + ubatch_slice = UbatchSlice(slice(1, 3), + slice(1, 7)) # Requests 1-3, tokens 1-7 + + result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata) + + assert result.num_reqs == 2 # slice(1, 3) gives 2 requests + assert result.num_actual_tokens == 6 # slice(1, 7) gives 6 tokens + assert result.max_query_len == 5 + assert torch.equal(result.query_start_loc, torch.tensor([0, 1, 6])) + assert torch.equal(result.seq_lens, torch.tensor([40, 48])) + + +def test_split_attn_metadata_decode_batch(large_decode_metadata): + """Test splitting decode batch into two equal parts""" + num_tokens = large_decode_metadata.num_reqs + mid_point = num_tokens // 2 + ubatch_slices = [ + UbatchSlice(slice(0, mid_point), slice(0, mid_point)), + UbatchSlice(slice(mid_point, num_tokens), slice(mid_point, + num_tokens)), + ] + + results = split_attn_metadata(ubatch_slices, large_decode_metadata) + + assert len(results) == 2 + + # Check first split + assert results[0].num_reqs == mid_point + assert results[0].num_actual_tokens == mid_point + assert torch.equal(results[0].seq_lens, torch.tensor([2048] * mid_point)) + + # Check second split + assert results[1].num_reqs == mid_point + assert results[1].num_actual_tokens == mid_point + assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 36bacf0cb36f..9333cb53b4ca 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -62,6 +62,89 @@ class CommonAttentionMetadata: causal: bool = True +@dataclass +class UbatchSlice: + request_slice: slice + token_slice: slice + + +def slice_query_start_locs( + query_start_loc: torch.Tensor, + request_slice: slice, +) -> torch.Tensor: + """ + Creates a new query_start_loc that corresponds to the requests in + request_slice. + + Note: This function creates a new tensor to hold the new query_start_locs. + This will break cudagraph compatibility. + """ + return query_start_loc[request_slice.start: request_slice.stop + 1] -\ + query_start_loc[request_slice.start] + + +def _make_metadata_with_slice( + ubatch_slice: UbatchSlice, + attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: + """ + This function creates a new CommonAttentionMetadata that corresponds to + the requests included in ubatch_slice + """ + + request_slice = ubatch_slice.request_slice + token_slice = ubatch_slice.token_slice + + query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc, + request_slice) + assert len(query_start_loc >= 2) + query_start_loc_cpu = slice_query_start_locs( + attn_metadata.query_start_loc_cpu, request_slice) + + seq_lens = attn_metadata.seq_lens[request_slice] + seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice] + num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ + request_slice] + + num_requests = request_slice.stop - request_slice.start + num_actual_tokens = token_slice.stop - token_slice.start + max_query_len = int( + torch.max(torch.abs(query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1])).item()) + + block_table_tensor = attn_metadata.block_table_tensor[request_slice] + slot_mapping = attn_metadata.slot_mapping[token_slice] + + return CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=num_requests, + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + block_table_tensor=block_table_tensor, + slot_mapping=slot_mapping, + ) + + +def split_attn_metadata( + ubatch_slices: list[UbatchSlice], + common_attn_metadata: CommonAttentionMetadata, +) -> list[CommonAttentionMetadata]: + """ + Creates a new CommonAttentionMetadata instance that corresponds to the + requests for each UbatchSlice in ubatch_slices. + + Note: This function does not modify common_attn_metadata + """ + results = [] + for ubatch_slice in ubatch_slices: + results.append( + _make_metadata_with_slice(ubatch_slice, common_attn_metadata)) + return results + + M = TypeVar("M")