Skip to content

Commit

Permalink
perf: slight optimization on merge states (#313)
Browse files Browse the repository at this point in the history
When cudagraph is enabled, we will still call merge states kernels for
short sequence length, which incurs some unnecessary overhead.

This PR accelerates merge states kernel when there is nothing to merge
(`num_index_sets=1`).

We can actually write through to the target buffer for small sequence
length, but I'm always lazy evaluated and I'll leave it for a future PR
(if necessary).
  • Loading branch information
yzh119 authored Jul 24, 2024
1 parent 2ab2bca commit 701c813
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 17 deletions.
41 changes: 40 additions & 1 deletion include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,29 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t pos = blockIdx.x;
uint32_t head_idx = ty;
state_t<vec_size> st;

if (num_index_sets == 0) {
vec_t<DTypeOut, vec_size> v;
v.fill(DTypeOut(0));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = -5e4;
}
return;
}

if (num_index_sets == 1) {
vec_t<DTypeOut, vec_size> v;
v.cast_load(V + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = S[pos * num_heads + head_idx];
}
return;
}

vec_t<float, vec_size> v_merged_vec;
state_t<vec_size> st;
v_merged_vec.fill(0.f);
#pragma unroll 2
for (uint32_t iter = 0; iter < num_index_sets; ++iter) {
Expand Down Expand Up @@ -296,6 +316,25 @@ __global__ void VariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float*
float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn));
const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos];

if (num_index_sets == 0) {
vec_t<DTypeOut, vec_size> v;
v.fill(DTypeOut(0));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = -5e4;
}
return;
}

if (num_index_sets == 1) {
vec_t<DTypeOut, vec_size> v;
v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
}
}

#pragma unroll
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
Expand Down
36 changes: 20 additions & 16 deletions src/test_cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
template <typename T>
void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads,
size_t head_dim, bool sparse_s) {
EXPECT_GT(num_index_sets, 1) << "num_index_sets must be greater than 1";
std::vector<T> V_host(seq_len * num_index_sets * num_heads * head_dim);
std::vector<float> V_host_trans_f32(num_index_sets * seq_len * num_heads * head_dim);
std::vector<float> S_host(seq_len * num_index_sets * num_heads);
Expand Down Expand Up @@ -178,20 +177,25 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n
thrust::device_vector<T> V_merged_1_device(seq_len * num_heads * head_dim);
thrust::device_vector<float> S_merged_1_device(seq_len * num_heads);

// Method 0: use MergeState
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
thrust::raw_pointer_cast(S_device_trans.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
for (uint i = 2; i < num_index_sets; ++i) {
MergeStateInPlace(
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
num_heads, head_dim);
if (num_index_sets > 1) {
// Method 0: use MergeState
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
thrust::raw_pointer_cast(S_device_trans.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
for (uint i = 2; i < num_index_sets; ++i) {
MergeStateInPlace(
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
num_heads, head_dim);
}
} else {
V_merged_0_device = V_device;
S_merged_0_device = S_device;
}

// Method 1: use MergeStates
Expand Down Expand Up @@ -479,7 +483,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size,

template <typename T>
void TestMergeKernelCorrectness() {
for (size_t num_index_sets : {2, 9, 81, 513}) {
for (size_t num_index_sets : {1, 2, 9, 81, 513}) {
for (size_t seq_len : {4, 16, 77}) {
for (size_t num_heads : {1, 21, 32}) {
for (size_t head_dim : {64, 128, 256}) {
Expand Down

0 comments on commit 701c813

Please sign in to comment.