From 75878399a12afd04b1114f026f686c16e3a79d4b Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 11 Jun 2024 08:39:47 +0000 Subject: [PATCH] upd --- include/flashinfer/sampling.cuh | 303 +++++++++++++++++--------------- 1 file changed, 159 insertions(+), 144 deletions(-) diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index f29d47f4..fa62a344 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -77,6 +77,21 @@ struct SamplingTempStorage { } data; }; +template +struct RenormTempStorage { + union { + typename BlockReduce::TempStorage reduce; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_pair; + } block_prim; + struct { + T max_val; + union { + T value; + Pair pair; + } block_aggregate; + } data; +}; + template __device__ __forceinline__ void DeviceSamplingFromProb( @@ -130,6 +145,8 @@ __device__ __forceinline__ void DeviceSamplingFromProb( aggregate += aggregate_local; } +namespace SamplingKernel { + template __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdType* output, @@ -339,112 +356,9 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, } } -template -cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint32_t batch_size, - uint32_t d, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - IdType* row_indices_placeholder = nullptr; - void* args[] = {&probs, &uniform_samples, &output, &row_indices_placeholder, &d}; - const uint32_t smem_size = sizeof(SamplingTempStorage); - - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = - SamplingFromProbKernel; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); - return cudaSuccess; -} - -template -cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples, IdType* output, - IdType* row_indices, uint32_t batch_size, uint32_t d, - cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &output, &row_indices, &d}; - const uint32_t smem_size = sizeof(SamplingTempStorage); - - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = - SamplingFromProbKernel; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); - return cudaSuccess; -} - -template -cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, - IdType top_k, uint32_t batch_size, uint32_t d, - uint32_t max_top_k_rounds, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t smem_size = sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &output, &success, &top_k, &d, &max_top_k_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = - TopKSamplingFromProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); - return cudaSuccess; -} - -template -cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, - T top_p, uint32_t batch_size, uint32_t d, - uint32_t max_top_p_rounds, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); +} // namespace SamplingKernel - const uint32_t smem_size = sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - IdType* row_indices_placeholder = nullptr; - T* top_p_arr_placeholder = nullptr; - void* args[] = {&probs, - &uniform_samples, - &output, - &success, - &row_indices_placeholder, - &top_p_arr_placeholder, - &top_p, - &d, - &max_top_p_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = - TopPSamplingFromProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); - return cudaSuccess; -} - -template -struct RenormTempStorage { - union { - typename BlockReduce::TempStorage reduce; - typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_pair; - } block_prim; - struct { - T max_val; - union { - T value; - Pair pair; - } block_aggregate; - } data; -}; +namespace RenormKernel { template @@ -627,43 +541,9 @@ __global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32 } } -template -cudaError_t TopPRenormProb(DType* probs, IdType* renormed_prob, float p, float eps, - uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) { - const uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); +} // namespace RenormKernel - const uint32_t smem_size = sizeof(RenormTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &p, &eps, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopPRenormProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); - return cudaSuccess; -} - -template -cudaError_t TopKRenormProb(DType* probs, IdType* renormed_prob, uint32_t k, float eps, - uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) { - const uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); - - const uint32_t smem_size = sizeof(RenormTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &k, &eps, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopKRenormProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); - return cudaSuccess; -} +namespace SpecDecodingKernel { template @@ -764,6 +644,140 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token } } +} // namespace SpecDecodingKernel + +template +cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint32_t batch_size, + uint32_t d, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + IdType* row_indices_placeholder = nullptr; + void* args[] = {&probs, &uniform_samples, &output, &row_indices_placeholder, &d}; + const uint32_t smem_size = sizeof(SamplingTempStorage); + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = SamplingKernel::SamplingFromProbKernel; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; +} + +template +cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples, IdType* output, + IdType* row_indices, uint32_t batch_size, uint32_t d, + cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &uniform_samples, &output, &row_indices, &d}; + const uint32_t smem_size = sizeof(SamplingTempStorage); + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = SamplingKernel::SamplingFromProbKernel; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; +} + +template +cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, + IdType top_k, uint32_t batch_size, uint32_t d, + uint32_t max_top_k_rounds, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t smem_size = sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &uniform_samples, &output, &success, &top_k, &d, &max_top_k_rounds}; + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = SamplingKernel::TopKSamplingFromProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; +} + +template +cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, + T top_p, uint32_t batch_size, uint32_t d, + uint32_t max_top_p_rounds, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t smem_size = sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + IdType* row_indices_placeholder = nullptr; + T* top_p_arr_placeholder = nullptr; + void* args[] = {&probs, + &uniform_samples, + &output, + &success, + &row_indices_placeholder, + &top_p_arr_placeholder, + &top_p, + &d, + &max_top_p_rounds}; + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = SamplingKernel::TopPSamplingFromProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; +} + +template +cudaError_t TopPRenormProb(DType* probs, IdType* renormed_prob, float p, float eps, + uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) { + const uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &renormed_prob, &p, &eps, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = + RenormKernel::TopPRenormProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; +} + +template +cudaError_t TopKRenormProb(DType* probs, IdType* renormed_prob, uint32_t k, float eps, + uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) { + const uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &renormed_prob, &k, &eps, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = + RenormKernel::TopKRenormProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; +} + template cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, bool* success, IdType* row_indices, T* top_p_arr, @@ -780,8 +794,8 @@ cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* o &top_p_arr, &top_p_placeholder, &d, &max_top_p_rounds}; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = - TopPSamplingFromProbKernel; + auto kernel = SamplingKernel::TopPSamplingFromProbKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); @@ -811,7 +825,8 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids &d}; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = - ChainSpeculativeSampling; + SpecDecodingKernel::ChainSpeculativeSampling; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));