Skip to content

Commit 951d69c

Browse files
committed
add swap_blocks unit tests vllm-project#2583
1 parent 5265631 commit 951d69c

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

tests/kernels/test_cache.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
import pytest
44
import torch
55

6+
from typing import Tuple
7+
68
from vllm._C import cache_ops
79

10+
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
811
DTYPES = [torch.half, torch.bfloat16, torch.float]
912
NUM_TOKENS = [42] # Arbitrary values for testing
1013
NUM_LAYERS = [1] # Arbitrary values for testing
@@ -149,3 +152,68 @@ def test_reshape_and_cache(
149152

150153
assert torch.allclose(key_cache, cloned_key_cache)
151154
assert torch.allclose(value_cache, cloned_value_cache)
155+
156+
157+
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
158+
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
159+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
160+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
161+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
162+
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
163+
@pytest.mark.parametrize("dtype", DTYPES)
164+
@pytest.mark.parametrize("seed", SEEDS)
165+
@pytest.mark.parametrize("device", DEVICES)
166+
@torch.inference_mode()
167+
def test_swap_blocks(
168+
kv_cache_factory,
169+
direction: Tuple[str, str],
170+
num_mappings: int,
171+
num_heads: int,
172+
head_size: int,
173+
block_size: int,
174+
num_blocks: int,
175+
dtype: torch.dtype,
176+
seed: int,
177+
device: int,
178+
) -> None:
179+
random.seed(seed)
180+
torch.random.manual_seed(seed)
181+
torch.cuda.manual_seed(seed)
182+
src_device = f"{direction[0]}:{device}" if direction[
183+
0] == "cuda" else direction[0]
184+
dst_device = f"{direction[1]}:{device}" if direction[
185+
1] == "cuda" else direction[1]
186+
187+
src_blocks = random.sample(range(num_blocks), num_mappings)
188+
# For the same device, mapping must not overlap
189+
if src_device == dst_device:
190+
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
191+
dst_blocks = random.sample(remaining_blocks, num_mappings)
192+
else:
193+
dst_blocks = random.sample(range(num_blocks), num_mappings)
194+
195+
block_mapping = dict(zip(src_blocks, dst_blocks))
196+
197+
# Create the KV caches on the first device.
198+
src_key_caches, src_value_caches = kv_cache_factory(
199+
num_blocks, block_size, 1, num_heads, head_size, dtype, seed,
200+
src_device)
201+
202+
# Create the KV caches on the second device.
203+
dist_key_caches, dist_value_caches = kv_cache_factory(
204+
num_blocks, block_size, 1, num_heads, head_size, dtype, seed,
205+
dst_device)
206+
207+
src_key_caches_clone = src_key_caches[0].clone()
208+
src_value_caches_clone = src_value_caches[0].clone()
209+
210+
# # Call the swap_blocks kernel.
211+
cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
212+
cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
213+
block_mapping)
214+
215+
for src, dst in block_mapping.items():
216+
assert torch.allclose(src_key_caches_clone[src].cpu(),
217+
dist_key_caches[0][dst].cpu())
218+
assert torch.allclose(src_value_caches_clone[src].cpu(),
219+
dist_value_caches[0][dst].cpu())

0 commit comments

Comments
 (0)