|  | 
|  | 1 | +# SPDX-License-Identifier: Apache-2.0 | 
|  | 2 | + | 
|  | 3 | +from typing import Optional | 
|  | 4 | + | 
|  | 5 | +import pytest | 
|  | 6 | +import torch | 
|  | 7 | + | 
|  | 8 | +from vllm.attention.ops.triton_unified_attention import unified_attention | 
|  | 9 | +from vllm.platforms import current_platform | 
|  | 10 | + | 
|  | 11 | +NUM_HEADS = [(4, 4), (8, 2), (16, 2)] | 
|  | 12 | +HEAD_SIZES = [128, 256] | 
|  | 13 | +BLOCK_SIZES = [16, 32] | 
|  | 14 | + | 
|  | 15 | +DTYPES = [torch.float16, torch.bfloat16] | 
|  | 16 | +QDTYPES = [None, torch.float8_e4m3fn] | 
|  | 17 | +# one value large enough to test overflow in index calculation. | 
|  | 18 | +# one value small enough to test the schema op check | 
|  | 19 | +NUM_BLOCKS = [32768, 2048] | 
|  | 20 | + | 
|  | 21 | + | 
|  | 22 | +def ref_paged_attn( | 
|  | 23 | +    query: torch.Tensor, | 
|  | 24 | +    key_cache: torch.Tensor, | 
|  | 25 | +    value_cache: torch.Tensor, | 
|  | 26 | +    query_lens: list[int], | 
|  | 27 | +    kv_lens: list[int], | 
|  | 28 | +    block_tables: torch.Tensor, | 
|  | 29 | +    scale: float, | 
|  | 30 | +    sliding_window: Optional[int] = None, | 
|  | 31 | +    soft_cap: Optional[float] = None, | 
|  | 32 | +) -> torch.Tensor: | 
|  | 33 | +    num_seqs = len(query_lens) | 
|  | 34 | +    block_tables = block_tables.cpu().numpy() | 
|  | 35 | +    _, block_size, num_kv_heads, head_size = key_cache.shape | 
|  | 36 | + | 
|  | 37 | +    outputs: list[torch.Tensor] = [] | 
|  | 38 | +    start_idx = 0 | 
|  | 39 | +    for i in range(num_seqs): | 
|  | 40 | +        query_len = query_lens[i] | 
|  | 41 | +        kv_len = kv_lens[i] | 
|  | 42 | +        q = query[start_idx:start_idx + query_len] | 
|  | 43 | +        q *= scale | 
|  | 44 | + | 
|  | 45 | +        num_kv_blocks = (kv_len + block_size - 1) // block_size | 
|  | 46 | +        block_indices = block_tables[i, :num_kv_blocks] | 
|  | 47 | + | 
|  | 48 | +        k = key_cache[block_indices].view(-1, num_kv_heads, head_size) | 
|  | 49 | +        k = k[:kv_len] | 
|  | 50 | +        v = value_cache[block_indices].view(-1, num_kv_heads, head_size) | 
|  | 51 | +        v = v[:kv_len] | 
|  | 52 | + | 
|  | 53 | +        if q.shape[1] != k.shape[1]: | 
|  | 54 | +            k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) | 
|  | 55 | +            v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) | 
|  | 56 | +        attn = torch.einsum("qhd,khd->hqk", q, k).float() | 
|  | 57 | +        empty_mask = torch.ones(query_len, kv_len) | 
|  | 58 | +        mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() | 
|  | 59 | +        if sliding_window is not None: | 
|  | 60 | +            sliding_window_mask = torch.triu(empty_mask, | 
|  | 61 | +                                             diagonal=kv_len - | 
|  | 62 | +                                             (query_len + sliding_window) + | 
|  | 63 | +                                             1).bool().logical_not() | 
|  | 64 | +            mask |= sliding_window_mask | 
|  | 65 | +        if soft_cap is not None and soft_cap > 0: | 
|  | 66 | +            attn = soft_cap * torch.tanh(attn / soft_cap) | 
|  | 67 | +        attn.masked_fill_(mask, float("-inf")) | 
|  | 68 | +        attn = torch.softmax(attn, dim=-1).to(v.dtype) | 
|  | 69 | +        out = torch.einsum("hqk,khd->qhd", attn, v) | 
|  | 70 | + | 
|  | 71 | +        outputs.append(out) | 
|  | 72 | +        start_idx += query_len | 
|  | 73 | + | 
|  | 74 | +    return torch.cat(outputs, dim=0) | 
|  | 75 | + | 
|  | 76 | + | 
|  | 77 | +@pytest.mark.parametrize("seq_lens", | 
|  | 78 | +                         [[(1, 1328), (5, 18), | 
|  | 79 | +                           (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) | 
|  | 80 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) | 
|  | 81 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) | 
|  | 82 | +@pytest.mark.parametrize("block_size", BLOCK_SIZES) | 
|  | 83 | +@pytest.mark.parametrize("sliding_window", [None, 256]) | 
|  | 84 | +@pytest.mark.parametrize("dtype", DTYPES) | 
|  | 85 | +@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) | 
|  | 86 | +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) | 
|  | 87 | +@pytest.mark.parametrize("q_dtype", QDTYPES) | 
|  | 88 | +@torch.inference_mode() | 
|  | 89 | +def test_triton_unified_attn( | 
|  | 90 | +    seq_lens: list[tuple[int, int]], | 
|  | 91 | +    num_heads: tuple[int, int], | 
|  | 92 | +    head_size: int, | 
|  | 93 | +    sliding_window: Optional[int], | 
|  | 94 | +    dtype: torch.dtype, | 
|  | 95 | +    block_size: int, | 
|  | 96 | +    soft_cap: Optional[float], | 
|  | 97 | +    num_blocks: int, | 
|  | 98 | +    q_dtype: Optional[torch.dtype], | 
|  | 99 | +) -> None: | 
|  | 100 | +    torch.set_default_device("cuda") | 
|  | 101 | + | 
|  | 102 | +    current_platform.seed_everything(0) | 
|  | 103 | +    num_seqs = len(seq_lens) | 
|  | 104 | +    query_lens = [x[0] for x in seq_lens] | 
|  | 105 | +    kv_lens = [x[1] for x in seq_lens] | 
|  | 106 | +    num_query_heads = num_heads[0] | 
|  | 107 | +    num_kv_heads = num_heads[1] | 
|  | 108 | +    assert num_query_heads % num_kv_heads == 0 | 
|  | 109 | +    max_query_len = max(query_lens) | 
|  | 110 | +    max_kv_len = max(kv_lens) | 
|  | 111 | +    window_size = ((sliding_window - 1, 0) if sliding_window is not None else | 
|  | 112 | +                   (-1, -1)) | 
|  | 113 | +    scale = head_size**-0.5 | 
|  | 114 | + | 
|  | 115 | +    query = torch.randn(sum(query_lens), | 
|  | 116 | +                        num_query_heads, | 
|  | 117 | +                        head_size, | 
|  | 118 | +                        dtype=dtype) | 
|  | 119 | +    key_cache = torch.randn(num_blocks, | 
|  | 120 | +                            block_size, | 
|  | 121 | +                            num_kv_heads, | 
|  | 122 | +                            head_size, | 
|  | 123 | +                            dtype=dtype) | 
|  | 124 | +    value_cache = torch.randn_like(key_cache) | 
|  | 125 | +    cu_query_lens = torch.tensor([0] + query_lens, | 
|  | 126 | +                                 dtype=torch.int32).cumsum(dim=0, | 
|  | 127 | +                                                           dtype=torch.int32) | 
|  | 128 | +    kv_lens = torch.tensor(kv_lens, dtype=torch.int32) | 
|  | 129 | + | 
|  | 130 | +    max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size | 
|  | 131 | +    block_tables = torch.randint(0, | 
|  | 132 | +                                 num_blocks, | 
|  | 133 | +                                 (num_seqs, max_num_blocks_per_seq), | 
|  | 134 | +                                 dtype=torch.int32) | 
|  | 135 | + | 
|  | 136 | +    output = torch.empty_like(query) | 
|  | 137 | + | 
|  | 138 | +    maybe_quantized_query = query | 
|  | 139 | +    maybe_quantized_key_cache = key_cache | 
|  | 140 | +    maybe_quantized_value_cache = value_cache | 
|  | 141 | +    q_descale = None | 
|  | 142 | +    k_descale = None | 
|  | 143 | +    v_descale = None | 
|  | 144 | +    if q_dtype is not None: | 
|  | 145 | +        # QKV are drawn from N(0, 1): no need for a fp8 scaling factor | 
|  | 146 | +        maybe_quantized_query = query.to(q_dtype) | 
|  | 147 | +        maybe_quantized_key_cache = key_cache.to(q_dtype) | 
|  | 148 | +        maybe_quantized_value_cache = value_cache.to(q_dtype) | 
|  | 149 | + | 
|  | 150 | +        scale_shape = (num_seqs, num_kv_heads) | 
|  | 151 | +        q_descale = None  # Not yet supported | 
|  | 152 | +        k_descale = torch.rand(scale_shape, dtype=torch.float32) | 
|  | 153 | +        v_descale = torch.rand(scale_shape, dtype=torch.float32) | 
|  | 154 | + | 
|  | 155 | +    unified_attention( | 
|  | 156 | +        q=maybe_quantized_query, | 
|  | 157 | +        k=maybe_quantized_key_cache, | 
|  | 158 | +        v=maybe_quantized_value_cache, | 
|  | 159 | +        out=output, | 
|  | 160 | +        cu_seqlens_q=cu_query_lens, | 
|  | 161 | +        seqused_k=kv_lens, | 
|  | 162 | +        max_seqlen_q=max_query_len, | 
|  | 163 | +        max_seqlen_k=max_kv_len, | 
|  | 164 | +        softmax_scale=scale, | 
|  | 165 | +        causal=True, | 
|  | 166 | +        window_size=window_size, | 
|  | 167 | +        block_table=block_tables, | 
|  | 168 | +        softcap=soft_cap if soft_cap is not None else 0, | 
|  | 169 | +        q_descale=q_descale, | 
|  | 170 | +        k_descale=k_descale, | 
|  | 171 | +        v_descale=v_descale, | 
|  | 172 | +    ) | 
|  | 173 | + | 
|  | 174 | +    ref_output = ref_paged_attn( | 
|  | 175 | +        query=query, | 
|  | 176 | +        key_cache=key_cache, | 
|  | 177 | +        value_cache=value_cache, | 
|  | 178 | +        query_lens=query_lens, | 
|  | 179 | +        kv_lens=kv_lens, | 
|  | 180 | +        block_tables=block_tables, | 
|  | 181 | +        scale=scale, | 
|  | 182 | +        sliding_window=sliding_window, | 
|  | 183 | +        soft_cap=soft_cap, | 
|  | 184 | +    ) | 
|  | 185 | +    atol, rtol = 1.5e-2, 1e-2 | 
|  | 186 | +    if q_dtype is not None: | 
|  | 187 | +        atol, rtol = 1.5e-1, 1.5e-1 | 
|  | 188 | +    torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ | 
|  | 189 | +        f"{torch.max(torch.abs(output - ref_output))}" | 
0 commit comments