Skip to content

Commit c91b64f

Browse files
authored
[neuron] add reshape_and_cache (#14391)
1 parent d612317 commit c91b64f

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

tests/neuron/test_cache.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import torch
5+
6+
from vllm.attention.ops.nki_flash_attn import reshape_and_cache
7+
8+
9+
@pytest.mark.parametrize(
10+
"num_tokens, n_kv_head, d_head, num_blocks, block_size",
11+
[
12+
# Small model configuration (e.g., GPT-2 small)
13+
(32, 12, 64, 4, 128), # Typical sequence processing
14+
(1, 12, 64, 4, 128), # Single token update
15+
(128, 12, 64, 4, 128), # Longer sequence
16+
17+
# Medium model configuration (e.g., GPT-2 medium)
18+
(64, 16, 96, 8, 256), # Standard batch
19+
(256, 16, 96, 8, 256), # Large batch
20+
21+
# Large model configuration (e.g., GPT-3 style)
22+
(48, 32, 128, 16, 512), # Typical processing window
23+
(512, 32, 128, 16, 512), # Full context window
24+
25+
# Edge cases and stress tests
26+
(1024, 8, 32, 32, 32), # Many tokens, small heads
27+
(16, 64, 256, 4, 64), # Few tokens, many heads
28+
(2048, 24, 128, 64, 128), # Large scale test
29+
30+
# Minimal configurations for debugging
31+
(4, 2, 16, 2, 16), # Tiny test case
32+
(1, 1, 8, 1, 8), # Minimal possible
33+
])
34+
def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks,
35+
block_size):
36+
# Set random seed for reproducibility
37+
torch.manual_seed(42)
38+
39+
# Create CPU tensors for reference implementation
40+
key_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt(
41+
torch.tensor(d_head))
42+
value_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt(
43+
torch.tensor(d_head))
44+
key_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head)
45+
value_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head)
46+
slot_mapping_cpu = torch.randperm(num_blocks * block_size)[:num_tokens]
47+
48+
# Run reference implementation on CPU
49+
block_indices = torch.div(slot_mapping_cpu,
50+
block_size,
51+
rounding_mode="floor")
52+
block_offsets = slot_mapping_cpu % block_size
53+
54+
for i in range(num_tokens):
55+
block_idx = block_indices[i]
56+
block_offset = block_offsets[i]
57+
key_cache_cpu[block_idx, :, block_offset, :] = key_cpu[i]
58+
value_cache_cpu[block_idx, :, block_offset, :] = value_cpu[i]
59+
60+
# Create XLA device tensors
61+
device = torch.device('xla')
62+
key = key_cpu.to(device)
63+
value = value_cpu.to(device)
64+
key_cache = torch.zeros_like(key_cache_cpu, device=device)
65+
value_cache = torch.zeros_like(value_cache_cpu, device=device)
66+
slot_mapping = slot_mapping_cpu.to(device)
67+
68+
# Run vectorized implementation on XLA device
69+
reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
70+
71+
# Move results back to CPU for comparison
72+
key_cache_result = key_cache.cpu()
73+
value_cache_result = value_cache.cpu()
74+
75+
# Assert results match
76+
torch.testing.assert_close(key_cache_result,
77+
key_cache_cpu,
78+
rtol=1e-5,
79+
atol=1e-5)
80+
torch.testing.assert_close(value_cache_result,
81+
value_cache_cpu,
82+
rtol=1e-5,
83+
atol=1e-5)

vllm/attention/ops/nki_flash_attn.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,3 +869,46 @@ def flash_attn_varlen_nkifunc(
869869

870870
o = flash_paged_attention[1, n_kv_head](**kwargs)
871871
return o
872+
873+
874+
def reshape_and_cache(
875+
key: torch.Tensor,
876+
value: torch.Tensor,
877+
key_cache: torch.Tensor,
878+
value_cache: torch.Tensor,
879+
slot_mapping: torch.Tensor,
880+
) -> None:
881+
"""
882+
Writes key-value pairs to the KV cache at specified positions.
883+
884+
Args:
885+
key (torch.Tensor): Key tensor with shape
886+
(num_tokens, n_kv_head, d_head)
887+
value (torch.Tensor): Value tensor with shape
888+
(num_tokens, n_kv_head, d_head)
889+
key_cache (torch.Tensor): Key cache tensor with shape
890+
(num_blocks, n_kv_head, block_size, d_head)
891+
value_cache (torch.Tensor): Value cache tensor with shape
892+
(num_blocks, n_kv_head, block_size, d_head)
893+
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
894+
with shape (num_tokens)
895+
896+
Returns:
897+
None: Updates the key_cache and value_cache tensors in-place
898+
"""
899+
block_size = key_cache.size(2)
900+
901+
# Calculate indices with explicit floor division
902+
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
903+
block_offsets = slot_mapping % block_size
904+
905+
# Update caches using index_put_
906+
key_cache.index_put_(
907+
(block_indices.unsqueeze(1),
908+
torch.arange(key_cache.size(1),
909+
device=key.device), block_offsets.unsqueeze(1)), key)
910+
911+
value_cache.index_put_(
912+
(block_indices.unsqueeze(1),
913+
torch.arange(value_cache.size(1),
914+
device=value.device), block_offsets.unsqueeze(1)), value)

0 commit comments

Comments
 (0)