|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import numpy as np |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | + |
| 7 | +from vllm.v1.attention.backends.cp_utils import (cp_get_shard_size, |
| 8 | + prepare_inputs_for_cp) |
| 9 | +from vllm.v1.worker.block_table import MultiGroupBlockTable |
| 10 | +from vllm.v1.worker.gpu_input_batch import CachedRequestState |
| 11 | + |
| 12 | + |
| 13 | +@pytest.fixture(autouse=True) |
| 14 | +def patch_parallel_state(monkeypatch): |
| 15 | + # Patch get_context_parallel_world_size and get_context_parallel_rank |
| 16 | + monkeypatch.setattr( |
| 17 | + "vllm.distributed.parallel_state.get_context_parallel_world_size", |
| 18 | + lambda: 2) |
| 19 | + monkeypatch.setattr( |
| 20 | + "vllm.distributed.parallel_state.get_context_parallel_rank", lambda: 0) |
| 21 | + |
| 22 | + # Patch get_cp_group to return a mock object |
| 23 | + class MockCPGroup: |
| 24 | + world_size = 2 |
| 25 | + rank = 0 |
| 26 | + rank_in_group = 0 |
| 27 | + |
| 28 | + monkeypatch.setattr("vllm.distributed.parallel_state.get_cp_group", |
| 29 | + lambda: MockCPGroup()) |
| 30 | + |
| 31 | + |
| 32 | +def make_cached_request_state(id: int, prefill_len: int, decode_len: int, |
| 33 | + num_computed_tokens: list[int]): |
| 34 | + assert prefill_len + decode_len == sum(num_computed_tokens) |
| 35 | + return CachedRequestState( |
| 36 | + req_id="req" + str(id), |
| 37 | + prompt_token_ids=list(range(prefill_len)), |
| 38 | + prompt_embeds=None, |
| 39 | + mm_features=[], |
| 40 | + sampling_params=None, |
| 41 | + pooling_params=None, |
| 42 | + generator=None, |
| 43 | + block_ids=([0], ), |
| 44 | + num_computed_tokens=num_computed_tokens, |
| 45 | + output_token_ids=list(range(decode_len)), |
| 46 | + lora_request=None, |
| 47 | + ) |
| 48 | + |
| 49 | + |
| 50 | +def create_block_table(): |
| 51 | + return MultiGroupBlockTable(max_num_reqs=32, |
| 52 | + max_model_len=2048, |
| 53 | + max_num_batched_tokens=512, |
| 54 | + pin_memory=False, |
| 55 | + device=torch.device("cpu"), |
| 56 | + block_sizes=[16], |
| 57 | + num_speculative_tokens=0) |
| 58 | + |
| 59 | + |
| 60 | +def test_prepare_inputs_for_cp_prefill(monkeypatch): |
| 61 | + # Setup |
| 62 | + id = 0 |
| 63 | + prefill_len = 8 |
| 64 | + decode_len = 0 |
| 65 | + num_computed_tokens = [0] |
| 66 | + num_scheduled_tokens_ = prefill_len |
| 67 | + |
| 68 | + req_state = CachedRequestState( |
| 69 | + req_id="req" + str(id), |
| 70 | + prompt_token_ids=list(range(prefill_len)), |
| 71 | + prompt_embeds=None, |
| 72 | + mm_features=[], |
| 73 | + sampling_params=None, |
| 74 | + pooling_params=None, |
| 75 | + generator=None, |
| 76 | + block_ids=([0], ), |
| 77 | + num_computed_tokens=num_computed_tokens, |
| 78 | + output_token_ids=list(range(decode_len)), |
| 79 | + lora_request=None, |
| 80 | + ) |
| 81 | + num_scheduled_tokens = {req_state.req_id: num_scheduled_tokens_} |
| 82 | + req_ids = [req_state.req_id] |
| 83 | + requests = {req_state.req_id: req_state} |
| 84 | + |
| 85 | + block_table = MultiGroupBlockTable(max_num_reqs=32, |
| 86 | + max_model_len=2048, |
| 87 | + max_num_batched_tokens=512, |
| 88 | + pin_memory=False, |
| 89 | + device=torch.device("cpu"), |
| 90 | + block_sizes=[16], |
| 91 | + num_speculative_tokens=0) |
| 92 | + |
| 93 | + positions_np = np.zeros(64, dtype=np.int64) |
| 94 | + computed_positions_np = np.zeros(64, dtype=np.int64) |
| 95 | + arange_np = np.arange(64, dtype=np.int64) |
| 96 | + padding_loc = -1 |
| 97 | + |
| 98 | + # Run |
| 99 | + num_sched_local, num_comp_local, q_seqlens_sharded = prepare_inputs_for_cp( |
| 100 | + num_scheduled_tokens, requests, req_ids, block_table, positions_np, |
| 101 | + computed_positions_np, arange_np, padding_loc) |
| 102 | + |
| 103 | + # Check |
| 104 | + cp_shard_size, _ = cp_get_shard_size(num_scheduled_tokens_) |
| 105 | + assert num_sched_local == [2 * cp_shard_size] |
| 106 | + assert num_comp_local == [0] |
| 107 | + assert q_seqlens_sharded == [[cp_shard_size, cp_shard_size]] |
| 108 | + assert np.all( |
| 109 | + positions_np[:sum(num_sched_local)] == np.array([0, 1, 6, 7])) |
| 110 | + if sum(num_comp_local) > 0: |
| 111 | + assert np.all(computed_positions_np[:sum(num_comp_local)] == np.arange( |
| 112 | + 2 * cp_shard_size)) |
| 113 | + |
| 114 | + |
| 115 | +def test_prepare_inputs_for_cp_decode(monkeypatch): |
| 116 | + # Setup |
| 117 | + id = 0 |
| 118 | + prefill_len = 8 |
| 119 | + decode_len = 2 |
| 120 | + num_computed_tokens = [0, 4, 8, 9, 10] |
| 121 | + num_scheduled_tokens_ = 1 |
| 122 | + |
| 123 | + req_state = CachedRequestState( |
| 124 | + req_id="req" + str(id), |
| 125 | + prompt_token_ids=list(range(prefill_len)), |
| 126 | + prompt_embeds=None, |
| 127 | + mm_features=[], |
| 128 | + sampling_params=None, |
| 129 | + pooling_params=None, |
| 130 | + generator=None, |
| 131 | + block_ids=([0], ), |
| 132 | + num_computed_tokens=num_computed_tokens, |
| 133 | + output_token_ids=list(range(decode_len)), |
| 134 | + lora_request=None, |
| 135 | + ) |
| 136 | + num_scheduled_tokens = {req_state.req_id: num_scheduled_tokens_} |
| 137 | + req_ids = [req_state.req_id] |
| 138 | + requests = {req_state.req_id: req_state} |
| 139 | + |
| 140 | + block_table = MultiGroupBlockTable(max_num_reqs=32, |
| 141 | + max_model_len=2048, |
| 142 | + max_num_batched_tokens=512, |
| 143 | + pin_memory=False, |
| 144 | + device=torch.device("cpu"), |
| 145 | + block_sizes=[16], |
| 146 | + num_speculative_tokens=0) |
| 147 | + |
| 148 | + positions_np = np.zeros(64, dtype=np.int64) |
| 149 | + computed_positions_np = np.zeros(64, dtype=np.int64) |
| 150 | + arange_np = np.arange(64, dtype=np.int64) |
| 151 | + padding_loc = -1 |
| 152 | + |
| 153 | + # Run |
| 154 | + num_sched_local, num_comp_local, q_seqlens_sharded = prepare_inputs_for_cp( |
| 155 | + num_scheduled_tokens, requests, req_ids, block_table, positions_np, |
| 156 | + computed_positions_np, arange_np, padding_loc) |
| 157 | + |
| 158 | + # Check |
| 159 | + assert num_sched_local == [1] |
| 160 | + assert num_comp_local == [num_computed_tokens[-1] // 2] |
| 161 | + assert q_seqlens_sharded == [[1]] |
| 162 | + assert np.all(positions_np[:num_sched_local[0]] == np.array([10])) |
| 163 | + if sum(num_comp_local) > 0: |
| 164 | + assert np.all(computed_positions_np[:num_comp_local[0]] == np.array( |
| 165 | + [0, 3, 4, 7, 8])) |
| 166 | + |
| 167 | + |
| 168 | +def test_prepare_inputs_for_cp_multiple_requests(monkeypatch): |
| 169 | + # Setup |
| 170 | + prefill_lens = [8, 16] |
| 171 | + decode_lens = [2, 0] |
| 172 | + num_computed_tokens = [[0, 4, 8, 9, 10], [0, 8]] |
| 173 | + num_scheduled_tokens_ = [1, 8] |
| 174 | + |
| 175 | + num_scheduled_tokens = {} |
| 176 | + requests = {} |
| 177 | + req_ids = [] |
| 178 | + for i in range(2): |
| 179 | + req_state = CachedRequestState( |
| 180 | + req_id="req" + str(i), |
| 181 | + prompt_token_ids=list(range(prefill_lens[i])), |
| 182 | + prompt_embeds=None, |
| 183 | + mm_features=[], |
| 184 | + sampling_params=None, |
| 185 | + pooling_params=None, |
| 186 | + generator=None, |
| 187 | + block_ids=([0], ), |
| 188 | + num_computed_tokens=num_computed_tokens[i], |
| 189 | + output_token_ids=list(range(decode_lens[i])), |
| 190 | + lora_request=None, |
| 191 | + ) |
| 192 | + num_scheduled_tokens[req_state.req_id] = num_scheduled_tokens_[i] |
| 193 | + req_ids.append(req_state.req_id) |
| 194 | + requests[req_state.req_id] = req_state |
| 195 | + |
| 196 | + block_table = MultiGroupBlockTable(max_num_reqs=32, |
| 197 | + max_model_len=2048, |
| 198 | + max_num_batched_tokens=512, |
| 199 | + pin_memory=False, |
| 200 | + device=torch.device("cpu"), |
| 201 | + block_sizes=[16], |
| 202 | + num_speculative_tokens=0) |
| 203 | + |
| 204 | + positions_np = np.zeros(64, dtype=np.int64) |
| 205 | + computed_positions_np = np.zeros(64, dtype=np.int64) |
| 206 | + arange_np = np.arange(64, dtype=np.int64) |
| 207 | + padding_loc = -1 |
| 208 | + |
| 209 | + # Run |
| 210 | + num_sched_local, num_comp_local, q_seqlens_sharded = prepare_inputs_for_cp( |
| 211 | + num_scheduled_tokens, requests, req_ids, block_table, positions_np, |
| 212 | + computed_positions_np, arange_np, padding_loc) |
| 213 | + |
| 214 | + # Check |
| 215 | + assert num_sched_local == [1, 4] |
| 216 | + assert num_comp_local == [ |
| 217 | + num_computed_tokens[0][-1] // 2, [num_computed_tokens[1][-1] // 2] |
| 218 | + ] |
| 219 | + assert q_seqlens_sharded == [[1], [2, 2]] |
| 220 | + assert np.all(positions_np[:num_sched_local[0]] == np.array([10])) |
| 221 | + assert np.all(positions_np[num_sched_local[0]:num_sched_local[0] + |
| 222 | + num_sched_local[1]] == np.array([8, 9, 14, 15])) |
| 223 | + if sum(num_comp_local) > 0: |
| 224 | + assert np.all(computed_positions_np[:num_comp_local[0]] == np.array( |
| 225 | + [0, 3, 4, 7, 8])) |
| 226 | + assert np.all( |
| 227 | + computed_positions_np[num_comp_local[0]:num_comp_local[0] + |
| 228 | + num_comp_local[1]] == np.array([0, 1, 6, 7])) |
0 commit comments