diff --git a/docs/api/python/sparse.rst b/docs/api/python/sparse.rst new file mode 100644 index 00000000..f82191ed --- /dev/null +++ b/docs/api/python/sparse.rst @@ -0,0 +1,13 @@ +.. _apisparse: + +flashinfer.sparse +================= + +Kernels for block sparse flashattention. + +.. currentmodule:: flashinfer.sparse + +.. autoclass:: BlockSparseAttentionWrapper + :members: + + .. automethod:: __init__ \ No newline at end of file diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 79ecd83e..1072af10 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -751,15 +751,16 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, uint32_t* v_smem_o *v_smem_offset_r -= 16 * num_frags_z * channel_size_128b_in; } -template -__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], float (*d)[2]) { +template +__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], DTypeQKAccum (*m)[2], + float (*d)[2]) { float d_rcp[num_frags_x][2]; // compute reciprocal of d #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - d_rcp[fx][j] = math::ptx_rcp(d[fx][j]); + d_rcp[fx][j] = (m[fx][j] != DTypeQKAccum(-5e4)) ? math::ptx_rcp(d[fx][j]) : 0.f; } } @@ -1161,7 +1162,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC o_frag, (float*)smem, m, d, warp_idx, lane_idx); // normalize d - normalize_d(o_frag, d); + normalize_d(o_frag, m, d); // write back write_o_reg_gmem( @@ -1428,7 +1429,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg o_frag, (float*)smem, m, d, warp_idx, lane_idx); // normalize d - normalize_d(o_frag, d); + normalize_d(o_frag, m, d); const uint32_t num_kv_chunks = ceil_div(kv_len, kv_chunk_size); @@ -1719,7 +1720,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage o_frag, (float*)smem, m, d, warp_idx, lane_idx); // normalize d - normalize_d(o_frag, d); + normalize_d(o_frag, m, d); const uint32_t num_kv_chunks = ceil_div(kv_len, kv_chunk_size); diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index d93097a9..9472afae 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -25,6 +25,7 @@ BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, ) +from .sparse import BlockSparseAttentionWrapper from .cascade import ( merge_state, merge_state_in_place, diff --git a/python/flashinfer/sampling.py b/python/flashinfer/sampling.py index 76131c4b..c3eda0b3 100644 --- a/python/flashinfer/sampling.py +++ b/python/flashinfer/sampling.py @@ -254,7 +254,7 @@ def top_k_top_p_sampling_from_probs( >>> samples tensor([3, 3, 0, 1], device='cuda:0', dtype=torch.int32) >>> success - tensor([True, True, True, True], device='cuda:0') + tensor([True, True, True, True], device='cuda:0') Notes ----- diff --git a/python/flashinfer/sparse.py b/python/flashinfer/sparse.py new file mode 100644 index 00000000..df92f2ab --- /dev/null +++ b/python/flashinfer/sparse.py @@ -0,0 +1,292 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import math +from typing import Optional +import torch +import logging +from .prefill import _compute_page_qk_indptr +from .quantization import segment_packbits +from .utils import ( + check_pos_encoding_mode, + check_kv_layout, + is_float8, + expand_5d, + PosEncodingMode, + TensorLayout, +) + +try: + from . import _kernels +except ImportError as e: + import os + import logging + + if os.environ.get("BUILD_DOC", "0") == "1": + _kernels = None + logging.warning("Kernels are not loaded in documentation build mode.") + else: + raise e + + +class BlockSparseAttentionWrapper: + def __init__( + self, + workspace_buffer: torch.Tensor, + kv_layout: str = "NHD", + ): + r"""Constructs of :class:`BlockSparseAttentionWrapper`. + + Warning(Zihao): this is an experimental API and subject to change. + + Parameters + ---------- + workspace_buffer : torch.Tensor + The user reserved workspace buffer used to store auxiliary data structures, + recommended size is 128MB, the device of the workspace buffer should be the + same as the device of the input tensors. + + kv_layout : str + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + """ + check_kv_layout(kv_layout) + self._kv_layout = kv_layout + self._workspace_buffer = workspace_buffer + self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper( + TensorLayout[kv_layout].value, + False, # use_cuda_graph + ) + + def begin_forward( + self, + indptr: torch.Tensor, + indices: torch.Tensor, + M: int, + N: int, + R: int, + C: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + mask: Optional[torch.Tensor] = None, + packed_mask: Optional[torch.Tensor] = None, + q_data_type: str = "float16", + ): + r"""Create auxiliary data structures for block sparse attention. + + Parameters + ---------- + indptr : torch.Tensor + The indptr of the block-sparse matrix, shape (MB + 1,), where MB is the number of blocks in the row dimension. + indices: torch.Tensor + The indices of the block-sparse matrix, shape (nnz,), where nnz is the number of non-zero blocks. + M : int + The number of rows of the block-sparse matrix, MB = ceil_div(M, R). + N : int + The number of columns of the block-sparse matrix, NB = ceil_div(N, C). + R : int + The number of rows in each block. + C : int + The number of columns in each block. + num_qo_heads : int + The number of heads in the query/output tensor. + num_kv_heads : int + The number of heads in the key/value tensor. + head_dim : int + The dimension of each head. + mask : torch.Tensor, optional + The flattened mask tensor, shape (nnz * R * C,), where nnz is the number of non-zero blocks. + If every block is full, then we don't need to provide the mask tensor. + packed_mask : torch.Tensor, optional + The 1D packed mask tensor, if provided, the :attr:`custom_mask` will be ignored. + The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. + q_data_type : str, optional + The data type of the query tensor. + + The :meth:`begin_forward` method should be called before any :meth:`forward` or + :meth:`forward_return_lse` calls, auxiliary data structures will be created + during this call and cached for multiple forward calls. + + The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` + is not equal to ``num_kv_heads``, the function will use + `grouped query attention `_. + """ + num_rows = len(indptr) - 1 + qo_indptr_host = R * torch.arange(num_rows + 1, dtype=torch.int32) + qo_indptr_host[-1] = M + self._qo_indptr = qo_indptr_host.to(indptr.device) + row_empty = indptr[1:] == indptr[:1] + if indices.max().item() * C > N: + raise ValueError("indices out of bound") + last_block_pos = indices[torch.clamp(indptr[1:], min=1) - 1] + last_block_pos.masked_fill_(row_empty, 0) + last_block_len = torch.clamp(N - last_block_pos * C, max=C) + + if mask is not None or packed_mask is not None: + qk_indptr = _compute_page_qk_indptr( + self._qo_indptr, + indptr, # paged_kv_indptr + last_block_len, # paged_kv_last_page_len + C, # page_size + ) + if packed_mask is None and mask is not None: + # create packed mask from mask + packed_mask, qk_indptr = segment_packbits( + mask.contiguous().view(-1), qk_indptr, bitorder="little" + ) + + self._paged_kv_indptr_buf = indptr + self._paged_kv_indices_buf = indices + self._paged_kv_last_page_len = last_block_len + if packed_mask is not None: + self._packed_mask_buf = packed_mask + self._qk_indptr_buf = qk_indptr + else: + self._packed_mask_buf = None + + empty_q_data = torch.empty( + 0, + dtype=( + getattr(torch, q_data_type) + if isinstance(q_data_type, str) + else q_data_type + ), + ) + + self._wrapper.begin_forward( + self._workspace_buffer, + self._qo_indptr, + self._paged_kv_indptr_buf, + num_rows, + num_qo_heads, + num_kv_heads, + head_dim, + C, + empty_q_data, + ) + + def end_forward(self): + r"""Clear the auxiliary data structures created by :meth:`begin_forward`.""" + self._qo_indptr = None + self._paged_kv_indptr_buf = None + self._paged_kv_indices_buf = None + self._paged_kv_last_page_len = None + self._packed_mask_buf = None + self._qk_indptr_buf = None + + def forward( + self, + q: torch.Tensor, + kv_data: torch.Tensor, + pos_encoding_mode: str = "NONE", + allow_fp16_qk_reduction: bool = False, + logits_soft_cap: Optional[float] = None, + sm_scale: Optional[float] = None, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, + ): + r"""Compute block-sparse attention between Q/K/V tensors. + + Warning(Zihao): in the next release, kv_data will be decoupled into standalone k/v tensors, each + with shape (N, num_kv_heads, head_dim). + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape (M, num_qo_heads, head_dim). + kv_data : torch.Tensor + The key/value tensor, shape (N // C, 2, C, num_kv_heads, head_dim). + pos_encoding_mode : str, optional + The position encoding applied inside attention kernels, could be + ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + Default is ``NONE``. + allow_fp16_qk_reduction : bool + Whether to use f16 for qk reduction (faster at the cost of slight precision + loss). + logits_soft_cap : Optional[float] + The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not + provided, will be set to ``0``. If greater than 0, the logits will be capped according to + formula: + :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, + where :math:`x` is the input logits. + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to + ``1.0 / sqrt(head_dim)``. + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to + ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. + + Returns + ------- + torch.Tensor + The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + """ + check_pos_encoding_mode(pos_encoding_mode) + if logits_soft_cap is None: + logits_soft_cap = 0.0 + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(q.size(-1)) + if rope_scale is None: + rope_scale = 1.0 + if rope_theta is None: + rope_theta = 1e4 + if is_float8(q): + logging.warning( + "Our current prefill kernel implementation needs f16 input, the f8 inputs " + " are casted to f16, which could result in performance degradation." + ) + q = q.to(torch.float16) + kv_data = kv_data.to(torch.float16) + + kv_data = expand_5d(kv_data, self._kv_layout) + + if self._packed_mask_buf is None: + return self._wrapper.forward( + q, + self._qo_indptr, + kv_data, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len, + False, # causal + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + False, # return LSE + )[0] + else: + return self._wrapper.forward_custom_mask( + q, + self._qo_indptr, + kv_data, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len, + self._packed_mask_buf, + self._qk_indptr_buf, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + False, # return LSE + )[0] diff --git a/python/tests/test_block_sparse.py b/python/tests/test_block_sparse.py new file mode 100644 index 00000000..1bb9f919 --- /dev/null +++ b/python/tests/test_block_sparse.py @@ -0,0 +1,105 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +import torch +import numpy as np +import scipy as sp +import flashinfer + + +def bsr_attention_ref( + q, + kv, + indptr, + indices, + mask_data, +): + M = q.shape[0] + NB, _, C, H_KV, D = kv.shape + N = NB * C + bsr = sp.sparse.bsr_matrix( + (mask_data.cpu().numpy(), indices.cpu().numpy(), indptr.cpu().numpy()), + shape=(M, N), + ) + dense_mask = torch.tensor(bsr.toarray(), dtype=bool, device=q.device) + k = kv[:, 0].reshape(-1, H_KV, D).contiguous() + v = kv[:, 1].reshape(-1, H_KV, D).contiguous() + o = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=dense_mask) + return o + + +@pytest.mark.parametrize("R", [1, 4, 16]) +@pytest.mark.parametrize("C", [1, 4, 16]) +@pytest.mark.parametrize("M", [64, 128, 256]) +@pytest.mark.parametrize("N", [64, 128, 256]) +@pytest.mark.parametrize("num_qo_heads", [1, 4, 16]) +@pytest.mark.parametrize("num_kv_heads", [1, 4, 16]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("mask_inside_block", [True, False]) +def test_block_sparse_attention( + R, C, M, N, num_qo_heads, num_kv_heads, head_dim, mask_inside_block +): + if num_qo_heads % num_kv_heads != 0: + pytest.skip("num_qo_heads must be divisible by num_kv_heads") + rng = np.random.default_rng() + MB = M // R + NB = N // C + S = sp.sparse.random(MB, NB, density=0.25, random_state=rng).tocsr() + indptr = torch.from_numpy(S.indptr).to(0) + indices = torch.from_numpy(S.indices).to(0) + nnz = S.nnz + if mask_inside_block: + data_mask = (torch.rand((nnz, R, C)) > 0.5).to(0) + else: + data_mask = torch.full((nnz, R, C), True, dtype=bool, device=0) + q = torch.randn((M, num_qo_heads, head_dim), dtype=torch.float16, device=0) + kv_data = torch.randn( + (NB, 2, C, num_kv_heads, head_dim), dtype=torch.float16, device=0 + ) + + o_ref = bsr_attention_ref(q, kv_data, indptr, indices, data_mask) + workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device=0) + sparse_attention_wrapper = flashinfer.BlockSparseAttentionWrapper(workspace_buffer) + + if mask_inside_block: + mask_flashinfer_layout = torch.full((nnz * R * C,), False, dtype=bool, device=0) + for i in range(MB): + mask_flashinfer_layout[indptr[i] * R * C : indptr[i + 1] * R * C] = ( + data_mask[indptr[i] : indptr[i + 1]].transpose(0, 1).reshape(-1) + ) + + sparse_attention_wrapper.begin_forward( + indptr, + indices, + M, + N, + R, + C, + num_qo_heads, + num_kv_heads, + head_dim, + mask=mask_flashinfer_layout if mask_inside_block else None, + ) + + o = sparse_attention_wrapper.forward(q, kv_data) + sparse_attention_wrapper.end_forward() + print(o_ref, o) + np.testing.assert_allclose(o_ref.cpu(), o.cpu(), atol=1e-2, rtol=1e-3) + + +if __name__ == "__main__": + test_block_sparse_attention(1, 1, 64, 64, 1, 1, 128, True)