| 
 | 1 | +# SPDX-License-Identifier: Apache-2.0  | 
 | 2 | + | 
 | 3 | +import pytest  | 
 | 4 | +import torch  | 
 | 5 | + | 
 | 6 | +from vllm.attention.ops.nki_flash_attn import reshape_and_cache  | 
 | 7 | + | 
 | 8 | + | 
 | 9 | +@pytest.mark.parametrize(  | 
 | 10 | +    "num_tokens, n_kv_head, d_head, num_blocks, block_size",  | 
 | 11 | +    [  | 
 | 12 | +        # Small model configuration (e.g., GPT-2 small)  | 
 | 13 | +        (32, 12, 64, 4, 128),  # Typical sequence processing  | 
 | 14 | +        (1, 12, 64, 4, 128),  # Single token update  | 
 | 15 | +        (128, 12, 64, 4, 128),  # Longer sequence  | 
 | 16 | +
  | 
 | 17 | +        # Medium model configuration (e.g., GPT-2 medium)  | 
 | 18 | +        (64, 16, 96, 8, 256),  # Standard batch  | 
 | 19 | +        (256, 16, 96, 8, 256),  # Large batch  | 
 | 20 | +
  | 
 | 21 | +        # Large model configuration (e.g., GPT-3 style)  | 
 | 22 | +        (48, 32, 128, 16, 512),  # Typical processing window  | 
 | 23 | +        (512, 32, 128, 16, 512),  # Full context window  | 
 | 24 | +
  | 
 | 25 | +        # Edge cases and stress tests  | 
 | 26 | +        (1024, 8, 32, 32, 32),  # Many tokens, small heads  | 
 | 27 | +        (16, 64, 256, 4, 64),  # Few tokens, many heads  | 
 | 28 | +        (2048, 24, 128, 64, 128),  # Large scale test  | 
 | 29 | +
  | 
 | 30 | +        # Minimal configurations for debugging  | 
 | 31 | +        (4, 2, 16, 2, 16),  # Tiny test case  | 
 | 32 | +        (1, 1, 8, 1, 8),  # Minimal possible  | 
 | 33 | +    ])  | 
 | 34 | +def test_reshape_and_cache(num_tokens, n_kv_head, d_head, num_blocks,  | 
 | 35 | +                           block_size):  | 
 | 36 | +    # Set random seed for reproducibility  | 
 | 37 | +    torch.manual_seed(42)  | 
 | 38 | + | 
 | 39 | +    # Create CPU tensors for reference implementation  | 
 | 40 | +    key_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt(  | 
 | 41 | +        torch.tensor(d_head))  | 
 | 42 | +    value_cpu = torch.randn(num_tokens, n_kv_head, d_head) / torch.sqrt(  | 
 | 43 | +        torch.tensor(d_head))  | 
 | 44 | +    key_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head)  | 
 | 45 | +    value_cache_cpu = torch.zeros(num_blocks, n_kv_head, block_size, d_head)  | 
 | 46 | +    slot_mapping_cpu = torch.randperm(num_blocks * block_size)[:num_tokens]  | 
 | 47 | + | 
 | 48 | +    # Run reference implementation on CPU  | 
 | 49 | +    block_indices = torch.div(slot_mapping_cpu,  | 
 | 50 | +                              block_size,  | 
 | 51 | +                              rounding_mode="floor")  | 
 | 52 | +    block_offsets = slot_mapping_cpu % block_size  | 
 | 53 | + | 
 | 54 | +    for i in range(num_tokens):  | 
 | 55 | +        block_idx = block_indices[i]  | 
 | 56 | +        block_offset = block_offsets[i]  | 
 | 57 | +        key_cache_cpu[block_idx, :, block_offset, :] = key_cpu[i]  | 
 | 58 | +        value_cache_cpu[block_idx, :, block_offset, :] = value_cpu[i]  | 
 | 59 | + | 
 | 60 | +    # Create XLA device tensors  | 
 | 61 | +    device = torch.device('xla')  | 
 | 62 | +    key = key_cpu.to(device)  | 
 | 63 | +    value = value_cpu.to(device)  | 
 | 64 | +    key_cache = torch.zeros_like(key_cache_cpu, device=device)  | 
 | 65 | +    value_cache = torch.zeros_like(value_cache_cpu, device=device)  | 
 | 66 | +    slot_mapping = slot_mapping_cpu.to(device)  | 
 | 67 | + | 
 | 68 | +    # Run vectorized implementation on XLA device  | 
 | 69 | +    reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)  | 
 | 70 | + | 
 | 71 | +    # Move results back to CPU for comparison  | 
 | 72 | +    key_cache_result = key_cache.cpu()  | 
 | 73 | +    value_cache_result = value_cache.cpu()  | 
 | 74 | + | 
 | 75 | +    # Assert results match  | 
 | 76 | +    torch.testing.assert_close(key_cache_result,  | 
 | 77 | +                               key_cache_cpu,  | 
 | 78 | +                               rtol=1e-5,  | 
 | 79 | +                               atol=1e-5)  | 
 | 80 | +    torch.testing.assert_close(value_cache_result,  | 
 | 81 | +                               value_cache_cpu,  | 
 | 82 | +                               rtol=1e-5,  | 
 | 83 | +                               atol=1e-5)  | 
0 commit comments