diff --git a/python/flashinfer/sparse.py b/python/flashinfer/sparse.py index 391e584d..f53261c7 100644 --- a/python/flashinfer/sparse.py +++ b/python/flashinfer/sparse.py @@ -18,6 +18,7 @@ from typing import Optional, Union, Tuple import logging import torch +from .decode import get_batch_decode_module from .prefill import _compute_page_qk_indptr, get_batch_prefill_module from .quantization import segment_packbits from .utils import ( @@ -299,31 +300,65 @@ def plan( kv_indptr_host = indptr.to("cpu", non_blocking=True) - self._cached_module = get_batch_prefill_module( - q_data_type, - kv_data_type, - q_data_type, - indptr.dtype, - head_dim, - PosEncodingMode[pos_encoding_mode].value, - mask_mode, - False, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap - allow_fp16_qk_reduction, - ) + # NOTE(Zihao): we haven't supported mask in cuda-core implementations but it should + # be easy to add support for it if needed, leave it as a future work. + # at this moment, when mask is provided, we use the tensor-core implementation + if ( + R * (num_qo_heads // num_kv_heads) < 4 + and mask_mode == MaskMode.NON_CAUSAL.value + ): + # If the operation is not compute-bound, we use the cuda-core implementation + self._use_tensor_cores = False + self._cached_module = get_batch_decode_module( + q_data_type, + kv_data_type, + q_data_type, + indptr.dtype, + head_dim, + PosEncodingMode[pos_encoding_mode].value, + False, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + ) - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - kv_indptr_host, - num_blocks_row, - num_qo_heads, - num_kv_heads, - C, - False, # is_cuda_graph_enabled - ) + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + kv_indptr_host, + num_blocks_row, + num_qo_heads, + num_kv_heads, + C, + False, # is_cuda_graph_enabled + ) + else: + # if the operation is compute-bound, we use the tensor-core implementation + self._use_tensor_cores = True + self._cached_module = get_batch_prefill_module( + q_data_type, + kv_data_type, + q_data_type, + indptr.dtype, + head_dim, + PosEncodingMode[pos_encoding_mode].value, + mask_mode, + False, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + allow_fp16_qk_reduction, + ) + + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + kv_indptr_host, + num_blocks_row, + num_qo_heads, + num_kv_heads, + C, + False, # is_cuda_graph_enabled + ) self._pos_encoding_mode = pos_encoding_mode self._allow_fp16_qk_reduction = allow_fp16_qk_reduction @@ -404,30 +439,57 @@ def run( k = k.reshape(-1, self.C, *k.shape[-2:]).contiguous() v = v.reshape(-1, self.C, *v.shape[-2:]).contiguous() - out = self._cached_module.paged_run( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._plan_info, - q, - k, - v, - self._packed_mask_buf, - _get_cache_alibi_slopes_buf(q.shape[1], self.device), - self._qo_indptr, - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len, - self._qk_indptr_buf, - TensorLayout[self._kv_layout].value, - -1, # window_left - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - return_lse, - ) + lse = None + if return_lse: + lse = torch.empty( + (q.size(0), q.size(1)), dtype=torch.float32, device=q.device + ) + + if self._use_tensor_cores: + out = self._cached_module.paged_run( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._plan_info, + q, + k, + v, + self._packed_mask_buf, + _get_cache_alibi_slopes_buf(q.shape[1], self.device), + self._qo_indptr, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len, + self._qk_indptr_buf, + TensorLayout[self._kv_layout].value, + -1, # window_left + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + lse, + ) + else: + out = self._cached_module.run( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._plan_info, + q, + k, + v, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len, + _get_cache_alibi_slopes_buf(q.shape[1], self.device), + TensorLayout[self._kv_layout].value, + -1, # window_left + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + lse, + ) - return out if return_lse else out[0] + return (out, lse) if return_lse else out def end_forward(self) -> None: r"""Warning: This method is deprecated and has no effect."""