Skip to content

Commit

Permalink
你看 又急
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Jul 24, 2024
1 parent 122c048 commit fb16238
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/test_cascade.cu
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads,
template <typename T>
void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads,
size_t head_dim, bool sparse_s) {
EXPECT_GT(num_index_sets, 1) << "num_index_sets must be greater than 1";
std::vector<T> V_host(seq_len * num_index_sets * num_heads * head_dim);
std::vector<float> V_host_trans_f32(num_index_sets * seq_len * num_heads * head_dim);
std::vector<float> S_host(seq_len * num_index_sets * num_heads);
Expand Down Expand Up @@ -178,6 +177,7 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n
thrust::device_vector<T> V_merged_1_device(seq_len * num_heads * head_dim);
thrust::device_vector<float> S_merged_1_device(seq_len * num_heads);

if (num_index_sets > 1) {
// Method 0: use MergeState
MergeState(thrust::raw_pointer_cast(V_device_trans_f32.data()),
thrust::raw_pointer_cast(S_device_trans.data()),
Expand All @@ -193,6 +193,10 @@ void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t n
thrust::raw_pointer_cast(S_device_trans.data() + i * seq_len * num_heads), seq_len,
num_heads, head_dim);
}
} else {
V_merged_0_device = V_device;
S_merged_0_device = S_device;
}

// Method 1: use MergeStates
MergeStates(thrust::raw_pointer_cast(V_device.data()), thrust::raw_pointer_cast(S_device.data()),
Expand Down Expand Up @@ -479,7 +483,7 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size,

template <typename T>
void TestMergeKernelCorrectness() {
for (size_t num_index_sets : {2, 9, 81, 513}) {
for (size_t num_index_sets : {1, 2, 9, 81, 513}) {
for (size_t seq_len : {4, 16, 77}) {
for (size_t num_heads : {1, 21, 32}) {
for (size_t head_dim : {64, 128, 256}) {
Expand Down

0 comments on commit fb16238

Please sign in to comment.