diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index accbb09858fa..b5321f748e6b 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -21,6 +21,7 @@ #include #include #include +#include #include #include namespace cg = cooperative_groups; @@ -28,7 +29,6 @@ namespace cg = cooperative_groups; namespace vllm { namespace moe { -constexpr float kNegInfinity = INFINITY * -1; constexpr unsigned FULL_WARP_MASK = 0xffffffff; constexpr int32_t WARP_SIZE = 32; constexpr int32_t BLOCK_SIZE = 512; @@ -411,14 +411,21 @@ __device__ inline float cuda_cast(__nv_bfloat16 val) { return __bfloat162float(val); } +template +__device__ inline T neg_inf() { + // cuda::std::numeric_limits::infinity() returns `0` for [T=bf16 or fp16] + // so we need to cast from fp32 + return cuda_cast(-cuda::std::numeric_limits::infinity()); +} + template __device__ void topk_with_k2(T* output, T const* input, cg::thread_block_tile<32> const& tile, int32_t const lane_id, int const num_experts_per_group) { // Get the top2 per thread - T largest = -INFINITY; - T second_largest = -INFINITY; + T largest = neg_inf(); + T second_largest = neg_inf(); if (num_experts_per_group > WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { @@ -513,8 +520,8 @@ __global__ void group_idx_and_topk_idx_kernel( warp_id * topk; s_topk_idx += warp_id * topk; - T value = kNegInfinity; - T topk_group_value = kNegInfinity; + T value = neg_inf(); + T topk_group_value = neg_inf(); int32_t num_equalto_topkth_group; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -525,11 +532,8 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { // calculate group_idx int32_t target_num_min = WARP_SIZE - n_group + topk_group; - if (lane_id < n_group && - (isfinite(cuda_cast( - group_scores[lane_id])))) // The check is necessary to avoid - // abnormal input - { + // The check is necessary to avoid abnormal input + if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) { value = group_scores[lane_id]; } @@ -540,11 +544,11 @@ __global__ void group_idx_and_topk_idx_kernel( __syncwarp(); // Ensure all threads have valid data before reduction topk_group_value = cg::reduce(tile, value, cg::greater()); if (value == topk_group_value) { - value = kNegInfinity; + value = neg_inf(); } pre_count_equal_to_top_value = count_equal_to_top_value; - count_equal_to_top_value = __popc(__ballot_sync( - FULL_WARP_MASK, (value == cuda_cast(kNegInfinity)))); + count_equal_to_top_value = + __popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf()))); } num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; } @@ -552,11 +556,10 @@ __global__ void group_idx_and_topk_idx_kernel( warp_topk::WarpSelect - queue((int32_t)topk, -INFINITY); + queue((int32_t)topk, neg_inf()); int count_equalto_topkth_group = 0; - bool if_proceed_next_topk = - (topk_group_value != cuda_cast(kNegInfinity)); + bool if_proceed_next_topk = topk_group_value != neg_inf(); if (case_id < num_tokens && if_proceed_next_topk) { for (int i_group = 0; i_group < n_group; i_group++) { if ((group_scores[i_group] > topk_group_value) || @@ -566,10 +569,10 @@ __global__ void group_idx_and_topk_idx_kernel( for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) { T candidates = - (i < num_experts_per_group) && isfinite(cuda_cast( - scores_with_bias[offset + i])) + (i < num_experts_per_group) && + cuda::std::isfinite(scores_with_bias[offset + i]) ? scores_with_bias[offset + i] - : cuda_cast(kNegInfinity); + : neg_inf(); queue.add(candidates, offset + i); } if (group_scores[i_group] == topk_group_value) { @@ -598,7 +601,8 @@ __global__ void group_idx_and_topk_idx_kernel( if (i < topk) { s_topk_value[i] = value; } - topk_sum += reduce(tile, cuda_cast(value), cg::plus()); + topk_sum += + cg::reduce(tile, cuda_cast(value), cg::plus()); } }