diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index a5ddd8d4..5445156b 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -157,9 +157,29 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t pos = blockIdx.x; uint32_t head_idx = ty; - state_t st; + + if (num_index_sets == 0) { + vec_t v; + v.fill(DTypeOut(0)); + v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = -5e4; + } + return; + } + + if (num_index_sets == 1) { + vec_t v; + v.cast_load(V + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = S[pos * num_heads + head_idx]; + } + return; + } vec_t v_merged_vec; + state_t st; v_merged_vec.fill(0.f); #pragma unroll 2 for (uint32_t iter = 0; iter < num_index_sets; ++iter) { @@ -296,6 +316,25 @@ __global__ void VariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float* float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn)); const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; + if (num_index_sets == 0) { + vec_t v; + v.fill(DTypeOut(0)); + v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = -5e4; + } + return; + } + + if (num_index_sets == 1) { + vec_t v; + v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); + v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx]; + } + } + #pragma unroll for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { cp_async::pred_load( diff --git a/src/test_cascade.cu b/src/test_cascade.cu index 4bbce95a..5d7d28cd 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -136,7 +136,6 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads, template void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads, size_t head_dim, bool sparse_s) { - EXPECT_GT(num_index_sets, 1) << "num_index_sets must be greater than 1"; std::vector V_host(seq_len * num_index_sets * num_heads * head_dim); std::vector V_host_trans_f32(num_index_sets * seq_len * num_heads * head_dim); std::vector S_host(seq_len * num_index_sets * num_heads); @@ -178,20 +177,25 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n thrust::device_vector V_merged_1_device(seq_len * num_heads * head_dim); thrust::device_vector S_merged_1_device(seq_len * num_heads); - // Method 0: use MergeState - MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()), - thrust::raw_pointer_cast(S_device_trans.data()), - thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim), - thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads), - thrust::raw_pointer_cast(V_merged_0_device.data()), - thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim); - for (uint i = 2; i < num_index_sets; ++i) { - MergeStateInPlace( - thrust::raw_pointer_cast(V_merged_0_device.data()), - thrust::raw_pointer_cast(S_merged_0_device.data()), - thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim), - thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len, - num_heads, head_dim); + if (num_index_sets > 1) { + // Method 0: use MergeState + MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()), + thrust::raw_pointer_cast(S_device_trans.data()), + thrust::raw_pointer_cast(V_device_trans_f32.data() + seq_len * num_heads * head_dim), + thrust::raw_pointer_cast(S_device_trans.data() + seq_len * num_heads), + thrust::raw_pointer_cast(V_merged_0_device.data()), + thrust::raw_pointer_cast(S_merged_0_device.data()), seq_len, num_heads, head_dim); + for (uint i = 2; i < num_index_sets; ++i) { + MergeStateInPlace( + thrust::raw_pointer_cast(V_merged_0_device.data()), + thrust::raw_pointer_cast(S_merged_0_device.data()), + thrust::raw_pointer_cast(V_device_trans_f32.data() + i * seq_len * num_heads * head_dim), + thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len, + num_heads, head_dim); + } + } else { + V_merged_0_device = V_device; + S_merged_0_device = S_device; } // Method 1: use MergeStates @@ -479,7 +483,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, template void TestMergeKernelCorrectness() { - for (size_t num_index_sets : {2, 9, 81, 513}) { + for (size_t num_index_sets : {1, 2, 9, 81, 513}) { for (size_t seq_len : {4, 16, 77}) { for (size_t num_heads : {1, 21, 32}) { for (size_t head_dim : {64, 128, 256}) {