Skip to content

Commit ebd067a

Browse files
authored
[Feature] Support variable length merge states (#71)
Also, reorder the input dimension of `MergeStates` to be consistent with variable length version.
1 parent 88b9496 commit ebd067a

File tree

6 files changed

+331
-93
lines changed

6 files changed

+331
-93
lines changed

include/flashinfer/cascade.cuh

Lines changed: 126 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ __device__ __forceinline__ void threadblock_sync_state(state_t<vec_size>& st, DT
137137
* \param vec_size The vector size used in the kernel.
138138
* \tparam DTypeIn The data type of v.
139139
* \tparam DTypeOut The data type of v_merged.
140-
* \param v The partial v of index sets. (num_index_sets, n, h, d)
141-
* \param s The logsumexp value of index sets. (num_index_sets, n, h)
140+
* \param v The partial v of index sets. (n, num_index_sets, h, d)
141+
* \param s The logsumexp value of index sets. (n, num_index_sets, h)
142142
* \param v_merged The merged v of index sets union. (n, h, d)
143143
* \param s_merged The merged logsumexp value of index sets union. (n, h)
144144
* \param num_heads The number of heads of v.
@@ -150,7 +150,6 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
150150
DTypeOut* __restrict__ v_merged, float* __restrict__ s_merged,
151151
uint32_t num_index_sets, uint32_t num_heads, uint32_t head_dim) {
152152
uint32_t tx = threadIdx.x, ty = threadIdx.y;
153-
uint32_t seq_len = gridDim.x;
154153
uint32_t pos = blockIdx.x;
155154
uint32_t head_idx = ty;
156155
state_t<vec_size> st;
@@ -159,9 +158,10 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
159158
v_merged_vec.fill(0.f);
160159
#pragma unroll 2
161160
for (uint32_t iter = 0; iter < num_index_sets; ++iter) {
162-
float s = S[(iter * seq_len + pos) * num_heads + head_idx];
161+
float s = S[(pos * num_index_sets + iter) * num_heads + head_idx];
163162
vec_t<float, vec_size> v;
164-
v.cast_load(V + ((iter * seq_len + pos) * num_heads + head_idx) * head_dim + tx * vec_size);
163+
v.cast_load(V + ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim +
164+
tx * vec_size);
165165
st.merge(v, s, 1);
166166
}
167167

@@ -175,14 +175,14 @@ __global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S
175175
/*!
176176
* \brief The CUDA kernel that merges self-attention states of a list of index sets,
177177
* accelerated for larget number of index sets.
178-
* \param vec_size The vector size used in the kernel.
179-
* \param bdx The blockDim.x used in the kernel.
180-
* \param bdy The blockDim.y used in the kernel.
181-
* \param num_smem_stages The number of stages of shared memory used in the kernel.
178+
* \tparam vec_size The vector size used in the kernel.
179+
* \tparam bdx The blockDim.x used in the kernel.
180+
* \tparam bdy The blockDim.y used in the kernel.
181+
* \tparam num_smem_stages The number of stages of shared memory used in the kernel.
182182
* \tparam DTypeIn The data type of v.
183183
* \tparam DTypeOut The data type of v_merged.
184-
* \param v The partial v of index sets. (num_index_sets, n, h, d)
185-
* \param s The logsumexp value of index sets. (num_index_sets, n, h)
184+
* \param V The partial v of index sets. (n, num_index_sets, h, d)
185+
* \param S The logsumexp value of index sets. (n, num_index_sets, h)
186186
* \param v_merged The merged v of index sets union. (n, h, d)
187187
* \param s_merged The merged logsumexp value of index sets union. (n, h)
188188
* \param num_heads The number of heads of v.
@@ -196,9 +196,8 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
196196
float* __restrict__ s_merged,
197197
uint32_t num_index_sets, uint32_t num_heads) {
198198
uint32_t tx = threadIdx.x, ty = threadIdx.y;
199-
uint32_t seq_len = gridDim.y;
200-
uint32_t pos = blockIdx.y;
201-
uint32_t head_idx = blockIdx.x;
199+
uint32_t pos = blockIdx.x;
200+
uint32_t head_idx = blockIdx.y;
202201
state_t<vec_size> st;
203202
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
204203
constexpr uint32_t head_dim = vec_size * bdx;
@@ -211,7 +210,8 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
211210
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
212211
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
213212
v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
214-
V + (((iter * bdy + ty) * seq_len + pos) * num_heads + head_idx) * head_dim + tx * vec_size,
213+
V + ((pos * num_index_sets + (iter * bdy + ty)) * num_heads + head_idx) * head_dim +
214+
tx * vec_size,
215215
(iter * bdy + ty) < num_index_sets);
216216
cp_async::commit_group();
217217
}
@@ -220,27 +220,111 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa
220220
if (iter % bdx == 0) {
221221
s_smem[ty * bdx + tx] =
222222
iter * bdy + (ty * bdx + tx) < num_index_sets
223-
? S[((iter * bdy + ty * bdx + tx) * seq_len + pos) * num_heads + head_idx]
223+
? S[(pos * num_index_sets + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
224224
: 0.f;
225225
__syncthreads();
226226
}
227227
cp_async::wait_group<num_smem_stages - 1>();
228228
__syncthreads();
229229
vec_t<float, vec_size> v;
230230
v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
231+
if (iter * bdy + ty < num_index_sets) {
232+
float s = s_smem[(iter % bdx) * bdy + ty];
233+
st.merge(v, s, 1);
234+
}
231235
__syncthreads();
232236
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
233237
v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
234238
V +
235-
((((iter + num_smem_stages) * bdy + ty) * seq_len + pos) * num_heads + head_idx) *
239+
((pos * num_index_sets + ((iter + num_smem_stages) * bdy + ty)) * num_heads +
240+
head_idx) *
236241
head_dim +
237242
tx * vec_size,
238243
(iter + num_smem_stages) * bdy + ty < num_index_sets);
239244
cp_async::commit_group();
245+
}
246+
cp_async::wait_group<0>();
247+
__syncthreads();
248+
249+
st.normalize();
250+
threadblock_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem);
251+
st.normalize();
252+
253+
st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size);
254+
if (s_merged != nullptr) {
255+
s_merged[pos * num_heads + head_idx] = st.get_lse();
256+
}
257+
}
258+
259+
/*!
260+
* \brief The CUDA kernel to merge self-attention states of multiple index sets, the number of index
261+
* sets at each position might vary.
262+
* \tparam vec_size The vector size used in the kernel.
263+
* \tparam bdx The blockDim.x used in the kernel.
264+
* \tparam bdy The blockDim.y used in the kernel.
265+
* \tparam num_smem_stages The number of stages of shared memory used in the kernel.
266+
* \tparam DTypeIn The data type of v.
267+
* \tparam DTypeOut The data type of v_merged.
268+
* \param V The partial v of index sets. (nnz, h, d)
269+
* \param S The logsumexp value of index sets. (nnz, h)
270+
* \param indptr The start offsets of each position in the variable length array.
271+
* \param v_merged The merged v of index sets union. (n, h, d)
272+
* \param s_merged The merged logsumexp value of index sets union. (n, h)
273+
* \param num_heads The number of heads of v.
274+
* \param head_dim The dimension of each head.
275+
* \note s are logsumexp values with base 2.
276+
*/
277+
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t num_smem_stages, typename DTypeIn,
278+
typename DTypeOut, typename IdType>
279+
__global__ void VariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S,
280+
IdType* indptr, DTypeOut* __restrict__ v_merged,
281+
float* __restrict__ s_merged, uint32_t num_heads) {
282+
uint32_t tx = threadIdx.x, ty = threadIdx.y;
283+
uint32_t pos = blockIdx.x;
284+
uint32_t head_idx = blockIdx.y;
285+
state_t<vec_size> st;
286+
constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
287+
constexpr uint32_t head_dim = vec_size * bdx;
288+
289+
extern __shared__ uint8_t smem[];
290+
DTypeIn* v_smem = (DTypeIn*)smem;
291+
float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn));
292+
const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos];
293+
294+
#pragma unroll
295+
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
296+
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
297+
v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
298+
V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size,
299+
(iter * bdy + ty) < num_index_sets);
300+
cp_async::commit_group();
301+
}
302+
#pragma unroll 4
303+
for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) {
304+
if (iter % bdx == 0) {
305+
s_smem[ty * bdx + tx] =
306+
iter * bdy + (ty * bdx + tx) < num_index_sets
307+
? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx]
308+
: 0.f;
309+
__syncthreads();
310+
}
311+
cp_async::wait_group<num_smem_stages - 1>();
312+
__syncthreads();
313+
vec_t<float, vec_size> v;
314+
v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
240315
if (iter * bdy + ty < num_index_sets) {
241316
float s = s_smem[(iter % bdx) * bdy + ty];
242317
st.merge(v, s, 1);
243318
}
319+
__syncthreads();
320+
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
321+
v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
322+
V +
323+
((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) *
324+
head_dim +
325+
tx * vec_size,
326+
(iter + num_smem_stages) * bdy + ty < num_index_sets);
327+
cp_async::commit_group();
244328
}
245329
cp_async::wait_group<0>();
246330
__syncthreads();
@@ -346,7 +430,7 @@ cudaError_t MergeStates(DTypeIn* v, float* s, DTypeOut* v_merged, float* s_merge
346430
if (num_index_sets > 2 * (128 / bdx)) {
347431
constexpr uint32_t num_threads = 128;
348432
constexpr uint32_t bdy = num_threads / bdx;
349-
dim3 nblks(num_heads, seq_len);
433+
dim3 nblks(seq_len, num_heads);
350434
dim3 nthrs(bdx, bdy);
351435
constexpr uint32_t num_smem_stages = 4;
352436
auto kernel = MergeStatesLargeNumIndexSetsKernel<vec_size, bdx, bdy, num_smem_stages, DTypeIn,
@@ -369,6 +453,30 @@ cudaError_t MergeStates(DTypeIn* v, float* s, DTypeOut* v_merged, float* s_merge
369453
return cudaSuccess;
370454
}
371455

456+
template <typename DTypeIn, typename DTypeOut, typename IdType>
457+
cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeOut* v_merged,
458+
float* s_merged, uint32_t seq_len, uint32_t num_heads,
459+
uint32_t head_dim, cudaStream_t stream = nullptr) {
460+
SWITCH_HEAD_DIM(head_dim, HEAD_DIM, {
461+
constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U);
462+
constexpr uint32_t bdx = HEAD_DIM / vec_size;
463+
constexpr uint32_t num_threads = 128;
464+
constexpr uint32_t bdy = num_threads / bdx;
465+
dim3 nblks(seq_len, num_heads);
466+
dim3 nthrs(bdx, bdy);
467+
constexpr uint32_t num_smem_stages = 4;
468+
auto kernel = VariableLengthMergeStatesKernel<vec_size, bdx, bdy, num_smem_stages, DTypeIn,
469+
DTypeOut, IdType>;
470+
void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &num_heads};
471+
uint32_t smem_size =
472+
num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + num_threads * sizeof(float);
473+
FLASHINFER_CUDA_CALL(
474+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
475+
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
476+
});
477+
return cudaSuccess;
478+
}
479+
372480
} // namespace flashinfer
373481

