Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions csrc/moe/grouped_topk_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
#include <torch/all.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda/std/limits>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups;

namespace vllm {
namespace moe {

constexpr float kNegInfinity = INFINITY * -1;
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
constexpr int32_t WARP_SIZE = 32;
constexpr int32_t BLOCK_SIZE = 512;
Expand Down Expand Up @@ -411,14 +411,21 @@ __device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
return __bfloat162float(val);
}

template <typename T>
__device__ inline T neg_inf() {
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
// so we need to cast from fp32
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
}
Comment on lines +414 to +419
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The workaround for the cuda::std::numeric_limits bug is noted. To improve long-term maintainability and ensure the code automatically benefits from future bug fixes in the CUDA toolkit, it's best to scope this workaround to the specific compiler versions where the bug is present using preprocessor directives. This will allow the simpler and more direct cuda::std::numeric_limits<T>::lowest() to be used once the underlying bug is resolved.

Additionally, using cuda::std::numeric_limits<float>::lowest() is more idiomatic for getting negative infinity than -cuda::std::numeric_limits<float>::infinity().

#if (defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)
// Workaround for a bug in libcu++ in CUDA 12.8+ where numeric_limits for
// fp16/bf16 types return 0.
template <typename T>
__device__ inline T neg_inf() {
  // cuda::std::numeric_limits<T>::infinity() and lowest() return `0` for [T=bf16 or fp16]
  // so we need to cast from fp32.
  return cuda_cast<T, float>(cuda::std::numeric_limits<float>::lowest());
}
#else
// Default implementation for compilers where the bug is not present.
template <typename T>
__device__ inline T neg_inf() {
  return cuda::std::numeric_limits<T>::lowest();
}
#endif


template <typename T>
__device__ void topk_with_k2(T* output, T const* input,
cg::thread_block_tile<32> const& tile,
int32_t const lane_id,
int const num_experts_per_group) {
// Get the top2 per thread
T largest = -INFINITY;
T second_largest = -INFINITY;
T largest = neg_inf<T>();
T second_largest = neg_inf<T>();

if (num_experts_per_group > WARP_SIZE) {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
Expand Down Expand Up @@ -513,8 +520,8 @@ __global__ void group_idx_and_topk_idx_kernel(
warp_id * topk;
s_topk_idx += warp_id * topk;

T value = kNegInfinity;
T topk_group_value = kNegInfinity;
T value = neg_inf<T>();
T topk_group_value = neg_inf<T>();
int32_t num_equalto_topkth_group;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
Expand All @@ -525,11 +532,8 @@ __global__ void group_idx_and_topk_idx_kernel(
if (case_id < num_tokens) {
// calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
if (lane_id < n_group &&
(isfinite(cuda_cast<float, T>(
group_scores[lane_id])))) // The check is necessary to avoid
// abnormal input
{
// The check is necessary to avoid abnormal input
if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) {
value = group_scores[lane_id];
}

Expand All @@ -540,23 +544,22 @@ __global__ void group_idx_and_topk_idx_kernel(
__syncwarp(); // Ensure all threads have valid data before reduction
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) {
value = kNegInfinity;
value = neg_inf<T>();
}
pre_count_equal_to_top_value = count_equal_to_top_value;
count_equal_to_top_value = __popc(__ballot_sync(
FULL_WARP_MASK, (value == cuda_cast<T, float>(kNegInfinity))));
count_equal_to_top_value =
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
}
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
}
__syncthreads();

warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
/* is_stable */ true>
queue((int32_t)topk, -INFINITY);
queue((int32_t)topk, neg_inf<T>());

int count_equalto_topkth_group = 0;
bool if_proceed_next_topk =
(topk_group_value != cuda_cast<T, float>(kNegInfinity));
bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
if (case_id < num_tokens && if_proceed_next_topk) {
for (int i_group = 0; i_group < n_group; i_group++) {
if ((group_scores[i_group] > topk_group_value) ||
Expand All @@ -566,10 +569,10 @@ __global__ void group_idx_and_topk_idx_kernel(
for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) {
T candidates =
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
scores_with_bias[offset + i]))
(i < num_experts_per_group) &&
cuda::std::isfinite(scores_with_bias[offset + i])
? scores_with_bias[offset + i]
: cuda_cast<T, float>(kNegInfinity);
: neg_inf<T>();
queue.add(candidates, offset + i);
}
if (group_scores[i_group] == topk_group_value) {
Expand Down Expand Up @@ -598,7 +601,8 @@ __global__ void group_idx_and_topk_idx_kernel(
if (i < topk) {
s_topk_value[i] = value;
}
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
topk_sum +=
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
}
}

Expand Down