diff --git a/src/test_cascade.cu b/src/test_cascade.cu index 103f9dbe..ae744e27 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -136,7 +136,6 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads, template void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads, size_t head_dim, bool sparse_s) { - EXPECT_GT(num_index_sets, 1) << "num_index_sets must be greater than 1"; std::vector V_host(seq_len * num_index_sets * num_heads * head_dim); std::vector V_host_trans_f32(num_index_sets * seq_len * num_heads * head_dim); std::vector S_host(seq_len * num_index_sets * num_heads); @@ -178,6 +177,7 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n thrust::device_vector V_merged_1_device(seq_len * num_heads * head_dim); thrust::device_vector S_merged_1_device(seq_len * num_heads); + if (num_index_sets > 1) { // Method 0: use MergeState MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()), thrust::raw_pointer_cast(S_device_trans.data()), @@ -193,6 +193,10 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len, num_heads, head_dim); } + } else { + V_merged_0_device = V_device; + S_merged_0_device = S_device; + } // Method 1: use MergeStates MergeStates(thrust::raw_pointer_cast(V_device.data()), thrust::raw_pointer_cast(S_device.data()), @@ -478,7 +482,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size, template void TestMergeKernelCorrectness() { - for (size_t num_index_sets : {2, 9, 81, 513}) { + for (size_t num_index_sets : {1, 2, 9, 81, 513}) { for (size_t seq_len : {4, 16, 77}) { for (size_t num_heads : {1, 21, 32}) { for (size_t head_dim : {64, 128, 256}) {