|
5 | 5 | import torch |
6 | 6 |
|
7 | 7 | from tests.v1.attention.test_attention_backends import BATCH_SPECS |
8 | | -from tests.v1.attention.utils import create_common_attn_metadata |
| 8 | +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata |
9 | 9 | from vllm.v1.attention.backends.utils import (UBatchSlice, |
10 | 10 | _make_metadata_with_slice, |
11 | 11 | slice_query_start_locs, |
12 | 12 | split_attn_metadata) |
| 13 | +from vllm.v1.worker.ubatch_utils import create_ubatch_slices |
13 | 14 |
|
14 | 15 |
|
15 | 16 | @pytest.fixture |
@@ -155,3 +156,83 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): |
155 | 156 | assert results[1].num_reqs == mid_point |
156 | 157 | assert results[1].num_actual_tokens == mid_point |
157 | 158 | assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) |
| 159 | + |
| 160 | + |
| 161 | +@pytest.mark.parametrize( |
| 162 | + "seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs", |
| 163 | + [ |
| 164 | + # Split in the middle of request 1 |
| 165 | + ([32, 40], [8, 8], 12, 2, 1), |
| 166 | + # Split inside the first request |
| 167 | + ([32, 40], [8, 8], 4, 1, 2), |
| 168 | + ], |
| 169 | +) |
| 170 | +def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point, |
| 171 | + expected_first_reqs, |
| 172 | + expected_second_reqs): |
| 173 | + """Test splitting a prefill across ubatches""" |
| 174 | + import numpy as np |
| 175 | + |
| 176 | + device = torch.device("cpu") |
| 177 | + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens) |
| 178 | + common = create_common_attn_metadata(batch_spec, |
| 179 | + block_size=16, |
| 180 | + device=device) |
| 181 | + |
| 182 | + num_scheduled_tokens = np.array(query_lens, dtype=np.int32) |
| 183 | + qsl_np = common.query_start_loc_cpu.numpy() |
| 184 | + num_tokens = common.num_actual_tokens |
| 185 | + |
| 186 | + ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point) |
| 187 | + assert len(ubatch_slices) == 2 |
| 188 | + |
| 189 | + first_meta = _make_metadata_with_slice(ubatch_slices[0], common) |
| 190 | + second_meta = _make_metadata_with_slice(ubatch_slices[1], common) |
| 191 | + |
| 192 | + # Token counts match the split |
| 193 | + assert first_meta.num_actual_tokens == split_point |
| 194 | + assert second_meta.num_actual_tokens == num_tokens - split_point |
| 195 | + |
| 196 | + # Number of requests per ubatch |
| 197 | + assert first_meta.num_reqs == expected_first_reqs |
| 198 | + assert second_meta.num_reqs == expected_second_reqs |
| 199 | + |
| 200 | + # Identify which request is split and how many tokens are in the first chunk |
| 201 | + split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1) |
| 202 | + tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx]) |
| 203 | + orig_q_lens = (common.query_start_loc_cpu[1:] - |
| 204 | + common.query_start_loc_cpu[:-1]) |
| 205 | + |
| 206 | + # Check query length continuity: first-chunk + second-chunk == original qlen |
| 207 | + # First ubatch last request query length |
| 208 | + qlen_first_last = int(first_meta.query_start_loc_cpu[-1] - |
| 209 | + first_meta.query_start_loc_cpu[-2]) |
| 210 | + # Second ubatch first request query length |
| 211 | + qlen_second_first = int(second_meta.query_start_loc_cpu[1] - |
| 212 | + second_meta.query_start_loc_cpu[0]) |
| 213 | + assert qlen_first_last == tokens_in_first_chunk |
| 214 | + assert qlen_first_last + qlen_second_first == int( |
| 215 | + orig_q_lens[split_req_idx]) |
| 216 | + |
| 217 | + # Check seq_lens adjustments |
| 218 | + # Context lengths per original request |
| 219 | + context_lens = [s - q for s, q in zip(seq_lens, query_lens)] |
| 220 | + |
| 221 | + # First ubatch: last request's seq_len should be |
| 222 | + # context + tokens_in_first_chunk |
| 223 | + expected_seqlen = context_lens[split_req_idx] + tokens_in_first_chunk |
| 224 | + assert int(first_meta.seq_lens[-1]) == expected_seqlen |
| 225 | + |
| 226 | + # For full preceding requests in first ubatch, seq_lens should match |
| 227 | + # originals |
| 228 | + for i in range(first_meta.num_reqs - 1): |
| 229 | + assert int(first_meta.seq_lens[i]) == seq_lens[i] |
| 230 | + |
| 231 | + # Second ubatch: first request (continuation) seq_len should be full |
| 232 | + # original |
| 233 | + assert int(second_meta.seq_lens[0]) == seq_lens[split_req_idx] |
| 234 | + # Any following full requests in second ubatch should match originals |
| 235 | + for j in range(1, second_meta.num_reqs): |
| 236 | + # Map to original request index |
| 237 | + orig_idx = split_req_idx + j |
| 238 | + assert int(second_meta.seq_lens[j]) == seq_lens[orig_idx] |
0 commit comments