|
3 | 3 | import pytest |
4 | 4 | import torch |
5 | 5 |
|
| 6 | +from typing import Tuple |
| 7 | + |
6 | 8 | from vllm._C import cache_ops |
7 | 9 |
|
| 10 | +COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] |
8 | 11 | DTYPES = [torch.half, torch.bfloat16, torch.float] |
9 | 12 | NUM_TOKENS = [42] # Arbitrary values for testing |
10 | 13 | NUM_LAYERS = [1] # Arbitrary values for testing |
@@ -149,3 +152,68 @@ def test_reshape_and_cache( |
149 | 152 |
|
150 | 153 | assert torch.allclose(key_cache, cloned_key_cache) |
151 | 154 | assert torch.allclose(value_cache, cloned_value_cache) |
| 155 | + |
| 156 | + |
| 157 | +@pytest.mark.parametrize("direction", COPYING_DIRECTION) |
| 158 | +@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) |
| 159 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 160 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) |
| 161 | +@pytest.mark.parametrize("block_size", BLOCK_SIZES) |
| 162 | +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) |
| 163 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 164 | +@pytest.mark.parametrize("seed", SEEDS) |
| 165 | +@pytest.mark.parametrize("device", DEVICES) |
| 166 | +@torch.inference_mode() |
| 167 | +def test_swap_blocks( |
| 168 | + kv_cache_factory, |
| 169 | + direction: Tuple[str, str], |
| 170 | + num_mappings: int, |
| 171 | + num_heads: int, |
| 172 | + head_size: int, |
| 173 | + block_size: int, |
| 174 | + num_blocks: int, |
| 175 | + dtype: torch.dtype, |
| 176 | + seed: int, |
| 177 | + device: int, |
| 178 | +) -> None: |
| 179 | + random.seed(seed) |
| 180 | + torch.random.manual_seed(seed) |
| 181 | + torch.cuda.manual_seed(seed) |
| 182 | + src_device = f"{direction[0]}:{device}" if direction[ |
| 183 | + 0] == "cuda" else direction[0] |
| 184 | + dst_device = f"{direction[1]}:{device}" if direction[ |
| 185 | + 1] == "cuda" else direction[1] |
| 186 | + |
| 187 | + src_blocks = random.sample(range(num_blocks), num_mappings) |
| 188 | + # For the same device, mapping must not overlap |
| 189 | + if src_device == dst_device: |
| 190 | + remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) |
| 191 | + dst_blocks = random.sample(remaining_blocks, num_mappings) |
| 192 | + else: |
| 193 | + dst_blocks = random.sample(range(num_blocks), num_mappings) |
| 194 | + |
| 195 | + block_mapping = dict(zip(src_blocks, dst_blocks)) |
| 196 | + |
| 197 | + # Create the KV caches on the first device. |
| 198 | + src_key_caches, src_value_caches = kv_cache_factory( |
| 199 | + num_blocks, block_size, 1, num_heads, head_size, dtype, seed, |
| 200 | + src_device) |
| 201 | + |
| 202 | + # Create the KV caches on the second device. |
| 203 | + dist_key_caches, dist_value_caches = kv_cache_factory( |
| 204 | + num_blocks, block_size, 1, num_heads, head_size, dtype, seed, |
| 205 | + dst_device) |
| 206 | + |
| 207 | + src_key_caches_clone = src_key_caches[0].clone() |
| 208 | + src_value_caches_clone = src_value_caches[0].clone() |
| 209 | + |
| 210 | + # # Call the swap_blocks kernel. |
| 211 | + cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) |
| 212 | + cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0], |
| 213 | + block_mapping) |
| 214 | + |
| 215 | + for src, dst in block_mapping.items(): |
| 216 | + assert torch.allclose(src_key_caches_clone[src].cpu(), |
| 217 | + dist_key_caches[0][dst].cpu()) |
| 218 | + assert torch.allclose(src_value_caches_clone[src].cpu(), |
| 219 | + dist_value_caches[0][dst].cpu()) |
0 commit comments