From 9a58e8c168cfa11a3253f0fbe2df8e05b195cbb5 Mon Sep 17 00:00:00 2001 From: Zqy11 <841971412@qq.com> Date: Sat, 23 Aug 2025 22:28:01 +0800 Subject: [PATCH 01/11] feat: Overlap down gemm and combine send --- csrc/deep_ep.cpp | 21 +++-- csrc/deep_ep.hpp | 2 + csrc/kernels/api.cuh | 7 +- csrc/kernels/internode_ll.cu | 157 ++++++++++++++++++++++++++++------- deep_ep/buffer.py | 65 +++++++++++++++ tests/test_low_latency.py | 49 +++++++---- tests/utils.py | 4 + 7 files changed, 248 insertions(+), 57 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 24eb8169..c1d29e2c 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1139,7 +1139,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i // Allocate packed tensors auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)); - auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); @@ -1167,7 +1167,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, - packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), + packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), packed_recv_count.data_ptr(), cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr() : nullptr, dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr() : nullptr, @@ -1209,12 +1209,15 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i std::tuple, std::optional>> Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, const torch::Tensor& src_info, const torch::Tensor& layout_range, + bool overlap, const std::optional& packed_recv_count, + const std::optional& comp_signal, int block_m, int threshold, int num_sms, const std::optional& combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, const std::optional& out) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); + EP_HOST_ASSERT((!overlap || return_recv_hook) and "Overlap mode requires return_recv_hook=True"); // Tensor checks EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); @@ -1228,11 +1231,17 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32); EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous()); - EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0)); + EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt64 and x.size(0) == src_info.size(0)); EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); + if (comp_signal.has_value()) { + EP_HOST_ASSERT(comp_signal->dim() == 1 and comp_signal->is_contiguous()); + EP_HOST_ASSERT(comp_signal->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(comp_signal->size(0) == num_experts / num_ranks * ceil_div(num_ranks * num_max_dispatch_tokens_per_rank, 64)); + } + if (combine_wait_recv_cost_stats.has_value()) { EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64); EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous()); @@ -1275,13 +1284,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, buffer.combine_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), topk_weights.data_ptr(), - src_info.data_ptr(), layout_range.data_ptr(), + src_info.data_ptr(), layout_range.data_ptr(), + overlap, packed_recv_count.has_value() ? packed_recv_count->data_ptr() : nullptr, + comp_signal.has_value() ? comp_signal->data_ptr() : nullptr, block_m, threshold, combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr() : nullptr, next_clean_meta.first, next_clean_meta.second, num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, use_logfmt, - workspace, num_device_sms, + workspace, num_device_sms, num_sms, launch_stream, phases, zero_copy); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index aa62ccb0..a27757d2 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -154,6 +154,8 @@ struct Buffer { std::tuple, std::optional>> low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, const torch::Tensor& src_info, const torch::Tensor& layout_range, + bool overlap, const std::optional& packed_recv_count, + const std::optional& comp_signal, int block_m, int threshold, int num_sms, const std::optional& combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index d34775fd..401eb6f6 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -140,7 +140,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, cudaStream_t stream); void dispatch(void* packed_recv_x, void* packed_recv_x_scales, - int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, @@ -156,13 +156,14 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, void combine(void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, const void* x, const int64_t* topk_idx, const float* topk_weights, - const int* src_info, const int64_t* layout_range, + const int64_t* src_info, const int64_t* layout_range, + bool overlap, int* packed_recv_count, int* comp_signal, int block_m, int threshold, int64_t* combine_wait_recv_cost_stats, int* next_clean, int num_next_clean_int, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_logfmt, - void* workspace, int num_device_sms, + void* workspace, int num_device_sms, int num_sms, cudaStream_t stream, int phases, bool zero_copy); } // namespace internode_ll diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 4d5ee07e..7f3a9611 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -39,7 +39,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, template __global__ __launch_bounds__(1024, 1) void dispatch(void* packed_recv_x, void* packed_recv_x_scales, - int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, @@ -299,7 +299,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, // Copy source info const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); if (lane_id == 0) - recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); + recv_src_info[recv_token_begin_idx + i] = pack2(ld_nc_global(src_src_idx), src_rank); __syncwarp(); // Copy data @@ -335,7 +335,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, } void dispatch(void* packed_recv_x, void* packed_recv_x_scales, - int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int64_t* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, @@ -556,7 +556,8 @@ __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, const void* x, const int64_t* topk_idx, const float* topk_weights, - const int* src_info, const int64_t* layout_range, + const int64_t* src_info, const int64_t* layout_range, + bool overlap, int* packed_recv_count, int* comp_signal, int block_m, int threshold, int64_t* combine_wait_recv_cost_stats, int* next_clean, int num_next_clean_int, int* atomic_clean_flag, @@ -567,6 +568,7 @@ combine(void* combined_x, int phases, bool zero_copy) { const auto sm_id = __shfl_sync(0xffffffff, static_cast(blockIdx.x), 0); const auto num_sms = __shfl_sync(0xffffffff, static_cast(gridDim.x), 0); + const auto num_warps = num_warp_groups * num_warps_per_group; const auto thread_id = static_cast(threadIdx.x); const auto num_threads = __shfl_sync(0xffffffff, static_cast(blockDim.x), 0); const auto warp_id = __shfl_sync(0xffffffff, thread_id / 32, 0), lane_id = get_lane_id(); @@ -596,6 +598,9 @@ combine(void* combined_x, constexpr int kNumMetaBytes = kNumDivisions * sizeof(nv_bfloat162); constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16) + kNumMetaBytes; EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); + + // Parameters for IBGDA sends outer loop, declared upfront to bypass goto initialization restrictions. + int initial_idx, loop_bound, step_size; // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) @@ -613,10 +618,41 @@ combine(void* combined_x, atomic_add_release_global(atomic_clean_flag, num_experts); } - // Issue IBGDA sends - if (responsible_expert_idx < num_experts) { - const auto dst_rank = responsible_expert_idx / num_local_experts; - const auto local_expert_idx = responsible_expert_idx % num_local_experts; + // Shared between warps in sms for overlap mode, where each sm only has one warp group + __shared__ int shared_vaild_signal_prefix_sum[288]; + __shared__ int shared_vaild_signal_sum, shared_local_expert_idx; + + // Compute prefix sums of valid signal counts per local expert + if (overlap) { + if (sub_warp_id == 0 and lane_id == 0) { + shared_vaild_signal_prefix_sum[0] = ceil_div(packed_recv_count[0], block_m); + shared_local_expert_idx = 0; + #pragma unroll + for (int i = 1; i < num_local_experts; i++) { + shared_vaild_signal_prefix_sum[i] = shared_vaild_signal_prefix_sum[i-1] + ceil_div(packed_recv_count[i], block_m); + } + shared_vaild_signal_sum = shared_vaild_signal_prefix_sum[num_local_experts-1]; + } + __syncthreads(); + } + + // Issue IBGDA sends, non-overlap mode only loops once + initial_idx = overlap ? sm_id : responsible_expert_idx; + loop_bound = overlap ? shared_vaild_signal_sum : num_experts; + step_size = overlap ? num_sms : num_experts; + for (int vaild_signal_idx = initial_idx; vaild_signal_idx < loop_bound; vaild_signal_idx += step_size) { + + // Find the owning local_expert_idx by scanning the prefix-sum array + if (overlap) { + if (sub_warp_id == 0 and lane_id == 0) { + while (vaild_signal_idx >= shared_vaild_signal_prefix_sum[shared_local_expert_idx]) + shared_local_expert_idx++; + } + __syncthreads(); + } + + auto dst_rank = responsible_expert_idx / num_local_experts; + const auto local_expert_idx = overlap ? shared_local_expert_idx : responsible_expert_idx % num_local_experts; const auto global_expert_idx = rank * num_local_experts + local_expert_idx; const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank); const auto local_x = static_cast(x) + @@ -629,6 +665,22 @@ combine(void* combined_x, int offset, num_tokens_to_send; unpack2(layout, num_tokens_to_send, offset); + // Wait the corresponding comp_signal to reach the threshold + int num_tokens_per_expert, num_signal_per_expert, local_expert_signal_idx; + const int* gemm_comp_signal; + if (overlap) { + num_tokens_per_expert = packed_recv_count[local_expert_idx]; + num_signal_per_expert = ceil_div(num_ranks * num_max_dispatch_tokens_per_rank, block_m); + local_expert_signal_idx = (local_expert_idx == 0) ? vaild_signal_idx : + vaild_signal_idx - shared_vaild_signal_prefix_sum[local_expert_idx-1]; + gemm_comp_signal = comp_signal + num_signal_per_expert * local_expert_idx + local_expert_signal_idx; + + if (sub_warp_id == 0 and lane_id == 0 and packed_recv_count[local_expert_idx] != 0) { + while (ld_acquire_global(gemm_comp_signal) != threshold); + } + __syncthreads(); + } + // TMA stuffs constexpr int kNumTMABufferBytes = sizeof(int4) * 32 * kNumSendUnrolls; constexpr int kNumStages = 3; @@ -660,13 +712,17 @@ combine(void* combined_x, }; // Issue IBGDA send - for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) { + auto token_start_idx = overlap ? local_expert_signal_idx * block_m : offset; + auto token_end_idx = overlap ? min((local_expert_signal_idx + 1) * block_m, num_tokens_per_expert) : (offset + num_tokens_to_send); + for (int token_idx = sub_warp_id + token_start_idx; token_idx < token_end_idx; token_idx += num_warps_per_group) { const auto x_int4 = local_x + token_idx * hidden_bf16_int4; const auto rdma_send_type_row = reinterpret_cast(rdma_send_x_vec + token_idx * num_bytes_per_slot); const auto rdma_send_x_vec_row = reinterpret_cast(rdma_send_type_row); // Copy directly to local rank, or copy to buffer and issue RDMA - const auto src_idx = __shfl_sync(0xffffffff, __ldg(local_src_info + token_idx), 0); + if (overlap) + dst_rank = __shfl_sync(0xffffffff, static_cast(__ldg(local_src_info + token_idx) >> 32), 0); + const auto src_idx = __shfl_sync(0xffffffff, static_cast(__ldg(local_src_info + token_idx) & 0xffffffff), 0); const auto buf_ptr = reinterpret_cast(rdma_send_x_vec_row); const auto dst_ptr = reinterpret_cast(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot; const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); @@ -732,24 +788,26 @@ combine(void* combined_x, // Issue RDMA // NOTES: for zero-copy mode, we assume the data is already in the send buffer if (dst_p2p_ptr == 0) - nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, num_send_bytes, dst_rank, local_expert_idx, lane_id, token_idx - offset); + nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, num_send_bytes, dst_rank, local_expert_idx, lane_id, token_idx); } - - // Put the finishing flag - EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16); asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(num_warps_per_group * 32)); - if (sub_warp_id == 1 and lane_id == 0) { - while (ld_acquire_global(atomic_clean_flag) == 0); - auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx); - auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); - if (dst_p2p_ptr == 0) { - nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), 1, dst_rank, local_expert_idx); - } else { - st_release_sys_global(reinterpret_cast(dst_p2p_ptr), 1); + + // Put the finishing flag for non-overlap mode + if (overlap == false) { + EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16); + if (sub_warp_id == 1 and lane_id == 0) { + while (ld_acquire_global(atomic_clean_flag) == 0); + auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx); + auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (dst_p2p_ptr == 0) { + nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), 1, dst_rank, local_expert_idx); + } else { + st_release_sys_global(reinterpret_cast(dst_p2p_ptr), 1); + } + atomic_add_release_global(atomic_clean_flag, -1); } - atomic_add_release_global(atomic_clean_flag, -1); + __syncwarp(); } - __syncwarp(); // Destroy m-barriers if (lane_id < kNumStages) { @@ -760,6 +818,27 @@ combine(void* combined_x, __syncwarp(); } + // Put the finishing flag for overlap mode + if (overlap) { + cg::this_grid().sync(); + if (sm_id == 0) { + for (int local_expert_idx = warp_id; local_expert_idx < num_local_experts; local_expert_idx += num_warps) { + auto global_expert_idx = rank * num_local_experts + local_expert_idx; + for (int dst_rank = lane_id; dst_rank < num_ranks; dst_rank += 32) { + while (ld_acquire_global(atomic_clean_flag) == 0); + auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx); + auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (dst_p2p_ptr == 0) { + nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), 1, dst_rank, local_expert_idx); + } else { + st_release_sys_global(reinterpret_cast(dst_p2p_ptr), 1); + } + atomic_add_release_global(atomic_clean_flag, -1); + } + } + } + } + // Receiving phase LOW_LATENCY_COMBINE_RECV: if ((phases & LOW_LATENCY_RECV_PHASE) == 0) @@ -920,22 +999,35 @@ combine(void* combined_x, void combine(void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, const void* x, const int64_t* topk_idx, const float* topk_weights, - const int* src_info, const int64_t* layout_range, + const int64_t* src_info, const int64_t* layout_range, + bool overlap, int* packed_recv_count, int* comp_signal, int block_m, int threshold, int64_t* combine_wait_recv_cost_stats, int* next_clean, int num_next_clean_int, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_logfmt, - void* workspace, int num_device_sms, + void* workspace, int num_device_sms, int num_sms, cudaStream_t stream, int phases, bool zero_copy) { constexpr int kNumMaxTopk = 9; - const int num_warp_groups = ceil_div(num_experts, num_device_sms); - const int num_warps_per_group = 32 / num_warp_groups; - const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms); - EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm > 0); + int num_warp_groups, num_warps_per_group, num_recv_per_sm, num_warps; - const auto num_warps = num_warp_groups * num_warps_per_group; - const auto num_sms = max(ceil_div(num_experts, num_warp_groups), ceil_div(num_combined_tokens, num_recv_per_sm)); + if (overlap == true and phases == LOW_LATENCY_SEND_PHASE) { + num_warp_groups = 1; + num_warps_per_group = 32; + num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms); + EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm > 0); + + num_warps = num_warp_groups * num_warps_per_group; + } + else { + num_warp_groups = ceil_div(num_experts, num_device_sms); + num_warps_per_group = 32 / num_warp_groups; + num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms); + EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and num_recv_per_sm > 0); + + num_warps = num_warp_groups * num_warps_per_group; + num_sms = max(ceil_div(num_experts, num_warp_groups), ceil_div(num_combined_tokens, num_recv_per_sm)); + } // Check workspace auto atomic_clean_flag = static_cast(workspace); @@ -970,6 +1062,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \ combined_x, \ rdma_recv_x, rdma_recv_flag, rdma_send_x, \ x, topk_idx, topk_weights, src_info, layout_range, \ + overlap, packed_recv_count, comp_signal, block_m, threshold, \ combine_wait_recv_cost_stats, \ next_clean, num_next_clean_int, \ atomic_clean_flag, \ diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index eab8055c..7ce97e8d 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -633,7 +633,72 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig hook: the receiving hook function (valid only if `return_recv_hook` is set). """ src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle + overlap, packed_recv_count, comp_signal, block_m, threshold, num_sms = False, None, None, 64, 0, 3 combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, + overlap, packed_recv_count, comp_signal, block_m, threshold, num_sms, + combine_wait_recv_cost_stats, + num_max_dispatch_tokens_per_rank, num_experts, + use_logfmt, zero_copy, async_finish, return_recv_hook, + out) + tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) + return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook + +# noinspection PyTypeChecker + def ll_overlap_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: tuple, + overlap: bool = False, packed_recv_count: torch.Tensor = None, comp_signal: torch.Tensor = None, + block_m: int = 64, threshold: int = 0, num_sms: int = 3, + use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False, + return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, + combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \ + Tuple[torch.Tensor, EventOverlap, Callable]: + """ + A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. + It overlaps the down gemm computation with combine send phase, coordinated via a signaling mechanism. + This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA + (specifically, IBGDA must be enabled). + Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2 + low-latency kernels' result tensors at a single moment. + + Arguments: + x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`, + the local calculated tokens to be sent to this original rank and reduced. + topk_idx: `[num_combined_tokens, num_topk]` with `torch.int64`, the expert indices selected by the dispatched + tokens. `-1` indices (not selecting any expert) are supported. Note that, `num_combined_tokens` equals + to the number of dispatched tokens. + topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched + tokens. The received tokens will be reduced with the weights in this tensor. + handle: the communication handle given by the `dispatch` function. + packed_recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each + expert receive. + comp_signal: `[num_local_experts * ceil_div(num_tokens * num_max_dispatch_tokens_per_rank, block_m)]` with `torch.int32`, + each element indicates the processing progress of `block_m` tokens in DeepGEMM. + Note that, the fixed-length tensor is used to support cuda graph, + only the first `ceil_div(num_tokens * num_ranks, block_m)` elements within its corresponding segment are valid. + block_m: set by DeepGEMM. + threshold: set by DeepGEMM. When a valid element in comp_signal reaches this threshold, it means that all the tokens + corresponding to this element have been computed by DeepGEMM and can be sent. + overlap: whether to overlap the down gemm with the combine send phase. + num_sms: the number of sms used by low_latency_combine send, only needs to be set when overlap is `True`. + use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits). + zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative + with `get_next_low_latency_combine_buffer`. + async_finish: the current stream will not wait for the communication kernels to be finished if set. + return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, + but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. + If you do not set this flag, the kernel will ensure the data's arrival. + out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. + combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics, + which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`. + This is useful for detecting and pre-cisely localizing slow anomalies. + + Returns: + combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `torch.bfloat16`. + event: the event after executing the kernel (valid only if `async_finish` is set). + hook: the receiving hook function (valid only if `return_recv_hook` is set). + """ + src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle + combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, + overlap, packed_recv_count, comp_signal, block_m, threshold, num_sms, combine_wait_recv_cost_stats, num_max_dispatch_tokens_per_rank, num_experts, use_logfmt, zero_copy, async_finish, return_recv_hook, diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index aa928aab..c0b599ec 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -9,7 +9,7 @@ from typing import Optional import deep_ep -from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back +from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back, ceil_div def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, @@ -83,12 +83,12 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, if current_x is x: recv_x = recv_x[:num_valid_tokens] recv_x_amin = recv_x[:, :-128].amin(dim=-1) - recv_src_info = recv_src_info[:num_valid_tokens] + src_token_idx = recv_src_info[:num_valid_tokens] & int_mask assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) if round_scale: - assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 + assert calc_diff(recv_x[:, -1], src_token_idx.view(-1)) < 0.007 else: - assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 + assert (recv_x[:, -128:] - src_token_idx.view(-1, 1) % num_tokens).sum().item() == 0 for j in range(num_ranks): begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() if not round_scale: @@ -102,19 +102,34 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, # Check combine correctness for zero_copy in (False, ) if use_logfmt else (False, True): - if zero_copy: - buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x - out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - use_logfmt=use_logfmt, - async_finish=not return_recv_hook, zero_copy=zero_copy, - return_recv_hook=return_recv_hook, out=out) - hook() if return_recv_hook else event.current_stream_wait() - if do_check: - diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) - assert torch.isnan(combined_x).sum().item() == 0 - assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}' - hash_value ^= hash_tensor(combined_x) + for overlap in (False, True) if return_recv_hook else (False, ): + if zero_copy: + buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x + out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + if overlap: + block_m, threshold, num_sms = 64, 10, 3 + total_num_per_expert = ceil_div(num_tokens * num_ranks, block_m) + comp_signal = torch.zeros(num_local_experts * total_num_per_expert, dtype=torch.int32, device='cuda') + for i in range(num_local_experts): + vaild_num = ceil_div(packed_recv_count[i], block_m) + comp_signal[i * total_num_per_expert : i * total_num_per_expert + vaild_num] = threshold + combined_x, event, hook = buffer.ll_overlap_combine(simulated_gemm_x, topk_idx, topk_weights, handle, + overlap = overlap, packed_recv_count = packed_recv_count, + comp_signal = comp_signal, block_m = block_m, threshold = threshold, num_sms = num_sms, + use_logfmt=use_logfmt, + async_finish=not return_recv_hook, zero_copy=zero_copy, + return_recv_hook=return_recv_hook, out=out) + else: + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, + use_logfmt=use_logfmt, + async_finish=not return_recv_hook, zero_copy=zero_copy, + return_recv_hook=return_recv_hook, out=out) + hook() if return_recv_hook else event.current_stream_wait() + if do_check: + diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) + assert torch.isnan(combined_x).sum().item() == 0 + assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}' + hash_value ^= hash_tensor(combined_x) # noinspection PyShadowingNames def large_gemm_with_hook(hook): diff --git a/tests/utils.py b/tests/utils.py index a64cc0ae..ad3bafc3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -223,3 +223,7 @@ def bench_kineto(fn, kernel_names: Union[str, tuple], num_tests: int = 30, suppr def hash_tensor(t: torch.Tensor): return t.view(torch.int).sum().item() + + +def ceil_div(a, b): + return (a + b - 1) // b From 42ed2712538f047ba6d0664d6e1c48c20543a594 Mon Sep 17 00:00:00 2001 From: Zqy11 <841971412@qq.com> Date: Mon, 25 Aug 2025 15:41:02 +0800 Subject: [PATCH 02/11] optimize put finishing flag for overlap mode --- csrc/kernels/internode_ll.cu | 65 ++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 951a6bbc..3099ba95 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -560,7 +560,7 @@ combine(void* combined_x, bool overlap, int* packed_recv_count, int* comp_signal, int block_m, int threshold, int64_t* combine_wait_recv_cost_stats, int* next_clean, int num_next_clean_int, - int* atomic_clean_flag, + int* atomic_clean_flag, int* atomic_finish_counter_per_expert, int num_combined_tokens, int hidden, int num_topk, int num_max_dispatch_tokens_per_rank, int num_experts, int rank, int num_ranks, @@ -568,7 +568,6 @@ combine(void* combined_x, int phases, bool zero_copy) { const auto sm_id = __shfl_sync(0xffffffff, static_cast(blockIdx.x), 0); const auto num_sms = __shfl_sync(0xffffffff, static_cast(gridDim.x), 0); - const auto num_warps = num_warp_groups * num_warps_per_group; const auto thread_id = static_cast(threadIdx.x); const auto num_threads = __shfl_sync(0xffffffff, static_cast(blockDim.x), 0); const auto warp_id = __shfl_sync(0xffffffff, thread_id / 32, 0), lane_id = get_lane_id(); @@ -625,11 +624,12 @@ combine(void* combined_x, // Compute prefix sums of valid signal counts per local expert if (overlap) { if (sub_warp_id == 0 and lane_id == 0) { - shared_vaild_signal_prefix_sum[0] = ceil_div(packed_recv_count[0], block_m); + shared_vaild_signal_prefix_sum[0] = (packed_recv_count[0] == 0 ? 1 : ceil_div(packed_recv_count[0], block_m)); shared_local_expert_idx = 0; #pragma unroll for (int i = 1; i < num_local_experts; i++) { - shared_vaild_signal_prefix_sum[i] = shared_vaild_signal_prefix_sum[i-1] + ceil_div(packed_recv_count[i], block_m); + shared_vaild_signal_prefix_sum[i] = shared_vaild_signal_prefix_sum[i-1] + + (packed_recv_count[i] == 0 ? 1 : ceil_div(packed_recv_count[i], block_m)); } shared_vaild_signal_sum = shared_vaild_signal_prefix_sum[num_local_experts-1]; } @@ -792,8 +792,35 @@ combine(void* combined_x, } asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(num_warps_per_group * 32)); - // Put the finishing flag for non-overlap mode - if (overlap == false) { + if (overlap) { + // Put the finishing flag for overlap mode + if (warp_id == 0 and lane_id == 0) + atomicAdd(atomic_finish_counter_per_expert + shared_local_expert_idx, 1); + __syncthreads(); + + if ((local_expert_signal_idx + 1) * block_m >= num_tokens_per_expert) { + if (warp_id == 0) { + auto global_expert_idx = rank * num_local_experts + shared_local_expert_idx; + for (int dst_rank = lane_id; dst_rank < num_ranks; dst_rank += 32) { + if (packed_recv_count[shared_local_expert_idx] != 0) + while (ld_acquire_global(atomic_finish_counter_per_expert + shared_local_expert_idx) != ceil_div(num_tokens_per_expert, block_m)); + auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx); + auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (dst_p2p_ptr == 0) { + nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), 1, dst_rank, shared_local_expert_idx); + } else { + st_release_sys_global(reinterpret_cast(dst_p2p_ptr), 1); + } + atomic_add_release_global(atomic_clean_flag, -1); + } + if (lane_id == 0) + atomic_finish_counter_per_expert[shared_local_expert_idx] = 0; + } + } + __syncthreads(); + } + else { + // Put the finishing flag for non-overlap mode EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16); if (sub_warp_id == 1 and lane_id == 0) { while (ld_acquire_global(atomic_clean_flag) == 0); @@ -818,27 +845,6 @@ combine(void* combined_x, __syncwarp(); } - // Put the finishing flag for overlap mode - if (overlap) { - cg::this_grid().sync(); - if (sm_id == 0) { - for (int local_expert_idx = warp_id; local_expert_idx < num_local_experts; local_expert_idx += num_warps) { - auto global_expert_idx = rank * num_local_experts + local_expert_idx; - for (int dst_rank = lane_id; dst_rank < num_ranks; dst_rank += 32) { - while (ld_acquire_global(atomic_clean_flag) == 0); - auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx); - auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); - if (dst_p2p_ptr == 0) { - nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), 1, dst_rank, local_expert_idx); - } else { - st_release_sys_global(reinterpret_cast(dst_p2p_ptr), 1); - } - atomic_add_release_global(atomic_clean_flag, -1); - } - } - } - } - // Receiving phase LOW_LATENCY_COMBINE_RECV: if ((phases & LOW_LATENCY_RECV_PHASE) == 0) @@ -1031,7 +1037,8 @@ void combine(void* combined_x, // Check workspace auto atomic_clean_flag = static_cast(workspace); - EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES); + auto atomic_finish_counter_per_expert = atomic_clean_flag + 1; + EP_HOST_ASSERT((1 + num_experts) * sizeof(int) <= NUM_WORKSPACE_BYTES); EP_HOST_ASSERT(num_topk <= kNumMaxTopk); // Online cast cannot use zero-copy @@ -1065,7 +1072,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \ overlap, packed_recv_count, comp_signal, block_m, threshold, \ combine_wait_recv_cost_stats, \ next_clean, num_next_clean_int, \ - atomic_clean_flag, \ + atomic_clean_flag, atomic_finish_counter_per_expert, \ num_combined_tokens, hidden, num_topk, \ num_max_dispatch_tokens_per_rank, \ num_experts, rank, num_ranks, \ From 389e94b55a8b271f4a71581ecee290e168f3626c Mon Sep 17 00:00:00 2001 From: Zqy11 <841971412@qq.com> Date: Tue, 26 Aug 2025 16:09:58 +0800 Subject: [PATCH 03/11] optimize put finishing flag for overlap mode --- csrc/kernels/internode_ll.cu | 42 ++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 3099ba95..d1f2947d 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -675,7 +675,7 @@ combine(void* combined_x, vaild_signal_idx - shared_vaild_signal_prefix_sum[local_expert_idx-1]; gemm_comp_signal = comp_signal + num_signal_per_expert * local_expert_idx + local_expert_signal_idx; - if (sub_warp_id == 0 and lane_id == 0 and packed_recv_count[local_expert_idx] != 0) { + if (sub_warp_id == 0 and lane_id == 0 and num_tokens_per_expert != 0) { while (ld_acquire_global(gemm_comp_signal) != threshold); } __syncthreads(); @@ -794,28 +794,32 @@ combine(void* combined_x, if (overlap) { // Put the finishing flag for overlap mode - if (warp_id == 0 and lane_id == 0) - atomicAdd(atomic_finish_counter_per_expert + shared_local_expert_idx, 1); + bool put_finish_flag = false; + if (sub_warp_id == 0) { + if (lane_id == 0) { + const auto finish_counter = (num_tokens_per_expert == 0 ? 1 : ceil_div(num_tokens_per_expert, block_m)); + if ((atomicAdd(atomic_finish_counter_per_expert + local_expert_idx, 1) + 1) == finish_counter) + put_finish_flag = true; + } + put_finish_flag = __shfl_sync(0xffffffff, put_finish_flag, 0); + } __syncthreads(); - if ((local_expert_signal_idx + 1) * block_m >= num_tokens_per_expert) { - if (warp_id == 0) { - auto global_expert_idx = rank * num_local_experts + shared_local_expert_idx; - for (int dst_rank = lane_id; dst_rank < num_ranks; dst_rank += 32) { - if (packed_recv_count[shared_local_expert_idx] != 0) - while (ld_acquire_global(atomic_finish_counter_per_expert + shared_local_expert_idx) != ceil_div(num_tokens_per_expert, block_m)); - auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx); - auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); - if (dst_p2p_ptr == 0) { - nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), 1, dst_rank, shared_local_expert_idx); - } else { - st_release_sys_global(reinterpret_cast(dst_p2p_ptr), 1); - } - atomic_add_release_global(atomic_clean_flag, -1); + if (sub_warp_id == 0 and put_finish_flag) { + auto global_expert_idx = rank * num_local_experts + local_expert_idx; + for (int dst_rank = lane_id; dst_rank < num_ranks; dst_rank += 32) { + while (ld_acquire_global(atomic_clean_flag) == 0); + auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx); + auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (dst_p2p_ptr == 0) { + nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), 1, dst_rank, local_expert_idx); + } else { + st_release_sys_global(reinterpret_cast(dst_p2p_ptr), 1); } - if (lane_id == 0) - atomic_finish_counter_per_expert[shared_local_expert_idx] = 0; + atomic_add_release_global(atomic_clean_flag, -1); } + if (lane_id == 0) + atomic_finish_counter_per_expert[local_expert_idx] = 0; } __syncthreads(); } From 988bad893e320cad2ddad0221b4e36b85ffd1d5e Mon Sep 17 00:00:00 2001 From: Zqy11 <841971412@qq.com> Date: Thu, 28 Aug 2025 11:32:45 +0800 Subject: [PATCH 04/11] maintain: delete redundant code --- csrc/kernels/internode_ll.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index d1f2947d..7c23a211 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -806,7 +806,6 @@ combine(void* combined_x, __syncthreads(); if (sub_warp_id == 0 and put_finish_flag) { - auto global_expert_idx = rank * num_local_experts + local_expert_idx; for (int dst_rank = lane_id; dst_rank < num_ranks; dst_rank += 32) { while (ld_acquire_global(atomic_clean_flag) == 0); auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx); From 11c5c6646a80c043b7541f988792f0b78b95dbf4 Mon Sep 17 00:00:00 2001 From: Zqy11 <841971412@qq.com> Date: Thu, 28 Aug 2025 14:14:32 +0800 Subject: [PATCH 05/11] maintain: add lambda send_finish_flag --- csrc/kernels/internode_ll.cu | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 7c23a211..83fdce26 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -792,6 +792,18 @@ combine(void* combined_x, } asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(num_warps_per_group * 32)); + auto send_finish_flag = [&](int dst_rank) { + while (ld_acquire_global(atomic_clean_flag) == 0); + auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx); + auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (dst_p2p_ptr == 0) { + nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), 1, dst_rank, local_expert_idx); + } else { + st_release_sys_global(reinterpret_cast(dst_p2p_ptr), 1); + } + atomic_add_release_global(atomic_clean_flag, -1); + }; + if (overlap) { // Put the finishing flag for overlap mode bool put_finish_flag = false; @@ -807,15 +819,7 @@ combine(void* combined_x, if (sub_warp_id == 0 and put_finish_flag) { for (int dst_rank = lane_id; dst_rank < num_ranks; dst_rank += 32) { - while (ld_acquire_global(atomic_clean_flag) == 0); - auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx); - auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); - if (dst_p2p_ptr == 0) { - nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), 1, dst_rank, local_expert_idx); - } else { - st_release_sys_global(reinterpret_cast(dst_p2p_ptr), 1); - } - atomic_add_release_global(atomic_clean_flag, -1); + send_finish_flag(dst_rank); } if (lane_id == 0) atomic_finish_counter_per_expert[local_expert_idx] = 0; @@ -826,15 +830,7 @@ combine(void* combined_x, // Put the finishing flag for non-overlap mode EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16); if (sub_warp_id == 1 and lane_id == 0) { - while (ld_acquire_global(atomic_clean_flag) == 0); - auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx); - auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); - if (dst_p2p_ptr == 0) { - nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), 1, dst_rank, local_expert_idx); - } else { - st_release_sys_global(reinterpret_cast(dst_p2p_ptr), 1); - } - atomic_add_release_global(atomic_clean_flag, -1); + send_finish_flag(dst_rank); } __syncwarp(); } From b6f500c1ad4db3d6478d8c8acbb009c6f4fb4ab7 Mon Sep 17 00:00:00 2001 From: Zqy11 <841971412@qq.com> Date: Tue, 2 Sep 2025 15:52:37 +0800 Subject: [PATCH 06/11] maintain: add kNumMaxExperts --- csrc/kernels/internode_ll.cu | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 3fe23fc2..30afa3ee 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -348,6 +348,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, void* workspace, int num_device_sms, cudaStream_t stream, int phases) { constexpr int kNumMaxTopK = 9; + constexpr int kNumMaxExperts = 288; const int num_warp_groups = ceil_div(num_experts, num_device_sms); const int num_warps_per_group = 32 / num_warp_groups; EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0); @@ -356,6 +357,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, const auto num_warps = num_warp_groups * num_warps_per_group; const auto num_sms = ceil_div(num_experts, num_warp_groups); EP_HOST_ASSERT(num_topk <= kNumMaxTopK); + EP_HOST_ASSERT(num_experts <= kNumMaxExperts); // Workspace checks auto atomic_counter_per_expert = static_cast(workspace); @@ -551,7 +553,7 @@ __forceinline__ __device__ void decode_and_accumulate(uint32_t* ld_buffer, float } } -template +template __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, @@ -618,7 +620,7 @@ combine(void* combined_x, } // Shared between warps in sms for overlap mode, where each sm only has one warp group - __shared__ int shared_vaild_signal_prefix_sum[288]; + __shared__ int shared_vaild_signal_prefix_sum[kNumMaxExperts]; __shared__ int shared_vaild_signal_sum, shared_local_expert_idx; // Compute prefix sums of valid signal counts per local expert @@ -1014,6 +1016,7 @@ void combine(void* combined_x, void* workspace, int num_device_sms, int num_sms, cudaStream_t stream, int phases, bool zero_copy) { constexpr int kNumMaxTopk = 9; + constexpr int kNumMaxExperts = 288; int num_warp_groups, num_warps_per_group, num_recv_per_sm, num_warps; if (overlap == true and phases == LOW_LATENCY_SEND_PHASE) { @@ -1040,6 +1043,7 @@ void combine(void* combined_x, auto atomic_finish_counter_per_expert = atomic_clean_flag + 1; EP_HOST_ASSERT((1 + num_experts) * sizeof(int) <= NUM_WORKSPACE_BYTES); EP_HOST_ASSERT(num_topk <= kNumMaxTopk); + EP_HOST_ASSERT(num_experts <= kNumMaxExperts); // Online cast cannot use zero-copy EP_HOST_ASSERT(not (zero_copy and use_logfmt)); From f065dd4c1bda18468170659bf00af35aa626ac54 Mon Sep 17 00:00:00 2001 From: Zqy11 <841971412@qq.com> Date: Tue, 2 Sep 2025 15:54:35 +0800 Subject: [PATCH 07/11] maintain: add kNumMaxExperts --- csrc/kernels/internode_ll.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 30afa3ee..765bb7c4 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -1066,8 +1066,8 @@ void combine(void* combined_x, #define COMBINE_LAUNCH_CASE(hidden) { \ auto combine_func = use_logfmt ? \ - combine : \ - combine; \ + combine : \ + combine; \ SET_SHARED_MEMORY_FOR_TMA(combine_func); \ LAUNCH_KERNEL(&cfg, combine_func, \ combined_x, \ From 4694c00ea1c765d853ae40bec6a16ea033b65dd2 Mon Sep 17 00:00:00 2001 From: Zqy11 <841971412@qq.com> Date: Fri, 5 Sep 2025 17:37:43 +0800 Subject: [PATCH 08/11] maintain: delete ll_overlap_combine --- deep_ep/buffer.py | 62 ++++----------------------------------- tests/test_low_latency.py | 12 ++++---- 2 files changed, 12 insertions(+), 62 deletions(-) diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 67633c51..a649928f 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -594,8 +594,10 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, EventOverlap(event, tensors_to_record if async_finish else None), hook # noinspection PyTypeChecker - def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False, + def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: tuple, + overlap: bool = False, packed_recv_count: torch.Tensor = None, comp_signal: torch.Tensor = None, + block_m: int = 64, threshold: int = 0, num_sms: int = 3, + use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False, return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \ Tuple[torch.Tensor, EventOverlap, Callable]: @@ -615,59 +617,7 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched tokens. The received tokens will be reduced with the weights in this tensor. handle: the communication handle given by the `dispatch` function. - use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits). - zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative - with `get_next_low_latency_combine_buffer`. - async_finish: the current stream will not wait for the communication kernels to be finished if set. - return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, - but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. - If you do not set this flag, the kernel will ensure the data's arrival. - out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. - combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics, - which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`. - This is useful for detecting and pre-cisely localizing slow anomalies. - - Returns: - combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `torch.bfloat16`. - event: the event after executing the kernel (valid only if `async_finish` is set). - hook: the receiving hook function (valid only if `return_recv_hook` is set). - """ - src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle - overlap, packed_recv_count, comp_signal, block_m, threshold, num_sms = False, None, None, 64, 0, 3 - combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, - overlap, packed_recv_count, comp_signal, block_m, threshold, num_sms, - combine_wait_recv_cost_stats, - num_max_dispatch_tokens_per_rank, num_experts, - use_logfmt, zero_copy, async_finish, return_recv_hook, - out) - tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) - return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook - -# noinspection PyTypeChecker - def ll_overlap_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: tuple, - overlap: bool = False, packed_recv_count: torch.Tensor = None, comp_signal: torch.Tensor = None, - block_m: int = 64, threshold: int = 0, num_sms: int = 3, - use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False, - return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, - combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \ - Tuple[torch.Tensor, EventOverlap, Callable]: - """ - A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. - It overlaps the down gemm computation with combine send phase, coordinated via a signaling mechanism. - This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA - (specifically, IBGDA must be enabled). - Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2 - low-latency kernels' result tensors at a single moment. - - Arguments: - x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`, - the local calculated tokens to be sent to this original rank and reduced. - topk_idx: `[num_combined_tokens, num_topk]` with `torch.int64`, the expert indices selected by the dispatched - tokens. `-1` indices (not selecting any expert) are supported. Note that, `num_combined_tokens` equals - to the number of dispatched tokens. - topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched - tokens. The received tokens will be reduced with the weights in this tensor. - handle: the communication handle given by the `dispatch` function. + overlap: whether to overlap the down gemm with the combine send phase. packed_recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each expert receive. comp_signal: `[num_local_experts * ceil_div(num_tokens * num_max_dispatch_tokens_per_rank, block_m)]` with `torch.int32`, @@ -677,7 +627,6 @@ def ll_overlap_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weigh block_m: set by DeepGEMM. threshold: set by DeepGEMM. When a valid element in comp_signal reaches this threshold, it means that all the tokens corresponding to this element have been computed by DeepGEMM and can be sent. - overlap: whether to overlap the down gemm with the combine send phase. num_sms: the number of sms used by low_latency_combine send, only needs to be set when overlap is `True`. use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits). zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative @@ -697,6 +646,7 @@ def ll_overlap_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weigh hook: the receiving hook function (valid only if `return_recv_hook` is set). """ src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle + overlap, packed_recv_count, comp_signal, block_m, threshold, num_sms = False, None, None, 64, 0, 3 combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, overlap, packed_recv_count, comp_signal, block_m, threshold, num_sms, combine_wait_recv_cost_stats, diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index c0b599ec..b6db5e92 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -113,12 +113,12 @@ def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, for i in range(num_local_experts): vaild_num = ceil_div(packed_recv_count[i], block_m) comp_signal[i * total_num_per_expert : i * total_num_per_expert + vaild_num] = threshold - combined_x, event, hook = buffer.ll_overlap_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - overlap = overlap, packed_recv_count = packed_recv_count, - comp_signal = comp_signal, block_m = block_m, threshold = threshold, num_sms = num_sms, - use_logfmt=use_logfmt, - async_finish=not return_recv_hook, zero_copy=zero_copy, - return_recv_hook=return_recv_hook, out=out) + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, + overlap = overlap, packed_recv_count = packed_recv_count, + comp_signal = comp_signal, block_m = block_m, threshold = threshold, num_sms = num_sms, + use_logfmt=use_logfmt, + async_finish=not return_recv_hook, zero_copy=zero_copy, + return_recv_hook=return_recv_hook, out=out) else: combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, use_logfmt=use_logfmt, From 0282e2e030bd01b7fd14df00107f76538af2c97e Mon Sep 17 00:00:00 2001 From: Zqy11 <841971412@qq.com> Date: Fri, 5 Sep 2025 17:50:31 +0800 Subject: [PATCH 09/11] maintain: optimize code --- csrc/kernels/internode_ll.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 765bb7c4..0f4ba644 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -722,8 +722,7 @@ combine(void* combined_x, const auto rdma_send_x_vec_row = reinterpret_cast(rdma_send_type_row); // Copy directly to local rank, or copy to buffer and issue RDMA - if (overlap) - dst_rank = __shfl_sync(0xffffffff, static_cast(__ldg(local_src_info + token_idx) >> 32), 0); + overlap ? (dst_rank = __shfl_sync(0xffffffff, static_cast(__ldg(local_src_info + token_idx) >> 32), 0)) : 0; const auto src_idx = __shfl_sync(0xffffffff, static_cast(__ldg(local_src_info + token_idx) & 0xffffffff), 0); const auto buf_ptr = reinterpret_cast(rdma_send_x_vec_row); const auto dst_ptr = reinterpret_cast(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot; From e64636422ab2b97a2a3576ef82978c401d0dde21 Mon Sep 17 00:00:00 2001 From: Zqy11 <841971412@qq.com> Date: Fri, 5 Sep 2025 17:59:41 +0800 Subject: [PATCH 10/11] bugfix Co-authored-by: Sulfur6 Co-authored-by: wangfakang <1031379296@qq.com> Co-authored-by: alpha-baby Co-authored-by: AniZpZ --- deep_ep/buffer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index a649928f..c3838ac2 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -646,7 +646,6 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig hook: the receiving hook function (valid only if `return_recv_hook` is set). """ src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle - overlap, packed_recv_count, comp_signal, block_m, threshold, num_sms = False, None, None, 64, 0, 3 combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, overlap, packed_recv_count, comp_signal, block_m, threshold, num_sms, combine_wait_recv_cost_stats, From 93e3d53d76877cc936e6e5a943390f34d11a3862 Mon Sep 17 00:00:00 2001 From: Zqy11 <841971412@qq.com> Date: Thu, 11 Sep 2025 10:48:25 +0800 Subject: [PATCH 11/11] maintain: modify comments --- deep_ep/buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index de47a211..c39837e2 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -625,7 +625,7 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig comp_signal: `[num_local_experts * ceil_div(num_tokens * num_max_dispatch_tokens_per_rank, block_m)]` with `torch.int32`, each element indicates the processing progress of `block_m` tokens in DeepGEMM. Note that, the fixed-length tensor is used to support cuda graph, - only the first `ceil_div(num_tokens * num_ranks, block_m)` elements within its corresponding segment are valid. + only the first `ceil_div(num_tokens * num_ranks, block_m)` elements of each local_expert are valid. block_m: set by DeepGEMM. threshold: set by DeepGEMM. When a valid element in comp_signal reaches this threshold, it means that all the tokens corresponding to this element have been computed by DeepGEMM and can be sent.