Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: Fix sm75 kernel configuration #449

Merged
merged 12 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/release_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ on:
# required: true

env:
TORCH_CUDA_ARCH_LIST: "8.0 8.9 9.0+PTX"
TORCH_CUDA_ARCH_LIST: "7.5 8.0 8.9 9.0+PTX"

jobs:
build:
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Prerequisites

- Use ``python -c "import torch; print(torch.version.cuda)"`` to check your PyTorch CUDA version.

- Supported GPU architectures: ``sm80``, ``sm86``, ``sm89``, ``sm90`` (``sm75`` / ``sm70`` support is working in progress).
- Supported GPU architectures: ``sm75``, ``sm80``, ``sm86``, ``sm89``, ``sm90``.

Quick Start
^^^^^^^^^^^
Expand Down
252 changes: 128 additions & 124 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@
#include <cooperative_groups.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include <cstddef>
#ifdef FLASHINFER_ENABLE_FP8
#include <cuda_fp8.h>
#endif
#include <cuda_runtime.h>

#include <cstddef>
#include <cuda/pipeline>
#include <iostream>
#include <optional>
Expand Down Expand Up @@ -537,6 +534,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
j] +
tx * vec_size;
}

// load k tiles
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
Expand Down Expand Up @@ -597,11 +595,7 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo
return 512U;
}
} else {
#ifdef FLASHINFER_ENABLE_BF16
return 128U;
#else
return 64U;
#endif
}
}

Expand Down Expand Up @@ -639,8 +633,8 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
constexpr uint32_t bdx = HEAD_DIM / vec_size;
auto compute_capacity = GetCudaComputeCapability();
static_assert(bdx <= 32U);
DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, {
constexpr uint32_t bdy = GROUP_SIZE;
Expand All @@ -649,69 +643,74 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
constexpr uint32_t bdz = num_threads / (bdx * bdy);
tensor_info_t info(1, seq_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM);
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 8U) : 1U;
const uint32_t smem_size =
2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) +
2U * bdy * bdz * sizeof(float);
auto kernel = SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE,
num_stages_smem, tile_size_per_bdx, vec_size, bdx,
bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
if (seq_len <= 256 || tmp == nullptr) {
// no need to use partition-kv kernel
dim3 nblks = dim3(1, num_kv_heads);
dim3 nthrs = dim3(bdx, bdy, bdz);
float* lse = nullptr;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&v,
(void*)&o,
(void*)&lse,
(void*)&info,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta,
(void*)&seq_len};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm);
uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads;
uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256);
uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size);
dim3 nblks = dim3(num_chunks, num_kv_heads);
if (nblks.x == 0 || nblks.y == 0) {
std::ostringstream err_msg;
err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")";
throw std::runtime_error(err_msg.str());
}
dim3 nthrs = dim3(bdx, bdy, bdz);
float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM);
void* args[] = {(void*)&q,
(void*)&k,
(void*)&v,
(void*)&tmp,
(void*)&tmp_lse,
(void*)&info,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta,
(void*)&kv_chunk_size};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, {
const uint32_t smem_size =
2U * NUM_STAGES_SMEM * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) +
2U * bdy * bdz * sizeof(float);
auto kernel = SingleDecodeWithKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE,
NUM_STAGES_SMEM, tile_size_per_bdx, vec_size, bdx,
bdy, bdz, DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
MergeStates(tmp, tmp_lse, o, nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream));
}
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
if (seq_len <= 256 || tmp == nullptr) {
// no need to use partition-kv kernel
dim3 nblks = dim3(1, num_kv_heads);
dim3 nthrs = dim3(bdx, bdy, bdz);
float* lse = nullptr;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&v,
(void*)&o,
(void*)&lse,
(void*)&info,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta,
(void*)&seq_len};
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, kernel, num_threads, smem_size));
uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm);
uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads;
uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256);
uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size);
dim3 nblks = dim3(num_chunks, num_kv_heads);
if (nblks.x == 0 || nblks.y == 0) {
std::ostringstream err_msg;
err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")";
throw std::runtime_error(err_msg.str());
}
dim3 nthrs = dim3(bdx, bdy, bdz);
float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM);
void* args[] = {(void*)&q,
(void*)&k,
(void*)&v,
(void*)&tmp,
(void*)&tmp_lse,
(void*)&info,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta,
(void*)&kv_chunk_size};
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(
MergeStates(tmp, tmp_lse, o, nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream));
}
});
});
return cudaSuccess;
}
Expand All @@ -730,66 +729,71 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
const uint32_t num_kv_heads = paged_kv.num_heads;

constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
auto compute_capacity = GetCudaComputeCapability();
constexpr uint32_t bdx = HEAD_DIM / vec_size;
static_assert(bdx <= 32);
DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, {
constexpr uint32_t bdy = GROUP_SIZE;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U;
const uint32_t smem_size =
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float));
auto kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE, num_stages_smem,
tile_size_per_bdx, vec_size, bdx, bdy, bdz, page_storage,
DTypeQ, DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
if (tmp_v == nullptr) {
// do not use partition-kv kernel
bool partition_kv = false;
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);

void* args[] = {(void*)&q,
(void*)&q_offset,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&o,
(void*)&lse,
(void*)&block_valid_mask,
(void*)&partition_kv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
bool partition_kv = true;
void* args[] = {(void*)&q,
(void*)&q_offset,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&tmp_v,
(void*)&tmp_s,
(void*)&block_valid_mask,
(void*)&partition_kv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
tmp_v, tmp_s, kv_partition_info.chunk_indptr, o, lse,
kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream));
}
DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, {
const uint32_t smem_size =
2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*),
2 * bdy * bdz * sizeof(float));
auto kernel =
BatchDecodeWithPagedKVCacheKernel<LOGITS_POST_HOOK, POS_ENCODING_MODE, NUM_STAGES_SMEM,
tile_size_per_bdx, vec_size, bdx, bdy, bdz,
page_storage, DTypeQ, DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
if (tmp_v == nullptr) {
// do not use partition-kv kernel
bool partition_kv = false;
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);

void* args[] = {(void*)&q,
(void*)&q_offset,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&o,
(void*)&lse,
(void*)&block_valid_mask,
(void*)&partition_kv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
bool partition_kv = true;
void* args[] = {(void*)&q,
(void*)&q_offset,
(void*)&paged_kv,
(void*)&kv_partition_info,
(void*)&tmp_v,
(void*)&tmp_s,
(void*)&block_valid_mask,
(void*)&partition_kv,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
tmp_v, tmp_s, kv_partition_info.chunk_indptr, o, lse,
kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream));
}
});
});
return cudaSuccess;
}
Expand Down
Loading