374482
#endif // FLASHINFER_CASCADE_CUH_

include/flashinfer/prefill.cuh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -803,8 +803,8 @@ __global__ void SinglePrefillWithKVCacheKernel(
803803
DTypeIn* q_ptr_base = q + qkv_info.get_qo_elem_offset(qo_idx_base, kv_head_idx * group_size,
804804
(tx % 8) * num_elems_per_128b<DTypeIn>());
805805
DTypeOut* o_ptr_base =
806-
split_kv ? ((DTypeOut*)tmp) + chunk_idx * qo_len * qkv_info.get_num_qo_heads() * head_dim +
807-
qkv_info.get_qo_elem_offset(qo_idx_base, kv_head_idx * group_size,
806+
split_kv ? ((DTypeOut*)tmp) + chunk_idx * qkv_info.get_num_qo_heads() * head_dim +
807+
qkv_info.get_qo_elem_offset(qo_idx_base * num_chunks, kv_head_idx * group_size,
808808
(tx % 8) * num_elems_per_128b<DTypeOut>())
809809
: o + qkv_info.get_qo_elem_offset(qo_idx_base, kv_head_idx * group_size,
810810
(tx % 8) * num_elems_per_128b<DTypeOut>());
@@ -910,8 +910,9 @@ __global__ void SinglePrefillWithKVCacheKernel(
910910
normalize_d<num_frags_x, num_frags_y>(o_frag, d);
911911

912912
// write back
913-
write_o_reg_gmem<group_size, num_frags_x, num_frags_y>(o_frag, &qo_smem, o_ptr_base, qo_idx_base,
914-
qo_len, qo_n_stride, qo_h_stride);
913+
write_o_reg_gmem<group_size, num_frags_x, num_frags_y>(
914+
o_frag, &qo_smem, o_ptr_base, qo_idx_base, qo_len,
915+
split_kv ? qo_n_stride * num_chunks : qo_n_stride, qo_h_stride);
915916

916917
// write lse
917918
if (lse != nullptr || split_kv) {
@@ -926,8 +927,8 @@ __global__ void SinglePrefillWithKVCacheKernel(
926927
if (qo_idx < qo_len) {
927928
if constexpr (split_kv) {
928929
float* tmp_lse =
929-
(float*)(((DTypeOut*)tmp) + num_chunks * qo_len * num_qo_heads * head_dim);
930-
tmp_lse[(chunk_idx * qo_len + qo_idx) * num_qo_heads + qo_head_idx] =
930+
(float*)(((DTypeOut*)tmp) + qo_len * num_chunks * num_qo_heads * head_dim);
931+
tmp_lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] =
931932
math::ptx_log2(d[fx][j]) + float(m[fx][j]);
932933
} else {
933934
lse[qo_idx * num_qo_heads + qo_head_idx] = math::ptx_log2(d[fx][j]) + float(m[fx][j]);

python/csrc/cascade.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s) {
6565
CHECK_EQ(v.size(0), s.size(0));
6666
CHECK_EQ(v.size(1), s.size(1));
6767
CHECK_EQ(v.size(2), s.size(2));
68-
unsigned int num_index_sets = v.size(0);
69-
unsigned int seq_len = v.size(1);
68+
unsigned int seq_len = v.size(0);
69+
unsigned int num_index_sets = v.size(1);
7070
unsigned int num_heads = v.size(2);
7171
unsigned int head_dim = v.size(3);
7272
s = s.to(torch::kFloat32);

python/flashinfer/ops/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,10 @@ def merge_states(v: torch.Tensor, s: torch.Tensor):
265265
----------
266266
v : torch.Tensor
267267
The attention output from the KV segments.
268-
Shape: [num_kv_segments, seq_len, num_heads, head_dim]
268+
Shape: [seq_len, num_kv_segments, num_heads, head_dim]
269269
s : torch.Tensor
270270
The logsumexp value from the KV segments.
271-
Shape: [num_kv_segments, seq_len, num_heads]
271+
Shape: [seq_len, num_kv_segments, num_heads]
272272
273273
Returns
274274
-------

src/bench_cascade.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@ using namespace flashinfer;
2828
template <typename T>
2929
void bench_merge_states(nvbench::state& state) {
3030
const auto num_index_sets = state.get_int64("num_index_sets");
31-
const auto batch_size = state.get_int64("batch_size");
31+
const auto seq_len = state.get_int64("seq_len");
3232
const auto num_heads = state.get_int64("num_heads");
3333
const auto head_dim = state.get_int64("head_dim");
3434

35-
std::vector<T> V_host(num_index_sets * batch_size * num_heads * head_dim);
36-
std::vector<float> S_host(num_index_sets * batch_size * num_heads);
35+
std::vector<T> V_host(seq_len * num_index_sets * num_heads * head_dim);
36+
std::vector<float> S_host(seq_len * num_index_sets * num_heads);
3737

3838
utils::vec_normal_(V_host);
3939
utils::vec_uniform_(S_host, 5, 10);
4040

4141
thrust::device_vector<T> V_device(V_host);
4242
thrust::device_vector<float> S_device(S_host);
43-
thrust::device_vector<T> V_merged(batch_size * num_heads * head_dim);
44-
thrust::device_vector<float> S_merged(batch_size * num_heads);
43+
thrust::device_vector<T> V_merged(seq_len * num_heads * head_dim);
44+
thrust::device_vector<float> S_merged(seq_len * num_heads);
4545

4646
state.add_global_memory_reads<T>(V_host.size(), "Read");
4747
state.add_global_memory_writes<T>(V_merged.size(), "Write");
@@ -51,7 +51,7 @@ void bench_merge_states(nvbench::state& state) {
5151
cudaError_t status = MergeStates(
5252
thrust::raw_pointer_cast(V_device.data()), thrust::raw_pointer_cast(S_device.data()),
5353
thrust::raw_pointer_cast(V_merged.data()), thrust::raw_pointer_cast(S_merged.data()),
54-
num_index_sets, batch_size, num_heads, head_dim);
54+
num_index_sets, seq_len, num_heads, head_dim);
5555
timer.stop();
5656
});
5757
}

0 commit comments

Comments
 (0)