From 1f6710600cce240d5f71d07c3ca666343f10421d Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 21 Aug 2024 09:20:18 +0000 Subject: [PATCH 1/3] upd --- include/flashinfer/attention/cascade.cuh | 142 ++++++++++++----------- 1 file changed, 77 insertions(+), 65 deletions(-) diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index 5445156b..4964d4de 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -301,85 +301,91 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa */ template -__global__ void VariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S, - IdType* indptr, DTypeOut* __restrict__ v_merged, - float* __restrict__ s_merged, uint32_t num_heads) { +__global__ void PersistentVariableLengthMergeStatesKernel( + DTypeIn* __restrict__ V, float* __restrict__ S, IdType* indptr, DTypeOut* __restrict__ v_merged, + float* __restrict__ s_merged, uint32_t seq_len, uint32_t num_heads, uint32_t num_sms) { uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t pos = blockIdx.x; - uint32_t head_idx = blockIdx.y; - state_t st; + uint32_t sm_id = blockIdx.x; + uint32_t num_iters = ceil_div(seq_len * num_heads, num_sms); constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; constexpr uint32_t head_dim = vec_size * bdx; - extern __shared__ uint8_t smem[]; DTypeIn* v_smem = (DTypeIn*)smem; float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn)); - const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; - if (num_index_sets == 0) { - vec_t v; - v.fill(DTypeOut(0)); - v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); - if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = -5e4; +#pragma unroll 1 + for (uint32_t i = sm_id; i < seq_len * num_heads; i += num_sms) { + uint32_t pos = i / num_heads; + uint32_t head_idx = i % num_heads; + state_t st; + const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; + + if (num_index_sets == 0) { + vec_t v; + v.fill(DTypeOut(0)); + v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = -5e4; + } + continue; } - return; - } - if (num_index_sets == 1) { - vec_t v; - v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); - v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); - if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx]; + if (num_index_sets == 1) { + vec_t v; + v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); + v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx]; + } + continue; } - } #pragma unroll - for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { - cp_async::pred_load( - v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, - V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, - (iter * bdy + ty) < num_index_sets); - cp_async::commit_group(); - } + for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { + cp_async::pred_load( + v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, + V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, + (iter * bdy + ty) < num_index_sets); + cp_async::commit_group(); + } #pragma unroll 4 - for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) { - if (iter % bdx == 0) { - s_smem[ty * bdx + tx] = - iter * bdy + (ty * bdx + tx) < num_index_sets - ? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx] - : 0.f; + for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) { + if (iter % bdx == 0) { + s_smem[ty * bdx + tx] = + iter * bdy + (ty * bdx + tx) < num_index_sets + ? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx] + : 0.f; + __syncthreads(); + } + cp_async::wait_group(); __syncthreads(); + vec_t v; + v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size); + if (iter * bdy + ty < num_index_sets) { + float s = s_smem[(iter % bdx) * bdy + ty]; + st.merge(v, s, 1); + } + __syncthreads(); + cp_async::pred_load( + v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, + V + + ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * + head_dim + + tx * vec_size, + (iter + num_smem_stages) * bdy + ty < num_index_sets); + cp_async::commit_group(); } - cp_async::wait_group(); - __syncthreads(); - vec_t v; - v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size); - if (iter * bdy + ty < num_index_sets) { - float s = s_smem[(iter % bdx) * bdy + ty]; - st.merge(v, s, 1); - } + cp_async::wait_group<0>(); __syncthreads(); - cp_async::pred_load( - v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, - V + - ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * - head_dim + - tx * vec_size, - (iter + num_smem_stages) * bdy + ty < num_index_sets); - cp_async::commit_group(); - } - cp_async::wait_group<0>(); - __syncthreads(); - st.normalize(); - threadblock_sync_state(st, v_smem, s_smem); - st.normalize(); + st.normalize(); + threadblock_sync_state(st, v_smem, s_smem); + st.normalize(); - st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); - if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = st.get_lse(); + st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = st.get_lse(); + } } } @@ -502,17 +508,23 @@ template cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeOut* v_merged, float* s_merged, uint32_t seq_len, uint32_t num_heads, uint32_t head_dim, cudaStream_t stream = nullptr) { + int dev_id = 0; + int num_sms = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); constexpr uint32_t bdx = HEAD_DIM / vec_size; constexpr uint32_t num_threads = 128; constexpr uint32_t bdy = num_threads / bdx; - dim3 nblks(seq_len, num_heads); + + dim3 nblks(num_sms); dim3 nthrs(bdx, bdy); constexpr uint32_t num_smem_stages = 4; - auto kernel = VariableLengthMergeStatesKernel; - void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &num_heads}; + auto kernel = PersistentVariableLengthMergeStatesKernel; + void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &seq_len, &num_heads, &num_sms}; uint32_t smem_size = num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + num_threads * sizeof(float); FLASHINFER_CUDA_CALL( From 8aa48baeb302c2dd6d21e41b749766cbd438badf Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 21 Aug 2024 09:29:06 +0000 Subject: [PATCH 2/3] upd --- include/flashinfer/attention/cascade.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index 4964d4de..2e3e4e3b 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -511,7 +511,7 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp int dev_id = 0; int num_sms = 0; FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); From 064abb8433f7560fcaa3029a50b3d431074184fc Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 21 Aug 2024 18:39:38 +0000 Subject: [PATCH 3/3] upd --- include/flashinfer/attention/cascade.cuh | 31 +++++++++++++++--------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index 2e3e4e3b..2e678ec6 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -301,12 +301,15 @@ __global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, floa */ template -__global__ void PersistentVariableLengthMergeStatesKernel( - DTypeIn* __restrict__ V, float* __restrict__ S, IdType* indptr, DTypeOut* __restrict__ v_merged, - float* __restrict__ s_merged, uint32_t seq_len, uint32_t num_heads, uint32_t num_sms) { +__global__ void PersistentVariableLengthMergeStatesKernel(DTypeIn* __restrict__ V, + float* __restrict__ S, IdType* indptr, + DTypeOut* __restrict__ v_merged, + float* __restrict__ s_merged, + uint32_t seq_len, uint32_t num_heads) { uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t sm_id = blockIdx.x; - uint32_t num_iters = ceil_div(seq_len * num_heads, num_sms); + uint32_t cta_id = blockIdx.x; + uint32_t num_ctas = gridDim.x; + uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas); constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; constexpr uint32_t head_dim = vec_size * bdx; extern __shared__ uint8_t smem[]; @@ -314,7 +317,7 @@ __global__ void PersistentVariableLengthMergeStatesKernel( float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn)); #pragma unroll 1 - for (uint32_t i = sm_id; i < seq_len * num_heads; i += num_sms) { + for (uint32_t i = cta_id; i < seq_len * num_heads; i += num_ctas) { uint32_t pos = i / num_heads; uint32_t head_idx = i % num_heads; state_t st; @@ -510,6 +513,7 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp uint32_t head_dim, cudaStream_t stream = nullptr) { int dev_id = 0; int num_sms = 0; + int num_blocks_per_sm = 0; FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); @@ -518,15 +522,18 @@ cudaError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTyp constexpr uint32_t bdx = HEAD_DIM / vec_size; constexpr uint32_t num_threads = 128; constexpr uint32_t bdy = num_threads / bdx; - - dim3 nblks(num_sms); - dim3 nthrs(bdx, bdy); constexpr uint32_t num_smem_stages = 4; - auto kernel = PersistentVariableLengthMergeStatesKernel; - void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &seq_len, &num_heads, &num_sms}; uint32_t smem_size = num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + num_threads * sizeof(float); + auto kernel = PersistentVariableLengthMergeStatesKernel; + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); + num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms)); + + dim3 nblks(num_sms * num_blocks_per_sm); + dim3 nthrs(bdx, bdy); + void* args[] = {&v, &s, &indptr, &v_merged, &s_merged, &seq_len, &num_heads}; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));