diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index d9b9af6e37..93f5657620 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -346,8 +346,11 @@ cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, int num_sm = 0; FLASHINFER_CUDA_CALL( cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads_p, smem_size_p)); + // FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &num_blocks_per_sm, kernel, num_threads_p, smem_size_p)); + // Above function returns 0 for some reason, so we use a workaround + num_blocks_per_sm = std::max( + 1, std::min((int)(max_smem_per_sm / smem_size_p), (int)(256 / num_threads_p))); uint32_t max_num_kv_chunks = (num_blocks_per_sm * num_sm) / (num_kv_heads * ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q));