Skip to content

Commit 0081c69

Browse files
minosfuturerahul-tuliclaudeyewentao256houseroad
committed
Use macro guard CUDA functions for back compatibility in grouped_topk_kernel.cu (#25346)
Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
1 parent 6462fee commit 0081c69

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

csrc/moe/grouped_topk_kernels.cu

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,15 @@ __device__ inline T neg_inf() {
418418
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
419419
}
420420

421+
template <typename T>
422+
__device__ inline bool is_finite(const T val) {
423+
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
424+
return cuda::std::isfinite(val);
425+
#else
426+
return isfinite(cuda_cast<float, T>(val));
427+
#endif
428+
}
429+
421430
template <typename T>
422431
__device__ void topk_with_k2(T* output, T const* input,
423432
cg::thread_block_tile<32> const& tile,
@@ -533,7 +542,7 @@ __global__ void group_idx_and_topk_idx_kernel(
533542
// calculate group_idx
534543
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
535544
// The check is necessary to avoid abnormal input
536-
if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) {
545+
if (lane_id < n_group && is_finite(group_scores[lane_id])) {
537546
value = group_scores[lane_id];
538547
}
539548

@@ -568,11 +577,10 @@ __global__ void group_idx_and_topk_idx_kernel(
568577
int32_t offset = i_group * num_experts_per_group;
569578
for (int32_t i = lane_id; i < align_num_experts_per_group;
570579
i += WARP_SIZE) {
571-
T candidates =
572-
(i < num_experts_per_group) &&
573-
cuda::std::isfinite(scores_with_bias[offset + i])
574-
? scores_with_bias[offset + i]
575-
: neg_inf<T>();
580+
T candidates = (i < num_experts_per_group) &&
581+
is_finite(scores_with_bias[offset + i])
582+
? scores_with_bias[offset + i]
583+
: neg_inf<T>();
576584
queue.add(candidates, offset + i);
577585
}
578586
if (group_scores[i_group] == topk_group_value) {

0 commit comments

Comments
 (0)