Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: slight optimization on merge states #313

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,29 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t pos = blockIdx.x;
uint32_t head_idx = ty;
state_t<vec_size> st;

if (num_index_sets == 0) {
vec_t<DTypeOut, vec_size> v;
v.fill(DTypeOut(0));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = -5e4;
}
return;
}

if (num_index_sets == 1) {
vec_t<DTypeOut, vec_size> v;
v.cast_load(V + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = S[pos * num_heads + head_idx];
}
return;
}

vec_t<float, vec_size> v_merged_vec;
state_t<vec_size> st;
v_merged_vec.fill(0.f);
#pragma unroll 2
for (uint32_t iter = 0; iter < num_index_sets; ++iter) {
Expand Down Expand Up @@ -296,6 +316,25 @@ __global__ void VariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float*
float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn));
const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos];

if (num_index_sets == 0) {
vec_t<DTypeOut, vec_size> v;
v.fill(DTypeOut(0));
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = -5e4;
}
return;
}

if (num_index_sets == 1) {
vec_t<DTypeOut, vec_size> v;
v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size);
v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx];
}
}

#pragma unroll
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
Expand Down
36 changes: 20 additions & 16 deletions src/test_cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
template <typename T>
void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads,
size_t head_dim, bool sparse_s) {
EXPECT_GT(num_index_sets, 1) << "num_index_sets must be greater than 1";
std::vector<T> V_host(seq_len * num_index_sets * num_heads * head_dim);
std::vector<float> V_host_trans_f32(num_index_sets * seq_len * num_heads * head_dim);
std::vector<float> S_host(seq_len * num_index_sets * num_heads);
Expand Down Expand Up @@ -178,20 +177,25 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n
thrust::device_vector<T> V_merged_1_device(seq_len * num_heads * head_dim);
thrust::device_vector<float> S_merged_1_device(seq_len * num_heads);

// Method 0: use MergeState
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
thrust::raw_pointer_cast(S_device_trans.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
for (uint i = 2; i < num_index_sets; ++i) {
MergeStateInPlace(
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
num_heads, head_dim);
if (num_index_sets > 1) {
// Method 0: use MergeState
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
thrust::raw_pointer_cast(S_device_trans.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads),
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim);
for (uint i = 2; i < num_index_sets; ++i) {
MergeStateInPlace(
thrust::raw_pointer_cast(V_merged_0_device.data()),
thrust::raw_pointer_cast(S_merged_0_device.data()),
thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim),
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
num_heads, head_dim);
}
} else {
V_merged_0_device = V_device;
S_merged_0_device = S_device;
}

// Method 1: use MergeStates
Expand Down Expand Up @@ -479,7 +483,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size,

template <typename T>
void TestMergeKernelCorrectness() {
for (size_t num_index_sets : {2, 9, 81, 513}) {
for (size_t num_index_sets : {1, 2, 9, 81, 513}) {
for (size_t seq_len : {4, 16, 77}) {
for (size_t num_heads : {1, 21, 32}) {
for (size_t head_dim : {64, 128, 256}) {
Expand Down