Skip to content

Commit

Permalink
feat: add use_tensor_cores option to decode kernels to accelerate G…
Browse files Browse the repository at this point in the history
…QA (#317)

The tensor-cores accelerated GQA in our [blog
post](https://flashinfer.ai/2024/02/02/introduce-flashinfer.html) was
not enabled by default (user need to use Prefill kernels/wrappers for
decode to get such acceleration).

In this PR we add an option `use_tensor_cores` to decode
operators/wrappers, and user can select whether to use `tensor_cores`
for acceleration depending on use cases.

Not that our prefill kernels are compiled for all possible group sizes
(#301 ), but decode kernels are not. So if user wants to use general
group size, it's encouraged to set `use_tensor_cores=True`.
  • Loading branch information
yzh119 authored Jun 20, 2024
1 parent 2ef20c1 commit 3b50dd5
Show file tree
Hide file tree
Showing 6 changed files with 506 additions and 137 deletions.
15 changes: 0 additions & 15 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,24 +91,9 @@
if (group_size == 1) { \
constexpr size_t GROUP_SIZE = 1; \
__VA_ARGS__ \
} else if (group_size == 2) { \
constexpr size_t GROUP_SIZE = 2; \
__VA_ARGS__ \
} else if (group_size == 3) { \
constexpr size_t GROUP_SIZE = 3; \
__VA_ARGS__ \
} else if (group_size == 4) { \
constexpr size_t GROUP_SIZE = 4; \
__VA_ARGS__ \
} else if (group_size == 5) { \
constexpr size_t GROUP_SIZE = 5; \
__VA_ARGS__ \
} else if (group_size == 6) { \
constexpr size_t GROUP_SIZE = 6; \
__VA_ARGS__ \
} else if (group_size == 7) { \
constexpr size_t GROUP_SIZE = 7; \
__VA_ARGS__ \
} else if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \
__VA_ARGS__ \
Expand Down
12 changes: 4 additions & 8 deletions python/csrc/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,14 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
unsigned int head_dim = q.size(2);
unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads;
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
qo_len = q.size(0);
num_qo_heads = q.size(1);
if (kv_layout == QKVLayout::kNHD) {
kv_len = k.size(0);
qo_len = q.size(0);
num_kv_heads = k.size(1);
num_qo_heads = q.size(1);
} else {
kv_len = k.size(1);
qo_len = q.size(1);
num_kv_heads = k.size(0);
num_qo_heads = q.size(0);
}
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
Expand Down Expand Up @@ -122,16 +120,14 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache_custom_mask(
unsigned int head_dim = q.size(2);
unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads;
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
qo_len = q.size(0);
num_qo_heads = q.size(1);
if (kv_layout == QKVLayout::kNHD) {
kv_len = k.size(0);
qo_len = q.size(0);
num_kv_heads = k.size(1);
num_qo_heads = q.size(1);
} else {
kv_len = k.size(1);
qo_len = q.size(1);
num_kv_heads = k.size(0);
num_qo_heads = q.size(0);
}
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
Expand Down
10 changes: 5 additions & 5 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
>>> head_dim = 128
>>> max_num_pages = 128
>>> page_size = 16
>>> # allocate 16MB workspace buffer
>>> workspace_buffer = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> wrapper = flashinfer.BatchDecodeWithSharedPrefixPagedKVCacheWrapper(
... workspace_buffer, "NHD"
... )
Expand Down Expand Up @@ -540,8 +540,8 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
>>> head_dim = 128
>>> max_num_pages = 128
>>> page_size = 16
>>> # allocate 16MB workspace buffer
>>> workspace_buffer = torch.empty(16 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> # allocate 128MB workspace buffer
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
>>> prefill_wrapper = flashinfer.BatchPrefillWithSharedPrefixPagedKVCacheWrapper(
... workspace_buffer, "NHD"
... )
Expand Down Expand Up @@ -617,7 +617,7 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"):
----------
workspace_buffer : torch.Tensor
The user reserved workspace buffer used to store auxiliary data structures,
recommended size is 16MB, the device of the workspace buffer should be the
recommended size is 128MB, the device of the workspace buffer should be the
same as the device of the input tensors.
kv_layout : str
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
Expand Down
Loading

0 comments on commit 3b50dd5

Please sign in to comment.