-
Notifications
You must be signed in to change notification settings - Fork 162
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: expose pytorch api for block sparse attention (#375)
The block sparse attention (for any block size (R, C)) are hidden in flashinfer's codebase but it was never exposed explicitly in python. As requested in #367 , this PR implements the PyTorch APIs for block sparse attention, accordingly to our experiments, it can greatly accelerate attention computation with low density (10x for Tree Attention in Sequoia).
- Loading branch information
Showing
6 changed files
with
419 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
.. _apisparse: | ||
|
||
flashinfer.sparse | ||
================= | ||
|
||
Kernels for block sparse flashattention. | ||
|
||
.. currentmodule:: flashinfer.sparse | ||
|
||
.. autoclass:: BlockSparseAttentionWrapper | ||
:members: | ||
|
||
.. automethod:: __init__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,292 @@ | ||
""" | ||
Copyright (c) 2024 by FlashInfer team. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import math | ||
from typing import Optional | ||
import torch | ||
import logging | ||
from .prefill import _compute_page_qk_indptr | ||
from .quantization import segment_packbits | ||
from .utils import ( | ||
check_pos_encoding_mode, | ||
check_kv_layout, | ||
is_float8, | ||
expand_5d, | ||
PosEncodingMode, | ||
TensorLayout, | ||
) | ||
|
||
try: | ||
from . import _kernels | ||
except ImportError as e: | ||
import os | ||
import logging | ||
|
||
if os.environ.get("BUILD_DOC", "0") == "1": | ||
_kernels = None | ||
logging.warning("Kernels are not loaded in documentation build mode.") | ||
else: | ||
raise e | ||
|
||
|
||
class BlockSparseAttentionWrapper: | ||
def __init__( | ||
self, | ||
workspace_buffer: torch.Tensor, | ||
kv_layout: str = "NHD", | ||
): | ||
r"""Constructs of :class:`BlockSparseAttentionWrapper`. | ||
Warning(Zihao): this is an experimental API and subject to change. | ||
Parameters | ||
---------- | ||
workspace_buffer : torch.Tensor | ||
The user reserved workspace buffer used to store auxiliary data structures, | ||
recommended size is 128MB, the device of the workspace buffer should be the | ||
same as the device of the input tensors. | ||
kv_layout : str | ||
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. | ||
""" | ||
check_kv_layout(kv_layout) | ||
self._kv_layout = kv_layout | ||
self._workspace_buffer = workspace_buffer | ||
self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper( | ||
TensorLayout[kv_layout].value, | ||
False, # use_cuda_graph | ||
) | ||
|
||
def begin_forward( | ||
self, | ||
indptr: torch.Tensor, | ||
indices: torch.Tensor, | ||
M: int, | ||
N: int, | ||
R: int, | ||
C: int, | ||
num_qo_heads: int, | ||
num_kv_heads: int, | ||
head_dim: int, | ||
mask: Optional[torch.Tensor] = None, | ||
packed_mask: Optional[torch.Tensor] = None, | ||
q_data_type: str = "float16", | ||
): | ||
r"""Create auxiliary data structures for block sparse attention. | ||
Parameters | ||
---------- | ||
indptr : torch.Tensor | ||
The indptr of the block-sparse matrix, shape (MB + 1,), where MB is the number of blocks in the row dimension. | ||
indices: torch.Tensor | ||
The indices of the block-sparse matrix, shape (nnz,), where nnz is the number of non-zero blocks. | ||
M : int | ||
The number of rows of the block-sparse matrix, MB = ceil_div(M, R). | ||
N : int | ||
The number of columns of the block-sparse matrix, NB = ceil_div(N, C). | ||
R : int | ||
The number of rows in each block. | ||
C : int | ||
The number of columns in each block. | ||
num_qo_heads : int | ||
The number of heads in the query/output tensor. | ||
num_kv_heads : int | ||
The number of heads in the key/value tensor. | ||
head_dim : int | ||
The dimension of each head. | ||
mask : torch.Tensor, optional | ||
The flattened mask tensor, shape (nnz * R * C,), where nnz is the number of non-zero blocks. | ||
If every block is full, then we don't need to provide the mask tensor. | ||
packed_mask : torch.Tensor, optional | ||
The 1D packed mask tensor, if provided, the :attr:`custom_mask` will be ignored. | ||
The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. | ||
q_data_type : str, optional | ||
The data type of the query tensor. | ||
The :meth:`begin_forward` method should be called before any :meth:`forward` or | ||
:meth:`forward_return_lse` calls, auxiliary data structures will be created | ||
during this call and cached for multiple forward calls. | ||
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` | ||
is not equal to ``num_kv_heads``, the function will use | ||
`grouped query attention <https://arxiv.org/abs/2305.13245>`_. | ||
""" | ||
num_rows = len(indptr) - 1 | ||
qo_indptr_host = R * torch.arange(num_rows + 1, dtype=torch.int32) | ||
qo_indptr_host[-1] = M | ||
self._qo_indptr = qo_indptr_host.to(indptr.device) | ||
row_empty = indptr[1:] == indptr[:1] | ||
if indices.max().item() * C > N: | ||
raise ValueError("indices out of bound") | ||
last_block_pos = indices[torch.clamp(indptr[1:], min=1) - 1] | ||
last_block_pos.masked_fill_(row_empty, 0) | ||
last_block_len = torch.clamp(N - last_block_pos * C, max=C) | ||
|
||
if mask is not None or packed_mask is not None: | ||
qk_indptr = _compute_page_qk_indptr( | ||
self._qo_indptr, | ||
indptr, # paged_kv_indptr | ||
last_block_len, # paged_kv_last_page_len | ||
C, # page_size | ||
) | ||
if packed_mask is None and mask is not None: | ||
# create packed mask from mask | ||
packed_mask, qk_indptr = segment_packbits( | ||
mask.contiguous().view(-1), qk_indptr, bitorder="little" | ||
) | ||
|
||
self._paged_kv_indptr_buf = indptr | ||
self._paged_kv_indices_buf = indices | ||
self._paged_kv_last_page_len = last_block_len | ||
if packed_mask is not None: | ||
self._packed_mask_buf = packed_mask | ||
self._qk_indptr_buf = qk_indptr | ||
else: | ||
self._packed_mask_buf = None | ||
|
||
empty_q_data = torch.empty( | ||
0, | ||
dtype=( | ||
getattr(torch, q_data_type) | ||
if isinstance(q_data_type, str) | ||
else q_data_type | ||
), | ||
) | ||
|
||
self._wrapper.begin_forward( | ||
self._workspace_buffer, | ||
self._qo_indptr, | ||
self._paged_kv_indptr_buf, | ||
num_rows, | ||
num_qo_heads, | ||
num_kv_heads, | ||
head_dim, | ||
C, | ||
empty_q_data, | ||
) | ||
|
||
def end_forward(self): | ||
r"""Clear the auxiliary data structures created by :meth:`begin_forward`.""" | ||
self._qo_indptr = None | ||
self._paged_kv_indptr_buf = None | ||
self._paged_kv_indices_buf = None | ||
self._paged_kv_last_page_len = None | ||
self._packed_mask_buf = None | ||
self._qk_indptr_buf = None | ||
|
||
def forward( | ||
self, | ||
q: torch.Tensor, | ||
kv_data: torch.Tensor, | ||
pos_encoding_mode: str = "NONE", | ||
allow_fp16_qk_reduction: bool = False, | ||
logits_soft_cap: Optional[float] = None, | ||
sm_scale: Optional[float] = None, | ||
rope_scale: Optional[float] = None, | ||
rope_theta: Optional[float] = None, | ||
): | ||
r"""Compute block-sparse attention between Q/K/V tensors. | ||
Warning(Zihao): in the next release, kv_data will be decoupled into standalone k/v tensors, each | ||
with shape (N, num_kv_heads, head_dim). | ||
Parameters | ||
---------- | ||
q : torch.Tensor | ||
The query tensor, shape (M, num_qo_heads, head_dim). | ||
kv_data : torch.Tensor | ||
The key/value tensor, shape (N // C, 2, C, num_kv_heads, head_dim). | ||
pos_encoding_mode : str, optional | ||
The position encoding applied inside attention kernels, could be | ||
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. | ||
Default is ``NONE``. | ||
allow_fp16_qk_reduction : bool | ||
Whether to use f16 for qk reduction (faster at the cost of slight precision | ||
loss). | ||
logits_soft_cap : Optional[float] | ||
The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not | ||
provided, will be set to ``0``. If greater than 0, the logits will be capped according to | ||
formula: | ||
:math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, | ||
where :math:`x` is the input logits. | ||
sm_scale : Optional[float] | ||
The scale used in softmax, if not provided, will be set to | ||
``1.0 / sqrt(head_dim)``. | ||
rope_scale : Optional[float] | ||
The scale used in RoPE interpolation, if not provided, will be set to | ||
``1.0``. | ||
rope_theta : Optional[float] | ||
The theta used in RoPE, if not provided, will be set to ``1e4``. | ||
Returns | ||
------- | ||
torch.Tensor | ||
The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. | ||
""" | ||
check_pos_encoding_mode(pos_encoding_mode) | ||
if logits_soft_cap is None: | ||
logits_soft_cap = 0.0 | ||
if sm_scale is None: | ||
sm_scale = 1.0 / math.sqrt(q.size(-1)) | ||
if rope_scale is None: | ||
rope_scale = 1.0 | ||
if rope_theta is None: | ||
rope_theta = 1e4 | ||
if is_float8(q): | ||
logging.warning( | ||
"Our current prefill kernel implementation needs f16 input, the f8 inputs " | ||
" are casted to f16, which could result in performance degradation." | ||
) | ||
q = q.to(torch.float16) | ||
kv_data = kv_data.to(torch.float16) | ||
|
||
kv_data = expand_5d(kv_data, self._kv_layout) | ||
|
||
if self._packed_mask_buf is None: | ||
return self._wrapper.forward( | ||
q, | ||
self._qo_indptr, | ||
kv_data, | ||
self._paged_kv_indptr_buf, | ||
self._paged_kv_indices_buf, | ||
self._paged_kv_last_page_len, | ||
False, # causal | ||
PosEncodingMode[pos_encoding_mode].value, | ||
allow_fp16_qk_reduction, | ||
logits_soft_cap, | ||
sm_scale, | ||
rope_scale, | ||
rope_theta, | ||
False, # return LSE | ||
)[0] | ||
else: | ||
return self._wrapper.forward_custom_mask( | ||
q, | ||
self._qo_indptr, | ||
kv_data, | ||
self._paged_kv_indptr_buf, | ||
self._paged_kv_indices_buf, | ||
self._paged_kv_last_page_len, | ||
self._packed_mask_buf, | ||
self._qk_indptr_buf, | ||
PosEncodingMode[pos_encoding_mode].value, | ||
allow_fp16_qk_reduction, | ||
logits_soft_cap, | ||
sm_scale, | ||
rope_scale, | ||
rope_theta, | ||
False, # return LSE | ||
)[0] |
Oops, something went wrong.