diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp index 693ba9674f9895..4cfd5e6250d1fd 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp @@ -69,14 +69,15 @@ Buffer::Buffer(int rank, calc_ctx = reinterpret_cast( reinterpret_cast(pg) ->GetDeviceContext(place, true)); - // Task fifo memory - int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS; - int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS; - int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS; + + // Metadata memory + int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); + int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); + int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); // Common checks EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 && - (num_nvl_bytes <= std::numeric_limits::max() || + (num_nvl_bytes <= std::numeric_limits::max() || num_rdma_bytes == 0)); EP_HOST_ASSERT( num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 && @@ -90,9 +91,8 @@ Buffer::Buffer(int rank, EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS || low_latency_mode); // Get ranks - // CUDA_CHECK(cudaGetDevice(&device_id)); rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; - num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), + num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS); num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); // Get device info @@ -100,30 +100,26 @@ Buffer::Buffer(int rank, CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); if (num_nvl_bytes > 0) { - // Local IPC: alloc local memory and set local IPC handle - CUDA_CHECK(cudaMalloc( - &buffer_ptrs[nvl_rank], - num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes)); + // Local IPC: alloc local memory and set local IPC handles + CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], + num_nvl_bytes + barrier_signal_bytes + + buffer_ptr_bytes + barrier_signal_ptr_bytes)); CUDA_CHECK( cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); - buffer_ptrs_gpu = reinterpret_cast( - reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + - fifo_bytes); - - // Set task fifo - EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0); - task_fifo_ptrs[nvl_rank] = reinterpret_cast( - reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); - task_fifo_ptrs_gpu = reinterpret_cast( - reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + - fifo_bytes + buffer_ptr_bytes); + buffer_ptrs_gpu = + reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + + num_nvl_bytes + barrier_signal_bytes); + + // Set barrier signals + barrier_signal_ptrs[nvl_rank] = reinterpret_cast( + static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); + barrier_signal_ptrs_gpu = reinterpret_cast( + static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + + barrier_signal_bytes + buffer_ptr_bytes); // No need to synchronize, will do a full device sync during `sync` CUDA_CHECK(cudaMemsetAsync( - buffer_ptrs[nvl_rank], - 0, - num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes, - comm_stream)); + barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream)); } // Create 32 MiB workspace @@ -165,8 +161,7 @@ Buffer::~Buffer() noexcept(false) { if (num_nvl_bytes > 0) { // Barrier intranode::barrier( - task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream); - move_fifo_slots(); + barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream); CUDA_CHECK(cudaDeviceSynchronize()); // Close remote IPC @@ -197,10 +192,6 @@ Buffer::~Buffer() noexcept(false) { CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); } -void Buffer::move_fifo_slots(int num_slots) { - head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS; -} - bool Buffer::is_available() const { return available; } bool Buffer::is_internode_available() const { @@ -249,7 +240,7 @@ void Buffer::sync( // Sync IPC handles if (num_nvl_bytes > 0) { - EP_HOST_ASSERT(num_ranks == static_cast(device_ids.size())); + EP_HOST_ASSERT(num_ranks == device_ids.size()); EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size()); for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++i) { @@ -261,8 +252,8 @@ void Buffer::sync( ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); CUDA_CHECK(cudaIpcOpenMemHandle( &buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); - task_fifo_ptrs[i] = reinterpret_cast( - reinterpret_cast(buffer_ptrs[i]) + num_nvl_bytes); + barrier_signal_ptrs[i] = reinterpret_cast( + static_cast(buffer_ptrs[i]) + num_nvl_bytes); } else { EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), @@ -270,13 +261,13 @@ void Buffer::sync( } } - // Copy all buffer and task pointers to GPU + // Copy all buffer and barrier signal pointers to GPU CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu, - task_fifo_ptrs, + CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu, + barrier_signal_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); CUDA_CHECK(cudaDeviceSynchronize()); @@ -520,7 +511,7 @@ Buffer::intranode_dispatch( // FP8 scales checks float* x_scales_ptr = nullptr; - int num_scales = 0; + int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; if (x_scales.has_value()) { EP_HOST_ASSERT(x.element_size() == 1); EP_HOST_ASSERT(x_scales->scalar_type() == deep_ep::detail::kFloat32); @@ -529,6 +520,8 @@ Buffer::intranode_dispatch( EP_HOST_ASSERT(x_scales->size(0) == num_tokens); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); x_scales_ptr = x_scales->data_ptr(); + scale_token_stride = static_cast(x_scales->stride(0)); + scale_hidden_stride = static_cast(x_scales->stride(1)); } // Allocate all tensors on comm stream if set @@ -564,12 +557,10 @@ Buffer::intranode_dispatch( intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr(), num_memset_int, buffer_ptrs_gpu, - task_fifo_ptrs_gpu, - head, + barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream); - move_fifo_slots(2); } else { rank_prefix_matrix = ConvertPaddleTensorToDetailTensor( paddle::experimental::empty({num_ranks, num_ranks}, @@ -604,12 +595,10 @@ Buffer::intranode_dispatch( num_memset_int, expert_alignment, buffer_ptrs_gpu, - task_fifo_ptrs_gpu, - head, + barrier_signal_ptrs_gpu, rank, comm_stream, num_channels); - move_fifo_slots(3); // Synchronize total received tokens and tokens per expert auto start_time = std::chrono::high_resolution_clock::now(); @@ -719,10 +708,13 @@ Buffer::intranode_dispatch( is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), num_tokens, + 0, // num_worst_tokens (not exposed) static_cast(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales, + scale_token_stride, + scale_hidden_stride, buffer_ptrs_gpu, rank, num_ranks, @@ -867,15 +859,11 @@ Buffer::intranode_combine( num_channels, num_recv_tokens, num_channels * num_ranks * 2, - task_fifo_ptrs_gpu, - head, + barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream); - // NOTES: this function uses two FIFO slots (barrier before and after) - move_fifo_slots(2); - // Combine data auto recv_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty( {num_recv_tokens, hidden}, x.dtype(), x.place())); @@ -895,6 +883,8 @@ Buffer::intranode_combine( recv_topk_weights_ptr, x.data_ptr(), topk_weights_ptr, + nullptr, // bias_ptrs[0] (not exposed) + nullptr, // bias_ptrs[1] (not exposed) src_idx.data_ptr(), rank_prefix_matrix.data_ptr(), channel_prefix_matrix.data_ptr(), @@ -1084,7 +1074,7 @@ Buffer::internode_dispatch( // FP8 scales checks float* x_scales_ptr = nullptr; - int num_scales = 0; + int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; if (x_scales.has_value()) { EP_HOST_ASSERT(x.element_size() == 1); EP_HOST_ASSERT(x_scales->scalar_type() == deep_ep::detail::kFloat32); @@ -1093,6 +1083,8 @@ Buffer::internode_dispatch( EP_HOST_ASSERT(x_scales->size(0) == num_tokens); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); x_scales_ptr = x_scales->data_ptr(); + scale_token_stride = static_cast(x_scales->stride(0)); + scale_hidden_stride = static_cast(x_scales->stride(1)); } // Allocate all tensors on comm stream if set @@ -1144,15 +1136,13 @@ Buffer::internode_dispatch( config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - task_fifo_ptrs_gpu, - head, + barrier_signal_ptrs_gpu, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, true, low_latency_mode); - move_fifo_slots(2); } else { rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor( paddle::experimental::empty({num_rdma_ranks, num_channels}, @@ -1196,14 +1186,12 @@ Buffer::internode_dispatch( config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - task_fifo_ptrs_gpu, - head, + barrier_signal_ptrs_gpu, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, low_latency_mode); - move_fifo_slots(3); // Synchronize total received tokens and tokens per expert auto start_time = std::chrono::high_resolution_clock::now(); @@ -1320,12 +1308,14 @@ Buffer::internode_dispatch( recv_rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), + is_token_in_rank.data_ptr(), num_tokens, hidden_int4, num_scales, num_topk, num_experts, - is_token_in_rank.data_ptr(), + scale_token_stride, + scale_hidden_stride, rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, @@ -1523,15 +1513,13 @@ Buffer::internode_combine( config.num_max_rdma_chunked_recv_tokens, buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - task_fifo_ptrs_gpu, - head, + barrier_signal_ptrs_gpu, rank, comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), num_nvl_bytes, false, low_latency_mode); - move_fifo_slots(2); // Launch data combine auto combined_x = @@ -1543,6 +1531,8 @@ Buffer::internode_combine( is_combined_token_in_rank.data_ptr(), x.data_ptr(), topk_weights_ptr, + nullptr, // bias_ptrs[0] (not exposed) + nullptr, // bias_ptrs[1] (not exposed) combined_rdma_head.data_ptr(), combined_nvl_head.data_ptr(), src_meta.data_ptr(), diff --git a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp index 9733416c8611e2..ad82d08c16439d 100644 --- a/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp +++ b/paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp @@ -77,10 +77,9 @@ struct Buffer { // After IPC/NVSHMEM synchronization, this flag will be true bool available = false; - // Task fifo - int head = 0; - int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; - int** task_fifo_ptrs_gpu = nullptr; + // Barrier signals + int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + int** barrier_signal_ptrs_gpu = nullptr; // Workspace void* workspace = nullptr; @@ -97,9 +96,6 @@ struct Buffer { volatile int* moe_recv_rdma_counter = nullptr; int* moe_recv_rdma_counter_mapped = nullptr; - private: - void move_fifo_slots(int num_slots = 1); - public: Buffer(int rank, int num_ranks, diff --git a/paddle/fluid/distributed/collective/deep_ep/include/types.h b/paddle/fluid/distributed/collective/deep_ep/include/types.h index a06d5ecec86656..7eae49ca723c45 100644 --- a/paddle/fluid/distributed/collective/deep_ep/include/types.h +++ b/paddle/fluid/distributed/collective/deep_ep/include/types.h @@ -73,6 +73,8 @@ struct Tensor { } int64_t element_size() const { return phi::SizeOf(raw_tensor_.dtype()); } + + int64_t stride(int64_t d) const { return raw_tensor_.strides().at(d); } }; } // namespace deep_ep::detail diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh b/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh index 48441020df7b5b..65b1f7ded134f0 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh @@ -26,8 +26,7 @@ namespace deep_ep { // Intranode runtime namespace intranode { -void barrier(int** task_fifo_ptrs, - int head, +void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); @@ -83,8 +82,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int num_memset_int, int expert_alignment, void** buffer_ptrs, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank, cudaStream_t stream, int num_sms); @@ -92,8 +90,7 @@ void notify_dispatch(const int* num_tokens_per_rank, void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); @@ -112,10 +109,13 @@ void dispatch(void* recv_x, const bool* is_token_in_rank, const int* channel_prefix_matrix, int num_tokens, + int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, + int scale_token_stride, + int scale_hidden_stride, void** buffer_ptrs, int rank, int num_ranks, @@ -129,8 +129,7 @@ void cached_notify_combine(void** buffer_ptrs, int num_channels, int num_recv_tokens, int num_memset_int, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); @@ -140,6 +139,8 @@ void combine(cudaDataType_t type, float* recv_topk_weights, const void* x, const float* topk_weights, + const void* bias_0, + const void* bias_1, const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, @@ -187,8 +188,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank, cudaStream_t stream, int64_t num_rdma_bytes, @@ -212,12 +212,14 @@ void dispatch(void* recv_x, const int* recv_rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, + const bool* is_token_in_rank, int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, - const bool* is_token_in_rank, + int scale_token_stride, + int scale_hidden_stride, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, @@ -246,8 +248,7 @@ void cached_notify(int hidden_int4, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank, cudaStream_t stream, int64_t num_rdma_bytes, @@ -261,6 +262,8 @@ void combine(cudaDataType_t type, const bool* is_combined_token_in_rank, const void* x, const float* topk_weights, + const void* bias_0, + const void* bias_1, const int* combined_rdma_head, const int* combined_nvl_head, const void* src_meta, diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/configs.cuh b/paddle/fluid/distributed/collective/deep_ep/kernels/configs.cuh index ecdee5cc217233..4d2036b55e53d4 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/configs.cuh +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/configs.cuh @@ -21,15 +21,20 @@ #define NUM_MAX_NVL_PEERS 8 #define NUM_MAX_RDMA_PEERS 20 -#define NUM_MAX_FIFO_SLOTS 32768 #define NUM_WORKSPACE_BYTES (32 * 1024 * 1024) #define NUM_MAX_LOCAL_EXPERTS 1024 #define NUM_BUFFER_ALIGNMENT_BYTES 128 #define FINISHED_SUM_TAG 1024 +#define NUM_WAIT_NANOSECONDS 500 + +#ifndef ENABLE_FAST_DEBUG #define NUM_CPU_TIMEOUT_SECS 100 #define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s -#define NUM_WAIT_NANOSECONDS 500 +#else +#define NUM_CPU_TIMEOUT_SECS 10 +#define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s +#endif #define LOW_LATENCY_SEND_PHASE 1 #define LOW_LATENCY_RECV_PHASE 2 @@ -38,11 +43,6 @@ #ifdef __CLION_IDE__ #define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier) #define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) -__host__ __device__ __forceinline__ void host_device_printf(const char* format, - ...) { - asm volatile("trap;"); -} -#define printf host_device_printf #endif #ifdef __CUDA_NO_HALF_CONVERSIONS__ diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/ibgda_device.cuh b/paddle/fluid/distributed/collective/deep_ep/kernels/ibgda_device.cuh index 88d66b93c0fe12..d135695db6a1d3 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/ibgda_device.cuh +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/ibgda_device.cuh @@ -99,7 +99,9 @@ __device__ static __forceinline__ nvshmemi_ibgda_device_qp_t *ibgda_get_rc( int pe, int id) { auto state = ibgda_get_state(); const auto num_rc_per_pe = ibgda_get_state()->num_rc_per_pe; - return &state->globalmem.rcs[pe * num_rc_per_pe + id % num_rc_per_pe]; + return &state->globalmem + .rcs[pe * num_rc_per_pe * state->num_devices_initialized + + id % (num_rc_per_pe * state->num_devices_initialized)]; } __device__ static __forceinline__ void ibgda_lock_acquire(int *lock) { @@ -244,22 +246,27 @@ ibgda_get_lkey_and_rkey(uint64_t laddr, uint64_t raddr, int dst_pe, uint64_t *out_raddr, - __be32 *out_rkey) { + __be32 *out_rkey, + uint32_t dev_idx) { auto state = ibgda_get_state(); auto heap_start = reinterpret_cast(nvshmemi_device_state_d.heap_base); auto log2_cumem_granularity = state->log2_cumem_granularity; // Local key - uint64_t idx = (laddr - heap_start) >> log2_cumem_granularity; + uint64_t idx = ((laddr - heap_start) >> log2_cumem_granularity) * + state->num_devices_initialized + + dev_idx; auto device_key = state->constmem.lkeys[idx]; auto lchunk_size = device_key.next_addr - laddr; *lkey = device_key.key; // Remote key uint64_t roffset = raddr - heap_start; - idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) + - dst_pe; + + idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) * + state->num_devices_initialized + + dst_pe * state->num_devices_initialized + dev_idx; if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) { device_key = state->constmem.rkeys[idx]; } else { @@ -278,15 +285,17 @@ ibgda_get_lkey_and_rkey(uint64_t laddr, __device__ static __forceinline__ void ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, - __be32 *out_rkey) { + __be32 *out_rkey, + uint32_t dev_idx) { auto state = ibgda_get_state(); auto heap_start = reinterpret_cast(nvshmemi_device_state_d.heap_base); uint64_t roffset = addr - heap_start; - uint64_t idx = ((roffset >> state->log2_cumem_granularity) * - nvshmemi_device_state_d.npes) + - dst_pe; + uint64_t idx = + ((roffset >> state->log2_cumem_granularity) * + nvshmemi_device_state_d.npes * state->num_devices_initialized) + + dst_pe * state->num_devices_initialized + dev_idx; nvshmemi_ibgda_device_key_t device_key; if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) device_key = state->constmem.rkeys[idx]; @@ -324,10 +333,11 @@ __device__ static __forceinline__ void nvshmemi_ibgda_rma_p( // NOTES: the `p` operation will not cross multiple remote chunks __be32 rkey; uint64_t raddr; - ibgda_get_rkey(reinterpret_cast(rptr), dst_pe, &raddr, &rkey); + auto qp = ibgda_get_rc(dst_pe, qp_id); + ibgda_get_rkey( + reinterpret_cast(rptr), dst_pe, &raddr, &rkey, qp->dev_idx); // Write WQEs - auto qp = ibgda_get_rc(dst_pe, qp_id); uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); void *wqe_ptrs; wqe_ptrs = ibgda_get_wqe_ptr(qp, base_wqe_idx); @@ -426,17 +436,21 @@ __device__ static __forceinline__ void nvshmemi_ibgda_put_nbi_warp( uint64_t my_raddr = 0; uint64_t my_chunk_size = 0; + auto qp = ibgda_get_rc(dst_pe, qp_id); + // Decide how many messages (theoretically 3 for maximum) auto remaining_bytes = bytes; while (remaining_bytes > 0) { - if (lane_id == num_wqes) + if (lane_id == num_wqes) { my_chunk_size = min(remaining_bytes, ibgda_get_lkey_and_rkey(my_laddr = req_lptr, &my_lkey, req_rptr, dst_pe, &my_raddr, - &my_rkey)); + &my_rkey, + qp->dev_idx)); + } // Move one more message auto chunk_size = @@ -449,7 +463,6 @@ __device__ static __forceinline__ void nvshmemi_ibgda_put_nbi_warp( EP_DEVICE_ASSERT(num_wqes <= 32); // Process WQE - auto qp = ibgda_get_rc(dst_pe, qp_id); uint64_t base_wqe_idx = 0; if (lane_id == 0) base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes); base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0); @@ -539,15 +552,14 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add( int qp_id, bool is_local_copy = false) { if (is_local_copy) { - // Fallback to NVSHMEM legacy API - nvshmemx_signal_op( - static_cast(rptr), value, NVSHMEM_SIGNAL_ADD, pe); + atomicAdd(static_cast(rptr), value); } else { nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id); __be32 rkey; uint64_t raddr; - ibgda_get_rkey(reinterpret_cast(rptr), pe, &raddr, &rkey); + ibgda_get_rkey( + reinterpret_cast(rptr), pe, &raddr, &rkey, qp->dev_idx); uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx); @@ -565,4 +577,56 @@ __device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add( } } +__device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t &ptr, + const int &rank, + const int &dst_rank) { + // Local rank, no need for mapping + if (rank == dst_rank) return ptr; + auto peer_base = __ldg( + reinterpret_cast(nvshmemi_device_state_d.peer_heap_base_p2p) + + dst_rank); + + // RDMA connected + if (peer_base == 0) return 0; + + // NVLink P2P is enabled + return peer_base + + (ptr - reinterpret_cast(nvshmemi_device_state_d.heap_base)); +} + +// This is a simplified version of NVSHMEM's `ibgda_poll_cq`. +// Note that this implementation does not guarantee thread safety, +// so we must ensure that no other threads are concurrently using the same QP. +__device__ static __forceinline__ void ibgda_poll_cq( + nvshmemi_ibgda_device_cq_t *cq, uint64_t idx) { + const auto cqe64 = static_cast(cq->cqe); + const uint32_t ncqes = cq->ncqes; + memory_fence_cta(); + + // NOTES: this while loop is part of do-while below. + // `wqe_counter` is the HW consumer index. However, we always maintain `index + // + 1`. To be able to compare with the index, we need to use `wqe_counter + + // 1`. Because `wqe_counter` is `uint16_t`, it may be overflow. Still, we know + // for sure that if `idx - wqe_counter - 1 < ncqes`, `wqe_counter + 1 is less + // than idx, and thus we need to wait. We don't need to wait when `idx == + // wqe_counter + 1` That's why we use `- 2` here to make this case overflow. + uint16_t wqe_counter; + do { + wqe_counter = HtoBE16(ld_na_relaxed(&cqe64->wqe_counter)); + } while ((static_cast(static_cast(idx) - wqe_counter - + static_cast(2)) < ncqes)); + *cq->cons_idx = idx; + + // Prevent reordering of this function and later instructions + memory_fence_cta(); +} + +// Wait until wqe `idx - 1` is completed. +__device__ static __forceinline__ void nvshmemi_ibgda_quiet(int dst_pe, + int qp_id) { + auto qp = ibgda_get_rc(dst_pe, qp_id); + uint64_t prod_idx = ld_na_relaxed(qp->tx_wq.prod_idx); + ibgda_poll_cq(qp->tx_wq.cq, prod_idx); +} + } // namespace deep_ep diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu b/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu index afdd0009833009..a6c4ce7cd41a82 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu @@ -46,7 +46,6 @@ struct SourceMeta { __forceinline__ SourceMeta() = default; - // TODO(Xreki): faster encoding __device__ __forceinline__ SourceMeta(int rdma_rank, const bool* is_token_in_nvl_ranks) { src_rdma_rank = rdma_rank; @@ -66,7 +65,7 @@ EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, int get_source_meta_bytes() { return sizeof(SourceMeta); } -__host__ __device__ __forceinline__ int get_num_bytes_per_rdma_token( +__host__ __device__ __forceinline__ int get_num_bytes_per_token( int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) { return static_cast( align(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + @@ -82,13 +81,13 @@ __host__ __device__ __forceinline__ std::pair get_rdma_clean_meta( int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens, - int num_sms) { + int num_channels) { // Return `int32_t` offset and count to clean - return {(get_num_bytes_per_rdma_token( + return {(get_num_bytes_per_token( hidden_int4, num_scales, num_topk_idx, num_topk_weights) * - num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / + num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_channels) / sizeof(int), - (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms}; + (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_channels}; } __host__ __device__ __forceinline__ std::pair get_nvl_clean_meta( @@ -99,18 +98,19 @@ __host__ __device__ __forceinline__ std::pair get_nvl_clean_meta( int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens, - int num_sms) { + int num_channels, + bool is_dispatch) { // Return `int32_t` offset and to clean EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); + return { (num_nvl_recv_buffer_tokens * - (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + - num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float) + - sizeof(SourceMeta)) * - num_nvl_ranks * num_sms) / + get_num_bytes_per_token( + hidden_int4, num_scales, num_topk_idx, num_topk_weights) * + num_nvl_ranks * num_channels) / sizeof(int), - num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms, + num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels, }; } @@ -122,9 +122,9 @@ __forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank, } template -__forceinline__ __device__ void nvshmem_barrier_with_same_gpu_idx( +__forceinline__ __device__ void nvshmem_sync_with_same_gpu_idx( const nvshmem_team_t& rdma_team) { - kLowLatencyMode ? void(nvshmem_barrier(rdma_team)) : nvshmem_barrier_all(); + kLowLatencyMode ? void(nvshmem_sync(rdma_team)) : nvshmem_sync_all(); } template @@ -150,8 +150,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int* recv_gbl_rank_prefix_sum, void* rdma_buffer_ptr, void** buffer_ptrs, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank, const nvshmem_team_t rdma_team) { auto sm_id = static_cast(blockIdx.x); @@ -166,18 +165,16 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, if (sm_id == 0) { // Communication with others - // Global barrier: the first warp do intra-node sync, the second warp do + // Global barrier: the first warp does intra-node sync, the second warp does // internode sync EP_DEVICE_ASSERT(num_warps > 1); EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads); if (thread_id == 32) - nvshmem_barrier_with_same_gpu_idx(rdma_team); - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); - __syncthreads(); + nvshmem_sync_with_same_gpu_idx(rdma_team); + barrier_block(barrier_signal_ptrs, nvl_rank); // Send numbers of tokens per rank/expert to RDMA ranks - auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); + auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); auto rdma_recv_num_tokens_mixed = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, @@ -208,18 +205,39 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, __syncthreads(); // Issue send - // TODO(Xreki): more light fence or barrier or signaling - // TODO(Xreki): overlap EP barrier and NVL cleaning - if (thread_id < kNumRDMARanks) { - nvshmem_int_put_nbi( - rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), - rdma_recv_num_tokens_mixed.send_buffer(thread_id), - NUM_MAX_NVL_PEERS + num_rdma_experts + 1, - translate_dst_rdma_rank(thread_id, nvl_rank)); + for (int i = warp_id; i < kNumRDMARanks; i += num_warps) { + if (i != rdma_rank) { + nvshmemi_ibgda_put_nbi_warp( + reinterpret_cast( + rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank)), + reinterpret_cast( + rdma_recv_num_tokens_mixed.send_buffer(i)), + (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * sizeof(int), + translate_dst_rdma_rank(i, nvl_rank), + 0, + lane_id, + 0); + } else { + UNROLLED_WARP_COPY(1, + lane_id, + NUM_MAX_NVL_PEERS + num_rdma_experts + 1, + rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), + rdma_recv_num_tokens_mixed.send_buffer(i), + ld_volatile_global, + st_na_global); + } } __syncthreads(); + + // Wait previous operations to be finished + if (thread_id < kNumRDMARanks && thread_id != rdma_rank) + nvshmemi_ibgda_quiet( + translate_dst_rdma_rank(thread_id, nvl_rank), 0); + __syncthreads(); + + // Barrier if (thread_id == 0) - nvshmem_barrier_with_same_gpu_idx(rdma_team); + nvshmem_sync_with_same_gpu_idx(rdma_team); __syncthreads(); // NVL buffers @@ -239,7 +257,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, AsymBuffer(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); // Clean up for later data dispatch - auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); + auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes + nvl_send_num_tokens_per_expert.total_bytes <= @@ -249,7 +267,6 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; // Reduce number of tokens per expert into the NVL send buffer - // TODO(Xreki): may use NVSHMEM reduction EP_DEVICE_ASSERT(num_rdma_experts <= num_threads); if (thread_id < num_rdma_experts) { int sum = 0; @@ -287,13 +304,9 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i]; } - memory_fence(); - __syncthreads(); - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); - __syncthreads(); + barrier_block(barrier_signal_ptrs, nvl_rank); - // Reduce number of tokens per rank/expert + // Reduce the number of tokens per rank/expert EP_DEVICE_ASSERT(num_nvl_experts <= num_threads); if (thread_id == 0) { int sum = 0; @@ -321,11 +334,9 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, } // Finally barrier - __syncthreads(); if (thread_id == 32) - nvshmem_barrier_with_same_gpu_idx(rdma_team); - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); + nvshmem_sync_with_same_gpu_idx(rdma_team); + barrier_block(barrier_signal_ptrs, nvl_rank); } else { // Calculate meta data int dst_rdma_rank = sm_id - 1; @@ -412,8 +423,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank, cudaStream_t stream, int64_t num_rdma_bytes, @@ -448,8 +458,7 @@ void notify_dispatch(const int* num_tokens_per_rank, recv_gbl_rank_prefix_sum, \ rdma_buffer_ptr, \ buffer_ptrs, \ - task_fifo_ptrs, \ - head, \ + barrier_signal_ptrs, \ rank, \ cpu_rdma_team); \ } \ @@ -473,7 +482,8 @@ void notify_dispatch(const int* num_tokens_per_rank, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, - num_channels); + num_channels, + true); EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); @@ -496,6 +506,7 @@ constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { template __global__ void __launch_bounds__( @@ -517,12 +528,14 @@ __global__ void __launch_bounds__( const int* recv_rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, + const bool* is_token_in_rank, int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, - const bool* is_token_in_rank, + int scale_token_stride, + int scale_hidden_stride, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, @@ -539,18 +552,19 @@ __global__ void __launch_bounds__( kNVLReceivers }; + const auto num_sms = static_cast(gridDim.x); const auto sm_id = static_cast(blockIdx.x); const auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; const auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); - const auto num_channels = static_cast(gridDim.x) / 2, - channel_id = sm_id / 2; + const auto num_channels = num_sms / 2, channel_id = sm_id / 2; const bool is_forwarder = sm_id % 2 == 0; const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; - EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_channels); + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels || + ibgda_get_state()->num_rc_per_pe >= num_sms); const auto role_meta = [=]() -> std::pair { if (is_forwarder) { @@ -582,14 +596,15 @@ __global__ void __launch_bounds__( EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); auto hidden_bytes = hidden_int4 * sizeof(int4); - auto num_bytes_per_rdma_token = - get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk); - auto rdma_channel_data = SymBuffer( - rdma_buffer_ptr, - num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, - kNumRDMARanks, - channel_id, - num_channels); + auto scale_bytes = num_scales * sizeof(float); + auto num_bytes_per_token = + get_num_bytes_per_token(hidden_int4, num_scales, num_topk, num_topk); + auto rdma_channel_data = + SymBuffer(rdma_buffer_ptr, + num_max_rdma_chunked_recv_tokens * num_bytes_per_token, + kNumRDMARanks, + channel_id, + num_channels); auto rdma_channel_meta = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, @@ -616,44 +631,12 @@ __global__ void __launch_bounds__( // Allocate buffers auto nvl_channel_x = - AsymBuffer(ws_rr_buffer_ptr, - num_max_nvl_chunked_recv_tokens * hidden_int4, - NUM_MAX_NVL_PEERS, - channel_id, - num_channels, - rs_wr_rank) - .advance_also(rs_wr_buffer_ptr); - auto nvl_channel_src_meta = - AsymBuffer(ws_rr_buffer_ptr, - num_max_nvl_chunked_recv_tokens, - NUM_MAX_NVL_PEERS, - channel_id, - num_channels, - rs_wr_rank) - .advance_also(rs_wr_buffer_ptr); - auto nvl_channel_x_scales = - AsymBuffer(ws_rr_buffer_ptr, - num_max_nvl_chunked_recv_tokens * num_scales, - NUM_MAX_NVL_PEERS, - channel_id, - num_channels, - rs_wr_rank) - .advance_also(rs_wr_buffer_ptr); - auto nvl_channel_topk_idx = - AsymBuffer(ws_rr_buffer_ptr, - num_max_nvl_chunked_recv_tokens * num_topk, - NUM_MAX_NVL_PEERS, - channel_id, - num_channels, - rs_wr_rank) - .advance_also(rs_wr_buffer_ptr); - auto nvl_channel_topk_weights = - AsymBuffer(ws_rr_buffer_ptr, - num_max_nvl_chunked_recv_tokens * num_topk, - NUM_MAX_NVL_PEERS, - channel_id, - num_channels, - rs_wr_rank) + AsymBuffer(ws_rr_buffer_ptr, + num_max_nvl_chunked_recv_tokens * num_bytes_per_token, + NUM_MAX_NVL_PEERS, + channel_id, + num_channels, + rs_wr_rank) .advance_also(rs_wr_buffer_ptr); auto nvl_channel_prefix_start = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, @@ -685,14 +668,32 @@ __global__ void __launch_bounds__( .advance_also(rs_wr_buffer_ptr); // RDMA sender warp synchronization - __shared__ volatile int rdma_send_next_token_idx; - __shared__ volatile int rdma_send_channel_tail[kNumRDMARanks]; - __shared__ volatile int rdma_send_channel_next_tail[kNumRDMARanks]; + // NOTES: `rdma_send_channel_tail` means the latest released tail + // NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status + __shared__ int rdma_send_channel_lock[kNumRDMARanks]; + __shared__ int rdma_send_channel_tail[kNumRDMARanks]; + __shared__ uint32_t rdma_send_channel_window[kNumRDMARanks]; auto sync_rdma_sender_smem = []() { asm volatile( "bar.sync 0, %0;" ::"r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; + // TMA stuffs + extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; + auto tma_buffer = smem_tma_buffer + target_rank * kNumTMABytesPerWarp; + auto tma_mbarrier = reinterpret_cast(tma_buffer + hidden_bytes); + uint32_t tma_phase = 0; + if ((warp_role == WarpRole::kRDMAAndNVLForwarder || + warp_role == WarpRole::kNVLReceivers) && + lane_id == 0) { + mbarrier_init(tma_mbarrier, 1); + fence_view_async_shared(); + fence_barrier_init(); + EP_DEVICE_ASSERT(num_bytes_per_token + sizeof(uint64_t) <= + kNumTMABytesPerWarp); + } + __syncwarp(); + // Forward warp synchronization __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS] [kNumRDMARanks]; @@ -707,18 +708,6 @@ __global__ void __launch_bounds__( get_channel_task_range( num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); - // Clean shared memory - EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); - (warp_id == 0 && lane_id == 0) - ? (rdma_send_next_token_idx = token_start_idx) - : 0; - (warp_id == 0 && lane_id < kNumRDMARanks) - ? (rdma_send_channel_tail[lane_id] = 0) - : 0; - (warp_id == 0 && lane_id < kNumRDMARanks) - ? (rdma_send_channel_next_tail[lane_id] = 0) - : 0; - // Send number of tokens in this channel by `-value - 1` EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); @@ -757,6 +746,7 @@ __global__ void __launch_bounds__( 1; } __syncwarp(); + // Issue RDMA for non-local ranks if (dst_rdma_rank != rdma_rank) { nvshmemi_ibgda_put_nbi_warp( @@ -775,32 +765,49 @@ __global__ void __launch_bounds__( // Iterate over tokens and copy into buffer int64_t token_idx; - int cached_rdma_channel_head = 0, last_rdma_tail_idx = -1; + int cached_rdma_channel_head = 0, global_rdma_tail_idx = 0; auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id); - for (token_idx = token_start_idx + warp_id; token_idx < token_end_idx; - token_idx += kNumDispatchRDMASenderWarps) { + for (token_idx = token_start_idx; token_idx < token_end_idx; ++token_idx) { // Read RDMA rank existence uint64_t is_token_in_rank_uint64 = 0; - if (lane_id < kNumRDMARanks) - is_token_in_rank_uint64 = *reinterpret_cast( + if (lane_id < kNumRDMARanks) { + is_token_in_rank_uint64 = __ldg(reinterpret_cast( is_token_in_rank + token_idx * num_ranks + - lane_id * NUM_MAX_NVL_PEERS); - - // Acquire sequential lock - while (lane_id == 0 && rdma_send_next_token_idx != token_idx) { + lane_id * NUM_MAX_NVL_PEERS)); + global_rdma_tail_idx += (is_token_in_rank_uint64 != 0); } __syncwarp(); - // Acquire next tail - int rdma_tail_idx = -1; - if (is_token_in_rank_uint64 != 0) { - rdma_tail_idx = rdma_send_channel_next_tail[lane_id]++; - while (rdma_tail_idx - cached_rdma_channel_head >= - num_max_rdma_chunked_recv_tokens) - cached_rdma_channel_head = static_cast( - ld_volatile_global(rdma_channel_head.buffer(lane_id))); + // Skip the token which does not belong to this warp + if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps != + warp_id) + continue; + auto rdma_tail_idx = + is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1; + + // Wait the remote buffer to be released + auto start_time = clock64(); + while (is_token_in_rank_uint64 != 0 && + rdma_tail_idx - cached_rdma_channel_head >= + num_max_rdma_chunked_recv_tokens) { + cached_rdma_channel_head = static_cast( + ld_volatile_global(rdma_channel_head.buffer(lane_id))); + + // Timeout check + if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) { + printf( + "DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, " + "nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n", + channel_id, + rdma_rank, + nvl_rank, + lane_id, + cached_rdma_channel_head, + rdma_tail_idx); + trap(); + } } __syncwarp(); @@ -808,15 +815,6 @@ __global__ void __launch_bounds__( if (lane_id < kNumRDMARanks && !kCachedMode) send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; - // Update last token tail - if (last_rdma_tail_idx >= 0) - st_release_cta(const_cast(rdma_send_channel_tail + lane_id), - last_rdma_tail_idx + 1); - last_rdma_tail_idx = rdma_tail_idx; - - // Release sequential lock - lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; - // Broadcast tails SourceMeta src_meta; int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; @@ -834,7 +832,7 @@ __global__ void __launch_bounds__( src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values); dst_send_buffers[num_topk_ranks++] = reinterpret_cast(broadcast(send_buffer, i)) + - slot_idx * num_bytes_per_rdma_token; + slot_idx * num_bytes_per_token; } EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks); @@ -857,19 +855,11 @@ __global__ void __launch_bounds__( dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + hidden_int4; - // Copy source metadata into symmetric send buffer - if (lane_id < num_topk_ranks) - st_na_global(reinterpret_cast(dst_send_buffers[lane_id]), - src_meta); -#pragma unroll - for (int i = 0; i < num_topk_ranks; ++i) - dst_send_buffers[i] = - reinterpret_cast(dst_send_buffers[i]) + 1; - // Copy `x_scales` into symmetric send buffer #pragma unroll for (int i = lane_id; i < num_scales; i += 32) { - auto value = ld_nc_global(x_scales + token_idx * num_scales + i); + auto offset = token_idx * scale_token_stride + i * scale_hidden_stride; + auto value = ld_nc_global(x_scales + offset); #pragma unroll for (int j = 0; j < num_topk_ranks; ++j) st_na_global(reinterpret_cast(dst_send_buffers[j]) + i, @@ -880,6 +870,15 @@ __global__ void __launch_bounds__( dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + num_scales; + // Copy source metadata into symmetric send buffer + if (lane_id < num_topk_ranks) + st_na_global(reinterpret_cast(dst_send_buffers[lane_id]), + src_meta); +#pragma unroll + for (int i = 0; i < num_topk_ranks; ++i) + dst_send_buffers[i] = + reinterpret_cast(dst_send_buffers[i]) + 1; + // Copy `topk_idx` and `topk_weights` into symmetric send buffer #pragma unroll for (int i = lane_id; i < num_topk * num_topk_ranks; i += 32) { @@ -895,27 +894,49 @@ __global__ void __launch_bounds__( num_topk + copy_idx, weight_value); } - } + __syncwarp(); - // Epilogue - // Acquire sequential lock - while (lane_id == 0 && rdma_send_next_token_idx != token_idx) { - } - __syncwarp(); + // Release the transaction in the window + if (is_token_in_rank_uint64 != 0) { + // Acquire lock first + acquire_lock(rdma_send_channel_lock + lane_id); + auto latest_tail = rdma_send_channel_tail[lane_id]; + auto offset = rdma_tail_idx - latest_tail; + while (offset >= 32) { + release_lock(rdma_send_channel_lock + lane_id); + acquire_lock(rdma_send_channel_lock + lane_id); + latest_tail = rdma_send_channel_tail[lane_id]; + offset = rdma_tail_idx - latest_tail; + } - // Update last token tail - if (last_rdma_tail_idx >= 0) - st_release_cta(const_cast(rdma_send_channel_tail + lane_id), - last_rdma_tail_idx + 1); + // Release the transaction slot + // Add the bit and move the ones if possible + auto window = rdma_send_channel_window[lane_id] | (1u << offset); + if (offset == 0) { + auto num_empty_slots = (~window) == 0 ? 32 : __ffs(~window) - 1; + st_release_cta(rdma_send_channel_tail + lane_id, + latest_tail + num_empty_slots); + window >>= num_empty_slots; + } + rdma_send_channel_window[lane_id] = window; - // Release sequential lock - lane_id == 0 ? (rdma_send_next_token_idx += 1) : 0; + // Release lock + release_lock(rdma_send_channel_lock + lane_id); + } + __syncwarp(); + } } else if (warp_role == WarpRole::kRDMASenderCoordinator) { - // NOTES: in case of splitting the issued put at the end of the buffer + // NOTES: in case of splitting, the issued put at the end of the buffer EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); + // Clean shared memory + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); + (lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0; + (lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0; + (lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0; + // Synchronize shared memory sync_rdma_sender_smem(); @@ -931,20 +952,39 @@ __global__ void __launch_bounds__( // Iterate all RDMA ranks int last_issued_tail = 0; + auto start_time = clock64(); while (__any_sync(0xffffffff, num_tokens_to_send > 0)) { + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES && + lane_id < kNumRDMARanks) { + printf( + "DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl " + "%d, dst IB: %d, tail: %d, remaining: %d\n", + channel_id, + rdma_rank, + nvl_rank, + lane_id, + last_issued_tail, + num_tokens_to_send); + trap(); + } + for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++i) { // To mitigate incast congestion, shuffle the starting index of target - // rank for different ranks and channel + // rank for different ranks and channels int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks; synced_num_tokens_to_send = __shfl_sync(0xffffffff, num_tokens_to_send, dst_rdma_rank); if (synced_num_tokens_to_send == 0) continue; - // Read progress + // Read the latest progress + // NOTES: `rdma_send_channel_tail` does not need to be protected by lock + auto processed_tail = + __shfl_sync(0xffffffff, + ld_acquire_cta(rdma_send_channel_tail + dst_rdma_rank), + 0); auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank); - auto processed_tail = ld_acquire_cta( - const_cast(rdma_send_channel_tail + dst_rdma_rank)); auto num_tokens_processed = processed_tail - synced_last_issued_tail; if (num_tokens_processed != synced_num_tokens_to_send && num_tokens_processed < num_max_rdma_chunked_send_tokens) @@ -961,13 +1001,13 @@ __global__ void __launch_bounds__( EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); const size_t num_bytes_per_msg = - num_bytes_per_rdma_token * num_tokens_to_issue; + num_bytes_per_token * num_tokens_to_issue; const auto dst_ptr = reinterpret_cast( rdma_channel_data.recv_buffer(rdma_rank) + - dst_slot_idx * num_bytes_per_rdma_token); + dst_slot_idx * num_bytes_per_token); const auto src_ptr = reinterpret_cast( rdma_channel_data.send_buffer(dst_rdma_rank) + - dst_slot_idx * num_bytes_per_rdma_token); + dst_slot_idx * num_bytes_per_token); nvshmemi_ibgda_put_nbi_warp( dst_ptr, src_ptr, @@ -980,9 +1020,9 @@ __global__ void __launch_bounds__( // Lighter fence for local RDMA rank memory_fence(); } + __syncwarp(); // Update tails - __syncwarp(); if (lane_id == dst_rdma_rank) { last_issued_tail += num_tokens_to_issue; num_tokens_to_send -= num_tokens_to_issue; @@ -993,15 +1033,12 @@ __global__ void __launch_bounds__( channel_id, dst_rdma_rank == rdma_rank); } + __syncwarp(); } } } else if (warp_role == WarpRole::kRDMAAndNVLForwarder) { // RDMA consumers and NVL producers const auto dst_nvl_rank = target_rank; - const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank; - const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks); - const auto dst_rank_expert_end = - dst_rank_expert_begin + (num_experts / num_ranks); // Wait counters to arrive int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0; @@ -1079,15 +1116,17 @@ __global__ void __launch_bounds__( while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) { // Check destination queue emptiness, or wait a buffer to be released start_time = clock64(); - while (lane_id == 0) { - int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head; + while (true) { + const int num_used_slots = + cached_nvl_channel_tail - cached_nvl_channel_head; if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens) break; - cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer()); + cached_nvl_channel_head = __shfl_sync( + 0xffffffffu, ld_volatile_global(nvl_channel_head.buffer()), 0); // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + if (lane_id == 0 && clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf( "DeepEP dispatch forwarder timeout (NVL check), channel: %d, " "RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n", @@ -1100,7 +1139,6 @@ __global__ void __launch_bounds__( trap(); } } - __syncwarp(); // Find next source RDMA rank (round-robin) start_time = clock64(); @@ -1144,10 +1182,10 @@ __global__ void __launch_bounds__( // Iterate over every token from the RDMA buffer for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++i) { auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; - void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + - rdma_slot_idx * num_bytes_per_rdma_token; + auto shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + + rdma_slot_idx * num_bytes_per_token; auto src_meta = ld_nc_global(reinterpret_cast( - reinterpret_cast(shifted) + hidden_bytes)); + shifted + hidden_bytes + scale_bytes)); lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0; bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); if (lane_id == src_rdma_rank) { @@ -1160,61 +1198,28 @@ __global__ void __launch_bounds__( // Get an empty slot int dst_slot_idx = (cached_nvl_channel_tail++) % num_max_nvl_chunked_recv_tokens; + auto dst_shifted = + nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token; // Copy data - UNROLLED_WARP_COPY(5, - lane_id, - hidden_int4, - nvl_channel_x.buffer() + dst_slot_idx * hidden_int4, - reinterpret_cast(shifted), - ld_nc_global, - st_na_global); - shifted = reinterpret_cast(shifted) + hidden_int4; - - // Copy source meta - if (lane_id == 0) - st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta); - shifted = reinterpret_cast(shifted) + 1; - - // Copy `x_scales` - UNROLLED_WARP_COPY( - 1, - lane_id, - num_scales, - nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales, - reinterpret_cast(shifted), - ld_nc_global, - st_na_global); - shifted = reinterpret_cast(shifted) + num_scales; - - // Copy `topk_idx` and `topk_weights` - // NOTES: do not use `shifted` after this `if`, because only several - // lanes are shifted - if (lane_id < num_topk) { - // Read - auto idx_value = - ld_nc_global(reinterpret_cast(shifted) + lane_id); - shifted = reinterpret_cast(shifted) + num_topk; - auto weight_value = - ld_nc_global(reinterpret_cast(shifted) + lane_id); - - // Transform and write - idx_value = (idx_value >= dst_rank_expert_begin && - idx_value < dst_rank_expert_end) - ? idx_value - dst_rank_expert_begin - : -1; - st_na_global( - nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, - idx_value); - weight_value = idx_value >= 0 ? weight_value : 0.0f; - st_na_global(nvl_channel_topk_weights.buffer() + - dst_slot_idx * num_topk + lane_id, - weight_value); + if (lane_id == 0) { + tma_load_1d( + tma_buffer, shifted, tma_mbarrier, num_bytes_per_token, false); + mbarrier_arrive_and_expect_tx(tma_mbarrier, num_bytes_per_token); } + __syncwarp(); + mbarrier_wait(tma_mbarrier, tma_phase); + if (lane_id == 0) + tma_store_1d(tma_buffer, dst_shifted, num_bytes_per_token); + __syncwarp(); // In case of insufficient NVL buffers, early stopping if ((++num_tokens_sent) == num_max_nvl_chunked_send_tokens) src_rdma_tail = i + 1; + + // Wait TMA to be finished + tma_store_wait(); + __syncwarp(); } // Sync head index @@ -1266,7 +1271,7 @@ __global__ void __launch_bounds__( rdma_channel_head.buffer(rdma_rank), min_head - last_head, translate_dst_rdma_rank(lane_id, nvl_rank), - channel_id, + channel_id + num_channels, lane_id == rdma_rank); last_head = min_head; } @@ -1279,6 +1284,9 @@ __global__ void __launch_bounds__( // Retrieve rank offset from barrier results (each lane's register stores an // RDMA rank) int src_nvl_rank = target_rank, total_offset = 0; + const int local_expert_begin = rank * (num_experts / num_ranks); + const int local_expert_end = local_expert_begin + (num_experts / num_ranks); + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); if (lane_id < kNumRDMARanks && lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) @@ -1328,14 +1336,14 @@ __global__ void __launch_bounds__( while (num_tokens_to_recv > 0) { // Check channel status by lane 0 start_time = clock64(); - while (lane_id == 0) { + while (true) { // Ready to copy if (cached_channel_head_idx != cached_channel_tail_idx) break; - cached_channel_tail_idx = - ld_acquire_sys_global(nvl_channel_tail.buffer()); + cached_channel_tail_idx = __shfl_sync( + 0xffffffff, ld_acquire_sys_global(nvl_channel_tail.buffer()), 0); // Timeout check - if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + if (lane_id == 0 && clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf( "DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, " "nvl: %d, src NVL: %d, head: %d, tail: %d\n", @@ -1349,61 +1357,86 @@ __global__ void __launch_bounds__( } } - // Sync queue tail - cached_channel_tail_idx = - __shfl_sync(0xffffffff, cached_channel_tail_idx, 0); - // Copy data int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++chunk_idx, --num_tokens_to_recv) { int token_idx_in_buffer = (cached_channel_head_idx++) % num_max_nvl_chunked_recv_tokens; - auto meta = - ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer); + auto shifted = + nvl_channel_x.buffer() + token_idx_in_buffer * num_bytes_per_token; + auto meta = ld_nc_global(reinterpret_cast( + shifted + hidden_bytes + scale_bytes)); int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank); (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; + bool scale_aligned = (scale_bytes % 16 == 0); + auto tma_load_bytes = hidden_bytes + (scale_aligned ? scale_bytes : 0); + // Copy data - UNROLLED_WARP_COPY( - 5, - lane_id, - hidden_int4, - recv_x + recv_token_idx * hidden_int4, - nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4, - ld_nc_global, - st_na_global); + if (lane_id == 0) { + tma_load_1d(tma_buffer, shifted, tma_mbarrier, tma_load_bytes); + mbarrier_arrive_and_expect_tx(tma_mbarrier, tma_load_bytes); + } + __syncwarp(); + mbarrier_wait(tma_mbarrier, tma_phase); + if (lane_id == 0) + tma_store_1d(tma_buffer, + recv_x + recv_token_idx * hidden_int4, + hidden_bytes, + false); + __syncwarp(); + shifted += hidden_bytes; + + // Copy scales + if (scale_aligned) { + tma_store_1d(tma_buffer + hidden_bytes, + recv_x_scales + recv_token_idx * num_scales, + scale_bytes, + false); + } else { + UNROLLED_WARP_COPY(1, + lane_id, + num_scales, + recv_x_scales + recv_token_idx * num_scales, + reinterpret_cast(shifted), + ld_nc_global, + st_na_global); + } + shifted += scale_bytes; // Copy source meta if (lane_id == 0 && !kCachedMode) st_na_global(recv_src_meta + recv_token_idx, meta); - - // Copy scales - UNROLLED_WARP_COPY( - 1, - lane_id, - num_scales, - recv_x_scales + recv_token_idx * num_scales, - nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales, - ld_nc_global, - st_na_global); + shifted += sizeof(SourceMeta); // Copy `topk_idx` and `topk_weights` if (lane_id < num_topk) { + // Read + auto idx_value = static_cast( + ld_nc_global(reinterpret_cast(shifted) + lane_id)); + auto weight_value = ld_nc_global( + reinterpret_cast(shifted + sizeof(int) * num_topk) + + lane_id); auto recv_idx = recv_token_idx * num_topk + lane_id; - auto buffer_idx = token_idx_in_buffer * num_topk + lane_id; - st_na_global(recv_topk_idx + recv_idx, - static_cast(ld_nc_global( - nvl_channel_topk_idx.buffer() + buffer_idx))); - st_na_global( - recv_topk_weights + recv_idx, - ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx)); + + // Transform and write + idx_value = + (idx_value >= local_expert_begin && idx_value < local_expert_end) + ? idx_value - local_expert_begin + : -1; + weight_value = idx_value >= 0 ? weight_value : 0.0f; + st_na_global(recv_topk_idx + recv_idx, idx_value); + st_na_global(recv_topk_weights + recv_idx, weight_value); } + + // Wait TMA to be finished + tma_store_wait(); + __syncwarp(); } // Move queue - __syncwarp(); if (lane_id == 0) st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); @@ -1428,12 +1461,14 @@ void dispatch(void* recv_x, const int* recv_rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, + const bool* is_token_in_rank, int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, - const bool* is_token_in_rank, + int scale_token_stride, + int scale_hidden_stride, void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, @@ -1447,6 +1482,12 @@ void dispatch(void* recv_x, int num_channels, bool low_latency_mode) { constexpr int kNumDispatchRDMASenderWarps = 7; + constexpr int kNumTMABytesPerWarp = 16384; + constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS; + + // Make sure never OOB + EP_HOST_ASSERT(static_cast(num_scales) * scale_hidden_stride < + std::numeric_limits::max()); #define DISPATCH_LAUNCH_CASE(num_rdma_ranks) \ { \ @@ -1455,19 +1496,24 @@ void dispatch(void* recv_x, ? (is_cached_dispatch ? dispatch \ : dispatch) \ : (is_cached_dispatch ? dispatch \ : dispatch); \ + SET_SHARED_MEMORY_FOR_TMA(dispatch_func); \ LAUNCH_KERNEL(&cfg, \ dispatch_func, \ reinterpret_cast(recv_x), \ @@ -1487,12 +1533,14 @@ void dispatch(void* recv_x, recv_rdma_rank_prefix_sum, \ gbl_channel_prefix_matrix, \ recv_gbl_rank_prefix_sum, \ + is_token_in_rank, \ num_tokens, \ hidden_int4, \ num_scales, \ num_topk, \ num_experts, \ - is_token_in_rank, \ + scale_token_stride, \ + scale_hidden_stride, \ rdma_buffer_ptr, \ num_max_rdma_chunked_send_tokens, \ num_max_rdma_chunked_recv_tokens, \ @@ -1528,8 +1576,7 @@ __global__ void cached_notify(const int rdma_clean_offset, int* combined_nvl_head, void* rdma_buffer_ptr, void** buffer_ptrs, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank, int num_ranks, bool is_cached_dispatch, @@ -1547,39 +1594,30 @@ __global__ void cached_notify(const int rdma_clean_offset, // Using two SMs, which clean the RDMA/NVL buffer respectively if (sm_id == 0) { // Barrier for RDMA - if (thread_id == 0) - nvshmem_barrier_with_same_gpu_idx(rdma_team); - __syncthreads(); + if (thread_id == 32) + nvshmem_sync_with_same_gpu_idx(rdma_team); - // Clean - auto rdma_buffer_ptr_int = reinterpret_cast(rdma_buffer_ptr); + // Barrier for NVL + barrier_block(barrier_signal_ptrs, nvl_rank); + + // Clean RDMA buffer + auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); #pragma unroll for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; - nvshmem_fence(); - __syncthreads(); - - // Barrier again - if (thread_id == 0) - nvshmem_barrier_with_same_gpu_idx(rdma_team); - } else if (sm_id == 1) { - // Barrier for NVL - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); - __syncthreads(); - // Clean - auto nvl_buffer_ptr_int = reinterpret_cast(buffer_ptrs[nvl_rank]); + // Clean NVL buffer + auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); #pragma unroll for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; - memory_fence(); __syncthreads(); // Barrier again - barrier_device(task_fifo_ptrs, head, nvl_rank); - move_fifo_slots(head); - } else if (sm_id == 2) { + if (thread_id == 32) + nvshmem_sync_with_same_gpu_idx(rdma_team); + barrier_block(barrier_signal_ptrs, nvl_rank); + } else if (sm_id == 1) { if (is_cached_dispatch) return; EP_DEVICE_ASSERT(num_warps >= num_channels); @@ -1617,8 +1655,8 @@ __global__ void cached_notify(const int rdma_clean_offset, EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers"); if (lane_id < NUM_MAX_NVL_PEERS && warp_id < num_channels) { - for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks; - dst_rdma_rank += num_channels * 2 - 3) { + for (int dst_rdma_rank = sm_id - 2; dst_rdma_rank < num_rdma_ranks; + dst_rdma_rank += num_channels * 2 - 2) { // Iterate in reverse order int token_start_idx = warp_id == 0 @@ -1665,8 +1703,7 @@ void cached_notify(int hidden_int4, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank, cudaStream_t stream, int64_t num_rdma_bytes, @@ -1691,7 +1728,8 @@ void cached_notify(int hidden_int4, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, - num_channels); + num_channels, + is_cached_dispatch); EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); @@ -1719,8 +1757,7 @@ void cached_notify(int hidden_int4, combined_nvl_head, rdma_buffer_ptr, buffer_ptrs, - task_fifo_ptrs, - head, + barrier_signal_ptrs, rank, num_ranks, is_cached_dispatch, @@ -1728,6 +1765,7 @@ void cached_notify(int hidden_int4, } template (&bias_0_value_int4); + auto bias_1_values = reinterpret_cast(&bias_1_value_int4); +#pragma unroll + for (int j = 0; j < kDtypePerInt4; ++j) + values[j] = static_cast(bias_0_values[j]) + + static_cast(bias_1_values[j]); + } + +// Reduce all-to-all results #pragma unroll for (int j = 0; j < num_topk_ranks; ++j) { auto recv_value_dtypes = @@ -1805,19 +1864,21 @@ template < int kNumRDMARanks, typename dtype_t, int kNumCombineForwarderWarps, + int kNumTMABytesPerWarp, int kNumTopkRDMARanks = get_num_topk_rdma_ranks(kNumRDMARanks), int kNumWarpsPerForwarder = (kNumCombineForwarderWarps / kNumRDMARanks > 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, int kNumForwarders = kNumRDMARanks* kNumWarpsPerForwarder, - int kNumRDMAReceivers = kNumForwarders + NUM_MAX_NVL_PEERS> -__global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, - 1) + int kNumRDMAReceivers = kNumForwarders - NUM_MAX_NVL_PEERS> +__global__ void __launch_bounds__((kNumForwarders + 1) * 32, 1) combine(int4* combined_x, float* combined_topk_weights, const bool* is_combined_token_in_rank, const int4* x, const float* topk_weights, + const int4* bias_0, + const int4* bias_1, const int* combined_rdma_head, const int* combined_nvl_head, const SourceMeta* src_meta, @@ -1849,32 +1910,34 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, const auto thread_id = static_cast(threadIdx.x), lane_id = get_lane_id(); const auto num_channels = static_cast(gridDim.x) / 2, channel_id = sm_id / 2; - const bool is_rdma_receiver_sm = sm_id % 2 == 1; + const bool is_forwarder_sm = sm_id % 2 == 1; EP_DEVICE_ASSERT(num_topk <= 32); EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0); const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t)); + const auto hidden_bytes = hidden_int4 * sizeof(int4); + const auto num_bytes_per_token = + get_num_bytes_per_token(hidden_int4, 0, 0, num_topk); // NOTES: we decouple a channel into 2 SMs const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; auto role_meta = [=]() -> std::pair { auto warp_id = thread_id / 32; - if (!is_rdma_receiver_sm) { + if (!is_forwarder_sm) { if (warp_id < NUM_MAX_NVL_PEERS) { auto shuffled_warp_id = warp_id; shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; return {WarpRole::kNVLSender, shuffled_warp_id}; - } else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { - auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS; - shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders; - return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id}; + } else if (warp_id < kNumForwarders) { + return {WarpRole::kRDMAReceiver, warp_id - NUM_MAX_NVL_PEERS}; } else { return {WarpRole::kCoordinator, 0}; } } else { - if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { - return {WarpRole::kRDMAReceiver, warp_id}; + if (warp_id < kNumForwarders) { + auto shuffled_warp_id = (warp_id + channel_id) % kNumForwarders; + return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id}; } else { return {WarpRole::kCoordinator, 0}; } @@ -1883,7 +1946,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, auto warp_role = role_meta.first; auto warp_id = role_meta.second; - EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + kNumForwarders + 1); + EP_DEVICE_ASSERT(num_warps == kNumForwarders + 1); auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks; @@ -1896,30 +1959,14 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, // sources auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank]; - auto nvl_channel_x = - AsymBuffer(dst_buffer_ptr, - num_max_nvl_chunked_recv_tokens * hidden_int4, - NUM_MAX_NVL_PEERS, - channel_id, - num_channels, - nvl_rank) - .advance_also(local_buffer_ptr); - auto nvl_channel_src_meta = - AsymBuffer(dst_buffer_ptr, - num_max_nvl_chunked_recv_tokens, - NUM_MAX_NVL_PEERS, - channel_id, - num_channels, - nvl_rank) - .advance_also(local_buffer_ptr); - auto nvl_channel_topk_weights = - AsymBuffer(dst_buffer_ptr, - num_max_nvl_chunked_recv_tokens * num_topk, - NUM_MAX_NVL_PEERS, - channel_id, - num_channels, - nvl_rank) - .advance_also(local_buffer_ptr); + auto nvl_channel_x = AsymBuffer(dst_buffer_ptr, + num_max_nvl_chunked_recv_tokens * + num_bytes_per_token, + NUM_MAX_NVL_PEERS, + channel_id, + num_channels, + nvl_rank) + .advance_also(local_buffer_ptr); auto nvl_channel_head = AsymBuffer(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, @@ -1935,6 +1982,19 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, nvl_rank) .advance_also(local_buffer_ptr); + // TMA stuffs + extern __shared__ __align__(1024) uint8_t smem_tma_buffer[]; + auto tma_buffer = smem_tma_buffer + dst_nvl_rank * kNumTMABytesPerWarp; + auto tma_mbarrier = reinterpret_cast(tma_buffer + hidden_bytes); + uint32_t tma_phase = 0; + if (lane_id == 0) { + mbarrier_init(tma_mbarrier, 1); + fence_view_async_shared(); + fence_barrier_init(); + EP_DEVICE_ASSERT(hidden_bytes + sizeof(uint64_t) <= kNumTMABytesPerWarp); + } + __syncwarp(); + // Get tasks for each RDMA lane int token_start_idx = 0, token_end_idx = 0; if (lane_id < kNumRDMARanks) { @@ -1954,11 +2014,12 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); // Iterate over all tokens and send by chunks + int current_rdma_idx = channel_id % kNumRDMARanks; while (true) { // Exit if possible if (__all_sync(0xffffffff, token_start_idx >= token_end_idx)) break; - // Decide next RDMA buffer to send + // Decide the next RDMA buffer to send bool is_lane_ready = false; auto start_time = clock64(); while (true) { @@ -1995,8 +2056,8 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, } // Sync token start index and count - for (int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; - ++current_rdma_idx) { + for (int i = 0; i < kNumRDMARanks; ++i) { + current_rdma_idx = (current_rdma_idx + 1) % kNumRDMARanks; if (__shfl_sync(0xffffffff, (token_start_idx >= token_end_idx) || (!is_lane_ready), current_rdma_idx)) @@ -2026,29 +2087,36 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, current_rdma_idx); - // Copy data + // Load data auto shifted_x_buffers = - nvl_channel_x.buffer() + dst_slot_idx * hidden_int4; + nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token; auto shifted_x = x + token_idx * hidden_int4; - UNROLLED_WARP_COPY(5, - lane_id, - hidden_int4, - shifted_x_buffers, - shifted_x, - ld_nc_global, - st_na_global); + if (lane_id == 0) { + tma_store_wait(); + tma_load_1d(tma_buffer, shifted_x, tma_mbarrier, hidden_bytes); + mbarrier_arrive_and_expect_tx(tma_mbarrier, hidden_bytes); + } + __syncwarp(); + mbarrier_wait(tma_mbarrier, tma_phase); - // Copy source meta - if (lane_id == 0) - st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, - ld_nc_global(src_meta + token_idx)); + // Load source meta + if (lane_id == num_topk) + *reinterpret_cast(tma_buffer + hidden_bytes) = + ld_nc_global(src_meta + token_idx); - // Copy `topk_weights` + // Load `topk_weights` if (lane_id < num_topk) - st_na_global( - nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + - lane_id, - ld_nc_global(topk_weights + token_idx * num_topk + lane_id)); + *reinterpret_cast(tma_buffer + hidden_bytes + + sizeof(SourceMeta) + + lane_id * sizeof(float)) = + ld_nc_global(topk_weights + token_idx * num_topk + lane_id); + + // Issue TMA store + tma_store_fence(); + __syncwarp(); + if (lane_id == 0) + tma_store_1d( + tma_buffer, shifted_x_buffers, num_bytes_per_token, false); } lane_id == current_rdma_idx ? (token_start_idx = static_cast(token_idx)) @@ -2056,6 +2124,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, } // Move queue tail + tma_store_wait(); __syncwarp(); if (lane_id < kNumRDMARanks && is_lane_ready) st_release_sys_global(nvl_channel_tail.buffer() + lane_id, @@ -2064,12 +2133,9 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, } else { // Combiners and coordinators // RDMA symmetric layout - auto hidden_bytes = hidden_int4 * sizeof(int4); - auto num_bytes_per_rdma_token = - get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk); auto rdma_channel_data = SymBuffer( rdma_buffer_ptr, - num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, + num_max_rdma_chunked_recv_tokens * num_bytes_per_token, kNumRDMARanks, channel_id, num_channels); @@ -2083,27 +2149,13 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, void* nvl_buffers[NUM_MAX_NVL_PEERS]; #pragma unroll for (int i = 0; i < NUM_MAX_NVL_PEERS; ++i) nvl_buffers[i] = buffer_ptrs[i]; - auto nvl_channel_x = - AsymBuffer(local_nvl_buffer, - num_max_nvl_chunked_recv_tokens * hidden_int4, - NUM_MAX_NVL_PEERS, - channel_id, - num_channels) - .advance_also(nvl_buffers); - auto nvl_channel_src_meta = - AsymBuffer(local_nvl_buffer, - num_max_nvl_chunked_recv_tokens, - NUM_MAX_NVL_PEERS, - channel_id, - num_channels) - .advance_also(nvl_buffers); - auto nvl_channel_topk_weights = - AsymBuffer(local_nvl_buffer, - num_max_nvl_chunked_recv_tokens * num_topk, - NUM_MAX_NVL_PEERS, - channel_id, - num_channels) - .advance_also(nvl_buffers); + auto nvl_channel_x = AsymBuffer(local_nvl_buffer, + num_max_nvl_chunked_recv_tokens * + num_bytes_per_token, + NUM_MAX_NVL_PEERS, + channel_id, + num_channels) + .advance_also(nvl_buffers); auto nvl_channel_head = AsymBuffer(nvl_buffers, kNumRDMARanks, @@ -2155,11 +2207,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, // Advance to the corresponding NVL buffer nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * - hidden_int4); - nvl_channel_src_meta.advance(dst_rdma_rank * - num_max_nvl_chunked_recv_tokens_per_rdma); - nvl_channel_topk_weights.advance( - dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk); + num_bytes_per_token); nvl_channel_head.advance(dst_rdma_rank); nvl_channel_tail.advance(dst_rdma_rank); @@ -2262,27 +2310,33 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, // Combine current token auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; - void* shifted = - send_buffer + rdma_slot_idx * num_bytes_per_rdma_token; + void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_token; auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { - return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + - slot_idx * hidden_int4 + hidden_int4_idx); + return ld_nc_global( + reinterpret_cast(nvl_channel_x.buffer(src_nvl_rank) + + slot_idx * num_bytes_per_token) + + hidden_int4_idx); }; auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { - return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + - slot_idx * num_topk + topk_idx); + return ld_nc_global( + reinterpret_cast(nvl_channel_x.buffer(src_nvl_rank) + + slot_idx * num_bytes_per_token + + hidden_bytes + sizeof(SourceMeta)) + + topk_idx); }; - combine_token( + combine_token( expected_head >= 0, expected_head, lane_id, hidden_int4, num_topk, - reinterpret_cast(shifted), - reinterpret_cast(reinterpret_cast(shifted) + + static_cast(shifted), + reinterpret_cast(static_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), + nullptr, + nullptr, num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn); @@ -2301,13 +2355,13 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; const size_t num_bytes_per_msg = - num_chunked_tokens * num_bytes_per_rdma_token; + num_chunked_tokens * num_bytes_per_token; const auto dst_ptr = reinterpret_cast( rdma_channel_data.recv_buffer(rdma_rank) + - rdma_slot_idx * num_bytes_per_rdma_token); + rdma_slot_idx * num_bytes_per_token); const auto src_ptr = reinterpret_cast( rdma_channel_data.send_buffer(dst_rdma_rank) + - rdma_slot_idx * num_bytes_per_rdma_token); + rdma_slot_idx * num_bytes_per_token); nvshmemi_ibgda_put_nbi_warp( dst_ptr, src_ptr, @@ -2323,7 +2377,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, // Write new RDMA tail __syncwarp(); - if (lane_id == 0) + if (lane_id == 0) { nvshmemi_ibgda_amo_nonfetch_add( rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, @@ -2331,6 +2385,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); + } } } @@ -2398,18 +2453,18 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast( rdma_channel_data.recv_buffer(src_rdma_rank) + - slot_idx * num_bytes_per_rdma_token) + + slot_idx * num_bytes_per_token) + hidden_int4_idx); }; auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast( rdma_channel_data.recv_buffer(src_rdma_rank) + - slot_idx * num_bytes_per_rdma_token + + slot_idx * num_bytes_per_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx); }; - combine_token( + combine_token( expected_head >= 0, expected_head, lane_id, @@ -2417,6 +2472,8 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, num_topk, combined_x + token_idx * hidden_int4, combined_topk_weights + token_idx * num_topk, + bias_0 == nullptr ? nullptr : bias_0 + token_idx * hidden_int4, + bias_1 == nullptr ? nullptr : bias_1 + token_idx * hidden_int4, num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn); @@ -2428,7 +2485,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, } else { // Coordinator // Sync shared memory status - is_rdma_receiver_sm ? sync_rdma_receiver_smem() : sync_forwarder_smem(); + is_forwarder_sm ? sync_forwarder_smem() : sync_rdma_receiver_smem(); const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks; int last_rdma_head = 0; @@ -2439,18 +2496,17 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, "Invalid number of forwarder warps"); while (true) { // Retired - if (is_rdma_receiver_sm && - __all_sync( - 0xffffffff, - lane_id >= kNumRDMAReceivers || rdma_receiver_retired[lane_id])) + if (!is_forwarder_sm && __all_sync(0xffffffff, + lane_id >= kNumRDMAReceivers || + rdma_receiver_retired[lane_id])) break; - if (!is_rdma_receiver_sm && + if (is_forwarder_sm && __all_sync(0xffffffff, lane_id >= kNumForwarders || forwarder_retired[lane_id])) break; // Find minimum head for RDMA ranks - if (is_rdma_receiver_sm) { + if (!is_forwarder_sm) { int min_head = std::numeric_limits::max(); #pragma unroll for (int i = 0; i < kNumRDMAReceivers; ++i) @@ -2465,7 +2521,7 @@ __global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, min_head - last_rdma_head, translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), - channel_id, + channel_id + num_channels, dst_rdma_rank == rdma_rank); last_rdma_head = min_head; } @@ -2501,6 +2557,8 @@ void combine(cudaDataType_t type, const bool* is_combined_token_in_rank, const void* x, const float* topk_weights, + const void* bias_0, + const void* bias_1, const int* combined_rdma_head, const int* combined_nvl_head, const void* src_meta, @@ -2523,50 +2581,57 @@ void combine(cudaDataType_t type, int num_channels, bool low_latency_mode) { constexpr int kNumCombineForwarderWarps = 16; + constexpr int kNumTMABytesPerWarp = 16384; + constexpr int smem_size = kNumTMABytesPerWarp * NUM_MAX_NVL_PEERS; -#define COMBINE_LAUNCH_CASE(num_rdma_ranks) \ - { \ - auto combine_func = low_latency_mode ? combine \ - : combine; \ - LAUNCH_KERNEL(&cfg, \ - combine_func, \ - reinterpret_cast(combined_x), \ - combined_topk_weights, \ - is_combined_token_in_rank, \ - reinterpret_cast(x), \ - topk_weights, \ - combined_rdma_head, \ - combined_nvl_head, \ - reinterpret_cast(src_meta), \ - rdma_channel_prefix_matrix, \ - rdma_rank_prefix_sum, \ - gbl_channel_prefix_matrix, \ - num_tokens, \ - num_combined_tokens, \ - hidden, \ - num_topk, \ - rdma_buffer_ptr, \ - num_max_rdma_chunked_send_tokens, \ - num_max_rdma_chunked_recv_tokens, \ - buffer_ptrs, \ - num_max_nvl_chunked_send_tokens, \ - num_max_nvl_chunked_recv_tokens, \ - rank, \ - num_ranks); \ - } \ +#define COMBINE_LAUNCH_CASE(num_rdma_ranks) \ + { \ + auto combine_func = low_latency_mode ? combine \ + : combine; \ + SET_SHARED_MEMORY_FOR_TMA(combine_func); \ + LAUNCH_KERNEL(&cfg, \ + combine_func, \ + reinterpret_cast(combined_x), \ + combined_topk_weights, \ + is_combined_token_in_rank, \ + reinterpret_cast(x), \ + topk_weights, \ + reinterpret_cast(bias_0), \ + reinterpret_cast(bias_1), \ + combined_rdma_head, \ + combined_nvl_head, \ + reinterpret_cast(src_meta), \ + rdma_channel_prefix_matrix, \ + rdma_rank_prefix_sum, \ + gbl_channel_prefix_matrix, \ + num_tokens, \ + num_combined_tokens, \ + hidden, \ + num_topk, \ + rdma_buffer_ptr, \ + num_max_rdma_chunked_send_tokens, \ + num_max_rdma_chunked_recv_tokens, \ + buffer_ptrs, \ + num_max_nvl_chunked_send_tokens, \ + num_max_nvl_chunked_recv_tokens, \ + rank, \ + num_ranks); \ + } \ break int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1); int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder; - EP_HOST_ASSERT(num_forwarder_warps > 0 && + EP_HOST_ASSERT(num_forwarder_warps > NUM_MAX_NVL_PEERS && num_forwarder_warps % num_rdma_ranks == 0); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > @@ -2574,9 +2639,7 @@ void combine(cudaDataType_t type, num_max_nvl_chunked_send_tokens)); EP_HOST_ASSERT(type == CUDA_R_16BF); - SETUP_LAUNCH_CONFIG(num_channels * 2, - (NUM_MAX_NVL_PEERS + num_forwarder_warps + 1) * 32, - stream); + SETUP_LAUNCH_CONFIG(num_channels * 2, (num_forwarder_warps + 1) * 32, stream); SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE); #undef COMBINE_LAUNCH_CASE } diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/intranode.cu b/paddle/fluid/distributed/collective/deep_ep/kernels/intranode.cu index 10b8664fcd1fe2..e16016bbe26cc1 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/intranode.cu +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/intranode.cu @@ -43,8 +43,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, int num_memset_int, int expert_alignment, void** buffer_ptrs, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank) { auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x), @@ -54,13 +53,11 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, if (sm_id == 0) { // Barrier first - barrier_device(task_fifo_ptrs, head, rank); - move_fifo_slots(head); - __syncthreads(); + barrier_block(barrier_signal_ptrs, rank); int *per_rank_buffer, *per_expert_buffer; if (thread_id < kNumRanks) { - per_rank_buffer = reinterpret_cast(buffer_ptrs[thread_id]); + per_rank_buffer = static_cast(buffer_ptrs[thread_id]); per_expert_buffer = per_rank_buffer + kNumRanks * kNumRanks; } @@ -79,16 +76,13 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i]; } - __syncthreads(); // Wait for all ranks to be finished - barrier_device(task_fifo_ptrs, head, rank); - move_fifo_slots(head); - __syncthreads(); + barrier_block(barrier_signal_ptrs, rank); // Sum per-rank counts and return to CPU // Also pre-compute the prefix sum for data sending - auto local_per_rank_buffer = reinterpret_cast(buffer_ptrs[rank]); + auto local_per_rank_buffer = static_cast(buffer_ptrs[rank]); if (thread_id < kNumRanks) { #pragma unroll for (int i = 1; i < kNumRanks; ++i) @@ -123,9 +117,7 @@ __global__ void notify_dispatch(const int* num_tokens_per_rank, local_per_expert_buffer[i] = 0; // Barrier - memory_fence(); - __syncthreads(); - barrier_device(task_fifo_ptrs, head, rank); + barrier_block(barrier_signal_ptrs, rank); } else { int dst_rank = sm_id - 1; for (int channel_id = warp_id; channel_id < num_channels; @@ -167,8 +159,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int num_memset_int, int expert_alignment, void** buffer_ptrs, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank, cudaStream_t stream, int num_channels) { @@ -188,8 +179,7 @@ void notify_dispatch(const int* num_tokens_per_rank, num_memset_int, \ expert_alignment, \ buffer_ptrs, \ - task_fifo_ptrs, \ - head, \ + barrier_signal_ptrs, \ rank); \ break @@ -207,36 +197,30 @@ template __global__ void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank) { // A simplified version for cached handles - barrier_device(task_fifo_ptrs, head, rank); - move_fifo_slots(head); - __syncthreads(); + barrier_block(barrier_signal_ptrs, rank); // Copy and clean auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); - auto ptr = reinterpret_cast(buffer_ptrs[rank]); + auto ptr = static_cast(buffer_ptrs[rank]); #pragma unroll for (int i = thread_id; i < kNumRanks * kNumRanks; i += num_threads) ptr[i] = rank_prefix_matrix[i]; #pragma unroll for (int i = thread_id; i < num_memset_int; i += num_threads) ptr[kNumRanks * kNumRanks + i] = 0; - memory_fence(); - __syncthreads(); // Barrier after cleaning - barrier_device(task_fifo_ptrs, head, rank); + barrier_block(barrier_signal_ptrs, rank); } void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream) { @@ -246,8 +230,7 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, rank_prefix_matrix, \ num_memset_int, \ buffer_ptrs, \ - task_fifo_ptrs, \ - head, \ + barrier_signal_ptrs, \ rank); \ break @@ -256,7 +239,7 @@ void cached_notify_dispatch(const int* rank_prefix_matrix, #undef CACHED_NOTIFY_DISPATCH_LAUNCH_CASE } -template +template __global__ void __launch_bounds__(kNumThreads, 1) dispatch(int4* recv_x, float* recv_x_scales, @@ -272,17 +255,20 @@ __global__ void __launch_bounds__(kNumThreads, 1) const bool* is_token_in_rank, const int* channel_prefix_matrix, int num_tokens, + int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, + int scale_token_stride, + int scale_hidden_stride, void** buffer_ptrs, int rank, int num_max_send_tokens, int num_recv_buffer_tokens) { const auto num_sms = static_cast(gridDim.x), sm_id = static_cast(blockIdx.x); - const auto thread_id = static_cast(threadIdx.x); + const auto thread_id = static_cast(threadIdx.x), lane_id = get_lane_id(); const bool is_sender = sm_id % 2 == 0; EP_DEVICE_ASSERT(num_sms % 2 == 0); @@ -304,8 +290,7 @@ __global__ void __launch_bounds__(kNumThreads, 1) // Calculate pointers by the specific layout // `rank_prefix_matrix`: kNumRanks * kNumRanks * sizeof(int) auto ptr = reinterpret_cast( - reinterpret_cast( - buffer_ptrs[is_sender ? responsible_rank : rank]) + + static_cast(buffer_ptrs[is_sender ? responsible_rank : rank]) + kNumRanks * kNumRanks * sizeof(int)); int target_rank = is_sender ? rank : responsible_rank; auto num_channels_total = num_channels * kNumRanks; @@ -357,12 +342,31 @@ __global__ void __launch_bounds__(kNumThreads, 1) num_channels_total * num_recv_buffer_tokens * num_scales, channel_rank_offset * num_recv_buffer_tokens * num_scales); + // TMA stuffs +#ifndef DISABLE_SM90_FEATURES + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + auto half_hidden_int4 = hidden_int4 / 2; + auto half_hidden_bytes = half_hidden_int4 * static_cast(sizeof(int4)); + auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp; + auto tma_mbarrier = + reinterpret_cast(tma_buffer + half_hidden_bytes); + uint32_t tma_phase = 0; + if (lane_id == 0) { + mbarrier_init(tma_mbarrier, 1); + fence_view_async_shared(); + fence_barrier_init(); + EP_DEVICE_ASSERT(hidden_int4 % 2 == 0 && + half_hidden_bytes + sizeof(uint64_t) <= + kNumTMABytesPerWarp); + } + __syncwarp(); +#endif + if (is_sender) { // Workers for sending constexpr int num_send_warps = kNumThreads / 32; constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; const auto send_thread_id = thread_id; - const auto send_lane_id = send_thread_id % 32; const auto send_warp_id_in_rank = send_thread_id % num_threads_per_rank / 32; EP_DEVICE_ASSERT(kNumRanks <= 32); @@ -370,7 +374,7 @@ __global__ void __launch_bounds__(kNumThreads, 1) // Send offset by `-value - 1`, e.g. 0 -> -1, 1 -> -2 // NOTES: this is for distinguishing zero tokens - if (send_lane_id == 0 && send_warp_id_in_rank == 0) { + if (lane_id == 0 && send_warp_id_in_rank == 0) { int value = responsible_channel > 0 ? channel_prefix_matrix[responsible_rank * num_channels + responsible_channel - 1] @@ -397,7 +401,7 @@ __global__ void __launch_bounds__(kNumThreads, 1) // (rare cases) NOTES: the head index received by different warps may not // be the same auto start_time = clock64(); - while (send_lane_id == 0) { + while (lane_id == 0) { // NOTES: we only consider the worst case, because counting the real // numbers are time-consuming int num_used_slots = cached_channel_tail_idx - @@ -421,8 +425,8 @@ __global__ void __launch_bounds__(kNumThreads, 1) while (chunk_token_idx < num_max_send_tokens && token_idx < token_end_idx) { // NOTES: for the same token, the warp assigned to save `send_head` may - // be different from the warp assigned to send subsequent data - if (send_lane_id == 0 && + // be different from the warp assigned to send the following data + if (lane_id == 0 && token_idx % num_send_warps_per_rank == send_warp_id_in_rank) send_head[token_idx * kNumRanks + responsible_rank] = is_token_in_rank[token_idx * kNumRanks + responsible_rank] @@ -444,7 +448,7 @@ __global__ void __launch_bounds__(kNumThreads, 1) channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; auto shifted_x = x + token_idx * hidden_int4; UNROLLED_WARP_COPY(5, - send_lane_id, + lane_id, hidden_int4, shifted_channel_x_buffers, shifted_x, @@ -452,36 +456,38 @@ __global__ void __launch_bounds__(kNumThreads, 1) st_na_global); // Copy source index - if (send_lane_id == 0) + if (lane_id == 0) channel_src_idx_buffers[dst_slot_idx] = static_cast(token_idx); // Copy `topk_idx` and `topk_weights` with transformed index - if (send_lane_id < num_topk) { + if (lane_id < num_topk) { // Top-k index int recv_expert_begin = responsible_rank * num_experts_per_rank, recv_expert_end = (responsible_rank + 1) * num_experts_per_rank; - auto idx_value = - __ldg(topk_idx + token_idx * num_topk + send_lane_id); + auto idx_value = __ldg(topk_idx + token_idx * num_topk + lane_id); idx_value = (idx_value >= recv_expert_begin && idx_value < recv_expert_end) ? idx_value - recv_expert_begin : -1; - channel_topk_idx_buffers[dst_slot_idx * num_topk + send_lane_id] = + channel_topk_idx_buffers[dst_slot_idx * num_topk + lane_id] = idx_value; // Top-k weights auto weight_value = - __ldg(topk_weights + token_idx * num_topk + send_lane_id); + __ldg(topk_weights + token_idx * num_topk + lane_id); weight_value = (idx_value >= 0) ? weight_value : 0.0f; - channel_topk_weights_buffers[dst_slot_idx * num_topk + - send_lane_id] = weight_value; + channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] = + weight_value; } // Copy `x_scales` #pragma unroll - for (int i = send_lane_id; i < num_scales; i += 32) + for (int i = lane_id; i < num_scales; i += 32) { + auto offset = + token_idx * scale_token_stride + i * scale_hidden_stride; channel_x_scales_buffers[dst_slot_idx * num_scales + i] = - __ldg(x_scales + token_idx * num_scales + i); + __ldg(x_scales + offset); + } } // Move token index @@ -492,7 +498,7 @@ __global__ void __launch_bounds__(kNumThreads, 1) // NOTES: here all warps should share the same new tail asm volatile("bar.sync %0, %1;" ::"r"(responsible_rank), "r"(num_threads_per_rank)); - if (send_warp_id_in_rank == 0 && send_lane_id == 0) + if (send_warp_id_in_rank == 0 && lane_id == 0) st_release_sys_global(channel_tail_idx.buffer(), cached_channel_tail_idx); } @@ -501,14 +507,13 @@ __global__ void __launch_bounds__(kNumThreads, 1) constexpr int num_recv_warps = kNumThreads / 32; constexpr int num_recv_warps_per_rank = num_recv_warps / kNumRanks; const auto recv_thread_id = thread_id; - const auto recv_lane_id = recv_thread_id % 32; const auto recv_thread_id_in_rank = recv_thread_id % num_threads_per_rank; const auto recv_warp_id_in_rank = recv_thread_id_in_rank / 32; EP_DEVICE_ASSERT(kNumRanks <= 32); EP_DEVICE_ASSERT(recv_thread_id >= 0 && num_recv_warps % kNumRanks == 0); // Calculate offset first - auto rank_prefix_matrix = reinterpret_cast(buffer_ptrs[rank]); + auto rank_prefix_matrix = static_cast(buffer_ptrs[rank]); int rank_offset = responsible_rank > 0 ? rank_prefix_matrix[(responsible_rank - 1) * kNumRanks + rank] @@ -516,13 +521,13 @@ __global__ void __launch_bounds__(kNumThreads, 1) // Receive channel offset int total_offset, num_tokens_to_recv; - while (recv_lane_id == 0 && (total_offset = ld_volatile_global( - channel_start_offset.buffer())) == 0) { + while (lane_id == 0 && (total_offset = ld_volatile_global( + channel_start_offset.buffer())) == 0) { } - while (recv_lane_id == 0 && (num_tokens_to_recv = ld_volatile_global( - channel_end_offset.buffer())) == 0) { + while (lane_id == 0 && (num_tokens_to_recv = ld_volatile_global( + channel_end_offset.buffer())) == 0) { } - if (recv_lane_id == 0) { + if (lane_id == 0) { total_offset = -total_offset - 1, num_tokens_to_recv = -num_tokens_to_recv - 1; if (recv_warp_id_in_rank == 0) @@ -541,11 +546,10 @@ __global__ void __launch_bounds__(kNumThreads, 1) int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; while (num_tokens_to_recv > 0) { // NOTES: unlike the sender, the receiver must ensure that the tail - // indices hold by different warps are same + // indices hold by different warps are the same while (recv_thread_id_in_rank == 0) { cached_channel_tail_idx = ld_acquire_sys_global(channel_tail_idx.buffer()); - {} // Ready to copy if (cached_channel_head_idx != cached_channel_tail_idx) { @@ -581,13 +585,32 @@ __global__ void __launch_bounds__(kNumThreads, 1) auto shifted_recv_x_int4 = recv_x + static_cast(total_offset + chunk_idx) * hidden_int4; +#ifndef DISABLE_SM90_FEATURES +#pragma unroll + for (int i = 0; i < 2; ++i) + if (lane_id == 0) { + tma_store_wait(); + tma_load_1d(tma_buffer, + shifted_buffer_x_int4 + i * half_hidden_int4, + tma_mbarrier, + half_hidden_bytes); + mbarrier_arrive_and_expect_tx(tma_mbarrier, half_hidden_bytes); + mbarrier_wait(tma_mbarrier, tma_phase); + tma_store_1d(tma_buffer, + shifted_recv_x_int4 + i * half_hidden_int4, + half_hidden_bytes, + false); + } + __syncwarp(); +#else UNROLLED_WARP_COPY(5, - recv_lane_id, + lane_id, hidden_int4, shifted_recv_x_int4, shifted_buffer_x_int4, ld_nc_global, st_na_global); +#endif } // Copy `src_idx` @@ -635,14 +658,31 @@ __global__ void __launch_bounds__(kNumThreads, 1) total_offset += num_recv_tokens; asm volatile("bar.sync %0, %1;" ::"r"(responsible_rank), "r"(num_threads_per_rank)); - if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 && - recv_lane_id == 0) + if (recv_warp_id_in_rank == num_recv_warps_per_rank - 1 && lane_id == 0) st_relaxed_sys_global(channel_head_idx.buffer(), cached_channel_head_idx); // Exit num_tokens_to_recv -= num_recv_tokens; } + + // Make TMA store visible to the next kernel +#ifndef DISABLE_SM90_FEATURES + if (lane_id == 0) tma_store_wait(); +#endif + } + + // Clean unused `recv_topk_idx` as -1 + if (num_worst_tokens > 0) { + auto rank_prefix_matrix = static_cast(buffer_ptrs[rank]); + const auto num_recv_tokens = + rank_prefix_matrix[(kNumRanks - 1) * kNumRanks + rank]; + const auto clean_start = num_recv_tokens * num_topk + sm_id * kNumThreads; + const auto clean_end = num_worst_tokens * num_topk; + const auto clean_stride = num_sms * kNumThreads; +#pragma unroll + for (int i = clean_start + thread_id; i < clean_end; i += clean_stride) + recv_topk_idx[i] = -1; } } @@ -660,10 +700,13 @@ void dispatch(void* recv_x, const bool* is_token_in_rank, const int* channel_prefix_matrix, int num_tokens, + int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, + int scale_token_stride, + int scale_hidden_stride, void** buffer_ptrs, int rank, int num_ranks, @@ -671,33 +714,48 @@ void dispatch(void* recv_x, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { - constexpr int kNumThreads = 512; - -#define DISPATCH_LAUNCH_CASE(ranks) \ - LAUNCH_KERNEL(&cfg, \ - dispatch, \ - reinterpret_cast(recv_x), \ - recv_x_scales, \ - recv_src_idx, \ - recv_topk_idx, \ - recv_topk_weights, \ - recv_channel_offset, \ - send_head, \ - reinterpret_cast(x), \ - x_scales, \ - topk_idx, \ - topk_weights, \ - is_token_in_rank, \ - channel_prefix_matrix, \ - num_tokens, \ - hidden_int4, \ - num_topk, \ - num_experts, \ - num_scales, \ - buffer_ptrs, \ - rank, \ - num_max_send_tokens, \ - num_recv_buffer_tokens); \ + constexpr int kNumThreads = 768; + constexpr int kNumTMABytesPerWarp = 8192; +#ifndef DISABLE_SM90_FEATURES + constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32); +#endif + + // Make sure never OOB + EP_HOST_ASSERT(static_cast(num_scales) * scale_hidden_stride < + std::numeric_limits::max()); + +#define DISPATCH_LAUNCH_CASE(ranks) \ + { \ + auto kernel = dispatch; \ + SET_SHARED_MEMORY_FOR_TMA(kernel); \ + LAUNCH_KERNEL(&cfg, \ + kernel, \ + reinterpret_cast(recv_x), \ + recv_x_scales, \ + recv_src_idx, \ + recv_topk_idx, \ + recv_topk_weights, \ + recv_channel_offset, \ + send_head, \ + reinterpret_cast(x), \ + x_scales, \ + topk_idx, \ + topk_weights, \ + is_token_in_rank, \ + channel_prefix_matrix, \ + num_tokens, \ + num_worst_tokens, \ + hidden_int4, \ + num_topk, \ + num_experts, \ + num_scales, \ + scale_token_stride, \ + scale_hidden_stride, \ + buffer_ptrs, \ + rank, \ + num_max_send_tokens, \ + num_recv_buffer_tokens); \ + } \ break // Even-numbered blocks for sending, odd-numbered blocks for receiving. @@ -713,27 +771,22 @@ __global__ void cached_notify_combine(void** buffer_ptrs, int num_channels, int num_recv_tokens, int num_memset_int, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank) { const auto sm_id = static_cast(blockIdx.x); if (sm_id == 0) { // Barrier before cleaning - barrier_device(task_fifo_ptrs, head, rank); - move_fifo_slots(head); - __syncthreads(); + barrier_block(barrier_signal_ptrs, rank); // Clean auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); - auto ptr = reinterpret_cast(buffer_ptrs[rank]); + auto ptr = static_cast(buffer_ptrs[rank]); #pragma unroll for (int i = thread_id; i < num_memset_int; i += num_threads) ptr[i] = 0; - memory_fence(); - __syncthreads(); // Barrier after cleaning - barrier_device(task_fifo_ptrs, head, rank); + barrier_block(barrier_signal_ptrs, rank); } else { const auto channel_id = sm_id - 1; const auto thread_id = static_cast(threadIdx.x); @@ -760,7 +813,7 @@ __global__ void cached_notify_combine(void** buffer_ptrs, ? __ldg(send_head + token_idx * kNumRanks + rank_id) : -1; for (int i = 0; i < min(32, token_idx_tail - token_start_idx + 1); ++i) { - head = __shfl_sync(0xffffffff, current_head, i); + const int head = __shfl_sync(0xffffffff, current_head, i); if (head < 0) { if (lane_id == i) expected_head = -last_head - 1; } else { @@ -778,8 +831,7 @@ void cached_notify_combine(void** buffer_ptrs, int num_channels, int num_recv_tokens, int num_memset_int, - int** task_fifo_ptrs, - int head, + int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream) { @@ -791,8 +843,7 @@ void cached_notify_combine(void** buffer_ptrs, num_channels, \ num_recv_tokens, \ num_memset_int, \ - task_fifo_ptrs, \ - head, \ + barrier_signal_ptrs, \ rank); \ break @@ -805,12 +856,17 @@ void cached_notify_combine(void** buffer_ptrs, #undef CACHED_NOTIFY_COMBINE } -template +template __global__ void __launch_bounds__(kNumThreads, 1) combine(dtype_t* recv_x, float* recv_topk_weights, const dtype_t* x, const float* topk_weights, + const dtype_t* bias_0, + const dtype_t* bias_1, const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, @@ -825,7 +881,7 @@ __global__ void __launch_bounds__(kNumThreads, 1) int num_recv_buffer_tokens) { const auto num_sms = static_cast(gridDim.x); const auto thread_id = static_cast(threadIdx.x); - const auto sm_id = static_cast(blockIdx.x); + const auto sm_id = static_cast(blockIdx.x), lane_id = get_lane_id(); const auto num_channels = num_sms / 2; const bool is_sender = sm_id % 2 == 0; const int responsible_channel = sm_id / 2; @@ -834,23 +890,31 @@ __global__ void __launch_bounds__(kNumThreads, 1) constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4); auto x_int4 = reinterpret_cast(x); + auto bias_0_int4 = reinterpret_cast(bias_0); + auto bias_1_int4 = reinterpret_cast(bias_1); auto recv_int4 = reinterpret_cast(recv_x); + // TMA stuffs +#ifndef DISABLE_SM90_FEATURES + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + auto tma_buffer = smem_buffer + (thread_id / 32) * kNumTMABytesPerWarp; +#endif + if (is_sender) { // Workers for sending // Several warps are responsible for a single rank - constexpr int num_send_warps = kNumThreads / 32; - constexpr int num_send_warps_per_rank = num_send_warps / kNumRanks; + constexpr int num_send_warps_per_rank = (kNumThreads / 32) / kNumRanks; + constexpr int num_send_warps = num_send_warps_per_rank * kNumRanks; const auto num_threads_per_rank = num_send_warps_per_rank * 32; const auto send_thread_id = thread_id; - const auto send_lane_id = send_thread_id % 32; - const auto send_rank_id = thread_id / num_threads_per_rank; - const auto send_warp_id_in_rank = - send_thread_id % num_threads_per_rank / 32; + const auto send_warp_id = send_thread_id / 32; + const auto send_rank_id = (responsible_channel + send_warp_id) % kNumRanks; + const auto send_warp_id_in_rank = send_warp_id / kNumRanks; + EP_STATIC_ASSERT(num_send_warps * 32 == kNumThreads, "Invalid warp count"); // Calculate pointers by the specific layout auto ptr = reinterpret_cast( - reinterpret_cast(buffer_ptrs[send_rank_id])); + static_cast(buffer_ptrs[send_rank_id])); auto num_channels_total = num_channels * kNumRanks; auto channel_rank_offset = responsible_channel * kNumRanks + rank; @@ -905,7 +969,7 @@ __global__ void __launch_bounds__(kNumThreads, 1) auto start_time = clock64(); int num_round_tokens = min(num_max_send_tokens, token_end_idx - static_cast(token_idx)); - while (send_lane_id == 0) { + while (lane_id == 0) { // NOTES: we only consider the worst case, because counting the real // numbers are time-consuming int num_used_slots = current_channel_tail_idx - @@ -937,7 +1001,7 @@ __global__ void __launch_bounds__(kNumThreads, 1) channel_x_buffers.buffer() + dst_slot_idx * hidden_int4; auto shifted_x = x_int4 + (token_idx + i) * hidden_int4; UNROLLED_WARP_COPY(4, - send_lane_id, + lane_id, hidden_int4, shifted_x_buffers, shifted_x, @@ -945,14 +1009,14 @@ __global__ void __launch_bounds__(kNumThreads, 1) st_na_global); // Send source index - if (send_lane_id == 0) + if (lane_id == 0) channel_src_idx_buffers[dst_slot_idx] = __ldg(src_idx + token_idx + i); // Send `topk_weights` - if (num_topk > 0 && send_lane_id < num_topk) - channel_topk_weights_buffers[dst_slot_idx * num_topk + send_lane_id] = - __ldg(topk_weights + (token_idx + i) * num_topk + send_lane_id); + if (num_topk > 0 && lane_id < num_topk) + channel_topk_weights_buffers[dst_slot_idx * num_topk + lane_id] = + __ldg(topk_weights + (token_idx + i) * num_topk + lane_id); } token_idx += num_round_tokens; current_channel_tail_idx += num_round_tokens; @@ -960,7 +1024,7 @@ __global__ void __launch_bounds__(kNumThreads, 1) // Move tail index asm volatile("bar.sync %0, %1;" ::"r"(send_rank_id), "r"(num_threads_per_rank)); - if (send_lane_id == 0 && send_warp_id_in_rank == 0) + if (lane_id == 0 && send_warp_id_in_rank == 0) st_release_sys_global(channel_tail_idx.buffer(), current_channel_tail_idx); } @@ -969,7 +1033,6 @@ __global__ void __launch_bounds__(kNumThreads, 1) // One warp for moving the queue head, others for reduction constexpr int num_recv_warps = kNumThreads / 32; const auto recv_warp_id = thread_id / 32; - const auto recv_lane_id = thread_id % 32; EP_DEVICE_ASSERT(kNumRanks <= 32 && kNumThreads > 32); EP_DEVICE_ASSERT(thread_id >= 0 && kNumThreads % 32 == 0); @@ -978,21 +1041,19 @@ __global__ void __launch_bounds__(kNumThreads, 1) __shared__ volatile int channel_tail_idx[kNumRanks]; __shared__ volatile bool warp_retired[num_recv_warps]; if (thread_id < num_recv_warps) warp_retired[thread_id] = false; - if (recv_lane_id < kNumRanks) - warp_channel_head_idx[recv_warp_id][recv_lane_id] = 0; + if (lane_id < kNumRanks) warp_channel_head_idx[recv_warp_id][lane_id] = 0; if (thread_id < kNumRanks) channel_tail_idx[thread_id] = 0; asm volatile("bar.sync 0, %0;" ::"r"(kNumThreads)); if (thread_id < 32) { - int* channel_head_idx_ptr = reinterpret_cast(buffer_ptrs[rank]) + - responsible_channel * kNumRanks + - recv_lane_id; + int* channel_head_idx_ptr = static_cast(buffer_ptrs[rank]) + + responsible_channel * kNumRanks + lane_id; int* channel_tail_idx_ptr = channel_head_idx_ptr + num_channels * kNumRanks; // Queue head updater int last_head = 0; - while (recv_lane_id < kNumRanks) { + while (lane_id < kNumRanks) { // Check retired bool retired = true; #pragma unroll @@ -1001,15 +1062,14 @@ __global__ void __launch_bounds__(kNumThreads, 1) if (retired) break; // Update queue tail - channel_tail_idx[recv_lane_id] = - ld_acquire_sys_global(channel_tail_idx_ptr); + channel_tail_idx[lane_id] = ld_acquire_sys_global(channel_tail_idx_ptr); // Update minimum head int min_head = std::numeric_limits::max(); #pragma unroll for (int i = 1; i < num_recv_warps; ++i) if (!warp_retired[i]) - min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]); + min_head = min(min_head, warp_channel_head_idx[i][lane_id]); if (min_head != std::numeric_limits::max() && min_head > last_head) st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head); } @@ -1027,9 +1087,9 @@ __global__ void __launch_bounds__(kNumThreads, 1) auto channel_rank_offset = responsible_channel * kNumRanks + i; auto num_channels_total = num_channels * kNumRanks; // `head_idx` & `tail_idx`: kNumChannels * kNumRanks * sizeof(int) - auto ptr = reinterpret_cast( - reinterpret_cast(buffer_ptrs[rank]) + - 2 * num_channels * kNumRanks * sizeof(int)); + auto ptr = + reinterpret_cast(static_cast(buffer_ptrs[rank]) + + 2 * num_channels * kNumRanks * sizeof(int)); // `x_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens * // hidden_int4 * sizeof(int4) @@ -1040,7 +1100,7 @@ __global__ void __launch_bounds__(kNumThreads, 1) // `src_idx_buffers`: kNumChannels * kNumRanks * num_recv_buffer_tokens // * sizeof(int) - ptr = reinterpret_cast(reinterpret_cast(ptr) + + ptr = reinterpret_cast(static_cast(ptr) + num_channels_total * num_recv_buffer_tokens * sizeof(int)); @@ -1066,13 +1126,14 @@ __global__ void __launch_bounds__(kNumThreads, 1) token_idx += num_recv_warps - 1) { // Read expected head int expected_head = -1; - if (recv_lane_id < kNumRanks) { + if (lane_id < kNumRanks) expected_head = - ld_nc_global(send_head + token_idx * kNumRanks + recv_lane_id); - } + ld_nc_global(send_head + token_idx * kNumRanks + lane_id); + auto start_time = clock64(); - while (channel_tail_idx[recv_lane_id] <= expected_head && - expected_head >= 0) { + while (__any_sync( + 0xffffffff, + channel_tail_idx[lane_id] <= expected_head && expected_head >= 0)) { // Timeout check if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { printf( @@ -1098,9 +1159,28 @@ __global__ void __launch_bounds__(kNumThreads, 1) } } -// Reduce data + // Wait shared memory release +#ifndef DISABLE_SM90_FEATURES + if (lane_id == 0) tma_store_wait(); + __syncwarp(); +#endif + + // Reduce data with pipeline + constexpr int kNumStages = 8; + EP_STATIC_ASSERT(kNumStages * 32 * sizeof(int4) <= kNumTMABytesPerWarp, + "Invalid count"); #pragma unroll - for (int i = recv_lane_id; i < hidden_int4; i += 32) { + for (int i = lane_id; i < hidden_int4; i += 32) { + // Read bias + int4 bias_0_value_int4 = + bias_0_int4 != nullptr + ? __ldg(bias_0_int4 + token_idx * hidden_int4 + i) + : make_int4(0, 0, 0, 0); + int4 bias_1_value_int4 = + bias_1_int4 != nullptr + ? __ldg(bias_1_int4 + token_idx * hidden_int4 + i) + : make_int4(0, 0, 0, 0); + // Read buffers int4 recv_value_int4[kNumRanks]; #pragma unroll @@ -1109,8 +1189,18 @@ __global__ void __launch_bounds__(kNumThreads, 1) ld_nc_global(channel_x_buffers[topk_ranks[j]].buffer() + slot_indices[j] * hidden_int4 + i); - // Reduce all-to-all results - float values[kDtypePerInt4] = {0}; + // Reduce bias + float values[kDtypePerInt4]; + auto bias_0_values = + reinterpret_cast(&bias_0_value_int4); + auto bias_1_values = + reinterpret_cast(&bias_1_value_int4); +#pragma unroll + for (int j = 0; j < kDtypePerInt4; ++j) + values[j] = static_cast(bias_0_values[j]) + + static_cast(bias_1_values[j]); + +// Reduce all-to-all results #pragma unroll for (int j = 0; j < num_topk_ranks; ++j) { auto recv_value_dtypes = @@ -1120,34 +1210,66 @@ __global__ void __launch_bounds__(kNumThreads, 1) values[k] += static_cast(recv_value_dtypes[k]); } - // Cast back to `dtype_t` and write + // Cast back to `dtype_t` int4 out_int4; auto out_dtypes = reinterpret_cast(&out_int4); #pragma unroll for (int j = 0; j < kDtypePerInt4; ++j) out_dtypes[j] = static_cast(values[j]); + +#ifndef DISABLE_SM90_FEATURES + // Wait TMA arrival + if (lane_id == 0) tma_store_wait(); + __syncwarp(); + + // Write into TMA buffer + auto tma_stage_idx = (i / 32) % kNumStages; + reinterpret_cast(tma_buffer)[tma_stage_idx * 32 + lane_id] = + out_int4; + + // Issue TMA + tma_store_fence(); + __syncwarp(); + if (lane_id == 0) { + auto tma_bytes = + min(32, hidden_int4 - i) * static_cast(sizeof(int4)); + tma_store_1d( + reinterpret_cast(tma_buffer) + tma_stage_idx * 32, + recv_int4 + token_idx * hidden_int4 + i, + tma_bytes, + false); + } + __syncwarp(); +#else recv_int4[token_idx * hidden_int4 + i] = out_int4; +#endif } // Reduce `topk_weights` - if (recv_lane_id < num_topk) { + if (lane_id < num_topk) { float value = 0; #pragma unroll for (int i = 0; i < num_topk_ranks; ++i) value += ld_nc_global( channel_topk_weights_buffers[topk_ranks[i]].buffer() + - slot_indices[i] * num_topk + recv_lane_id); - recv_topk_weights[token_idx * num_topk + recv_lane_id] = value; + slot_indices[i] * num_topk + lane_id); + recv_topk_weights[token_idx * num_topk + lane_id] = value; } + // Update head - if (recv_lane_id < kNumRanks) - warp_channel_head_idx[recv_warp_id][recv_lane_id] = + if (lane_id < kNumRanks) + warp_channel_head_idx[recv_warp_id][lane_id] = (expected_head < 0) ? -expected_head - 1 : expected_head + 1; } // Retired __syncwarp(); - if (recv_lane_id == 0) warp_retired[recv_warp_id] = true; + if (lane_id == 0) warp_retired[recv_warp_id] = true; + + // Make TMA store visible to the next kernel +#ifndef DISABLE_SM90_FEATURES + if (lane_id == 0) tma_store_wait(); +#endif } } } @@ -1157,6 +1279,8 @@ void combine(cudaDataType_t type, float* recv_topk_weights, const void* x, const float* topk_weights, + const void* bias_0, + const void* bias_1, const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, @@ -1173,26 +1297,36 @@ void combine(cudaDataType_t type, int num_max_send_tokens, int num_recv_buffer_tokens) { constexpr int kNumThreads = 768; - -#define COMBINE_LAUNCH_CASE(dtype, ranks) \ - LAUNCH_KERNEL(&cfg, \ - (combine), \ - reinterpret_cast(recv_x), \ - recv_topk_weights, \ - reinterpret_cast(x), \ - topk_weights, \ - src_idx, \ - rank_prefix_matrix, \ - channel_prefix_matrix, \ - send_head, \ - num_tokens, \ - num_recv_tokens, \ - hidden, \ - num_topk, \ - buffer_ptrs, \ - rank, \ - num_max_send_tokens, \ - num_recv_buffer_tokens); \ + constexpr int kNumTMABytesPerWarp = 4096; +#ifndef DISABLE_SM90_FEATURES + constexpr int smem_size = kNumTMABytesPerWarp * (kNumThreads / 32); +#endif + +#define COMBINE_LAUNCH_CASE(dtype, ranks) \ + { \ + auto kernel = combine; \ + SET_SHARED_MEMORY_FOR_TMA(kernel); \ + LAUNCH_KERNEL(&cfg, \ + kernel, \ + reinterpret_cast(recv_x), \ + recv_topk_weights, \ + reinterpret_cast(x), \ + topk_weights, \ + reinterpret_cast(bias_0), \ + reinterpret_cast(bias_1), \ + src_idx, \ + rank_prefix_matrix, \ + channel_prefix_matrix, \ + send_head, \ + num_tokens, \ + num_recv_tokens, \ + hidden, \ + num_topk, \ + buffer_ptrs, \ + rank, \ + num_max_send_tokens, \ + num_recv_buffer_tokens); \ + } \ break #define COMBINE_DTYPE_LAUNCH_CASE(dtype) \ SWITCH_RANKS_WITH_DTYPE(dtype, COMBINE_LAUNCH_CASE); \ diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/launch.cuh b/paddle/fluid/distributed/collective/deep_ep/kernels/launch.cuh index 6f2f8a49ca3fb2..7a5b677b51223b 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/launch.cuh +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/launch.cuh @@ -40,6 +40,15 @@ CUDA_CHECK(cudaLaunchKernelEx(config, kernel, ##__VA_ARGS__)) #endif +#ifndef SET_SHARED_MEMORY_FOR_TMA +#define SET_SHARED_MEMORY_FOR_TMA(kernel) \ + EP_HOST_ASSERT( \ + cudaFuncSetAttribute(kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + smem_size) == cudaSuccess); \ + cfg.dynamicSmemBytes = smem_size; +#endif + #define SWITCH_RANKS(case_macro) \ switch (num_ranks) { \ case 2: \ diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/runtime.cu b/paddle/fluid/distributed/collective/deep_ep/kernels/runtime.cu index 51669f785f9d31..5ac200a57e4b71 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/runtime.cu +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/runtime.cu @@ -44,17 +44,16 @@ namespace deep_ep { namespace intranode { template -__global__ void barrier(int** task_fifo_ptrs, int head, int rank) { - barrier_device(task_fifo_ptrs, head, rank); +__global__ void barrier(int** barrier_signal_ptrs, int rank) { + barrier_block(barrier_signal_ptrs, rank); } -void barrier(int** task_fifo_ptrs, - int head, +void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream) { -#define BARRIER_LAUNCH_CASE(ranks) \ - LAUNCH_KERNEL(&cfg, barrier, task_fifo_ptrs, head, rank); \ +#define BARRIER_LAUNCH_CASE(ranks) \ + LAUNCH_KERNEL(&cfg, barrier, barrier_signal_ptrs, rank); \ break SETUP_LAUNCH_CONFIG(1, 32, stream); @@ -105,17 +104,6 @@ int init(const std::vector& root_unique_id_val, EP_HOST_ASSERT(cpu_rdma_team != NVSHMEM_TEAM_INVALID); } - // TODO(DeepEP): we still use `nvshmem_barrier` under IBRC mode, which should - // be switch to IBGDA mode later - nvshmemi_device_host_state_t* dev_state_ptr = nullptr; - CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&dev_state_ptr), - nvshmemi_device_state_d)); - - bool ibgda_is_initialized = false; - CUDA_CHECK(cudaMemcpy(&dev_state_ptr->ibgda_is_initialized, - &ibgda_is_initialized, - sizeof(bool), - cudaMemcpyHostToDevice)); nvshmem_barrier_all(); return nvshmem_my_pe(); } @@ -138,16 +126,15 @@ void finalize() { #endif // PADDLE_WITH_NVSHMEM template -__global__ void __launch_bounds__(kNumThreads, 1) - get_dispatch_layout(const int64_t* topk_idx, - int* num_tokens_per_rank, - int* num_tokens_per_rdma_rank, - int* num_tokens_per_expert, - bool* is_token_in_rank, - int num_tokens, - int num_topk, - int num_ranks, - int num_experts) { +__global__ void get_dispatch_layout(const int64_t* topk_idx, + int* num_tokens_per_rank, + int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, + bool* is_token_in_rank, + int num_tokens, + int num_topk, + int num_ranks, + int num_experts) { auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x); @@ -274,11 +261,11 @@ void get_dispatch_layout(const int64_t* topk_idx, int num_ranks, int num_experts, cudaStream_t stream) { - constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8; + constexpr int kNumThreads = 256, kNumExpertsPerSM = 4, kNumRanksPerSM = 8; int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM; - EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, - "Invalid number of experts per SM"); + EP_STATIC_ASSERT(kNumRanksPerSM % NUM_MAX_NVL_PEERS == 0, + "Invalid number of ranks per SM"); SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); LAUNCH_KERNEL( diff --git a/paddle/fluid/distributed/collective/deep_ep/kernels/utils.cuh b/paddle/fluid/distributed/collective/deep_ep/kernels/utils.cuh index 645fc54f4e0ce5..e9ec275c628304 100644 --- a/paddle/fluid/distributed/collective/deep_ep/kernels/utils.cuh +++ b/paddle/fluid/distributed/collective/deep_ep/kernels/utils.cuh @@ -66,6 +66,16 @@ struct VecInt<16> { using vec_t = int4; }; +template +struct PatternVisitor { + FuncT func; + + __device__ __host__ explicit PatternVisitor(FuncT &&func) + : func(std::forward(func)) {} + + __device__ __host__ auto operator[](const uint32_t &i) { return func(i); } +}; + __device__ __forceinline__ void trap() { asm("trap;"); } __device__ __forceinline__ void memory_fence() { @@ -224,7 +234,7 @@ __device__ __forceinline__ int64_t ld_volatile_global(const uint64_t *ptr) { #ifndef DISABLE_AGGRESSIVE_PTX_INSTRS #define LD_NC_FUNC "ld.global.nc.L1::no_allocate.L2::256B" #else -#define LD_NC_FUNC "ld.volatile.global" +#define LD_NC_FUNC "ld.volatile.global.L2::256B" #endif // `ld.global.nc.L1::no_allocate` will be translated into @@ -396,14 +406,138 @@ __device__ __forceinline__ void st_na_global(const int4 *ptr, "r"(value.w)); } +__device__ __forceinline__ float log2f_approx(const float &x) { + float ret; + asm volatile("lg2.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return ret; +} + +__device__ __forceinline__ float exp2f_approx(const float &x) { + float ret; + asm volatile("ex2.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return ret; +} + +// TMA PTX instructions +#ifndef DISABLE_SM90_FEATURES + +__device__ __forceinline__ uint32_t elect_one_sync(int lane_id) { + uint32_t pred = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(lane_id), "+r"(pred) + : "r"(0xffffffff)); + return pred; +} + +__device__ __forceinline__ void fence_view_async_shared() { + asm volatile("fence.proxy.async.shared::cta; \n" ::); +} + +__device__ __forceinline__ void fence_barrier_init() { + asm volatile("fence.mbarrier_init.release.cluster; \n" ::); +} + +__device__ __forceinline__ void mbarrier_init(uint64_t *mbar_ptr, + uint32_t arrive_count) { + auto mbar_int_ptr = static_cast(__cvta_generic_to_shared(mbar_ptr)); + asm volatile("mbarrier.init.shared::cta.b64 [%1], %0;" ::"r"(arrive_count), + "r"(mbar_int_ptr)); +} + +__device__ __forceinline__ void mbarrier_wait(uint64_t *mbar_ptr, + uint32_t &phase) { + auto mbar_int_ptr = static_cast(__cvta_generic_to_shared(mbar_ptr)); + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + "}" ::"r"(mbar_int_ptr), + "r"(phase), + "r"(0x989680)); + phase ^= 1; +} + +__device__ __forceinline__ void mbarrier_arrive_and_expect_tx( + uint64_t *mbar_ptr, int num_bytes) { + auto mbar_int_ptr = static_cast(__cvta_generic_to_shared(mbar_ptr)); + asm volatile( + "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" ::"r"( + num_bytes), + "r"(mbar_int_ptr)); +} + +__device__ __forceinline__ void tma_store_fence() { + asm volatile("fence.proxy.async.shared::cta;"); +} + +constexpr uint64_t kEvictFirst = 0x12f0000000000000; +constexpr uint64_t kEvictNormal = 0x1000000000000000; + +__device__ __forceinline__ void tma_load_1d(const void *smem_ptr, + const void *gmem_ptr, + uint64_t *mbar_ptr, + int num_bytes, + bool evict_first = true) { + auto mbar_int_ptr = static_cast(__cvta_generic_to_shared(mbar_ptr)); + auto smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal; + asm volatile( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::" + "cache_hint [%0], [%1], %2, [%3], %4;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "r"(num_bytes), + "r"(mbar_int_ptr), + "l"(cache_hint) + : "memory"); +} + +__device__ __forceinline__ void tma_store_1d(const void *smem_ptr, + const void *gmem_ptr, + int num_bytes, + bool evict_first = true) { + auto smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + const auto cache_hint = evict_first ? kEvictFirst : kEvictNormal; + asm volatile( + "cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], " + "%2, %3;\n" ::"l"(gmem_ptr), + "r"(smem_int_ptr), + "r"(num_bytes), + "l"(cache_hint) + : "memory"); + asm volatile("cp.async.bulk.commit_group;"); +} + +template +__device__ __forceinline__ void tma_store_wait() { + asm volatile("cp.async.bulk.wait_group.read %0;" ::"n"(N) : "memory"); +} + +#endif + template -__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) { +__host__ __device__ constexpr dtype_t ceil_div(dtype_t a, dtype_t b) { return (a + b - 1) / b; } template -__host__ __device__ dtype_t align(dtype_t a, dtype_t b) { - return cell_div(a, b) * b; +__host__ __device__ constexpr dtype_t align(dtype_t a, dtype_t b) { + return ceil_div(a, b) * b; +} + +template +__host__ __device__ dtype_t cell_div(dtype_t a, dtype_t b) { + return (a + b - 1) / b; } __forceinline__ __device__ void get_channel_task_range(int num_tokens, @@ -411,7 +545,7 @@ __forceinline__ __device__ void get_channel_task_range(int num_tokens, int sm_id, int &token_start_idx, int &token_end_idx) { - int num_tokens_per_sm = cell_div(num_tokens, num_sms); + int num_tokens_per_sm = ceil_div(num_tokens, num_sms); token_start_idx = min(num_tokens_per_sm * sm_id, num_tokens); token_end_idx = min(token_start_idx + num_tokens_per_sm, num_tokens); } @@ -449,15 +583,6 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) { return *reinterpret_cast(recv_int_values); } -__forceinline__ __device__ int warp_reduce_sum(int value) { - value += __shfl_xor_sync(0xffffffff, value, 16); - value += __shfl_xor_sync(0xffffffff, value, 8); - value += __shfl_xor_sync(0xffffffff, value, 4); - value += __shfl_xor_sync(0xffffffff, value, 2); - value += __shfl_xor_sync(0xffffffff, value, 1); - return value; -} - __forceinline__ __device__ float half_warp_reduce_max(float value) { auto mask = __activemask(); // The mask be in `{0xffffffff, 0xffff}` @@ -474,48 +599,166 @@ __forceinline__ __device__ int get_lane_id() { return lane_id; } -template -__forceinline__ __device__ void move_fifo_slots(int &head) { - head = (head + kNumRanks) % NUM_MAX_FIFO_SLOTS; +constexpr float kFP8Margin = 1e-4; +constexpr float kFinfoAmaxE4M3 = 448.0f; +constexpr float kFinfoAmaxInvE4M3 = 1 / 448.0f; + +__forceinline__ __device__ float fast_pow2(int x) { + // We can ensure `-126 <= x and x <= 127` + uint32_t bits_x = (x + 127) << 23; + return *reinterpret_cast(&bits_x); } -template -__device__ __forceinline__ bool not_finished(int *task, int expected) { - auto result = false; - auto lane_id = threadIdx.x % 32; - if (lane_id < kNumRanks) - result = ld_volatile_global(task + lane_id) != expected; - return __any_sync(0xffffffff, result); +__forceinline__ __device__ int fast_log2_ceil(float x) { + auto bits_x = *reinterpret_cast(&x); + auto exp_x = (bits_x >> 23) & 0xff; + auto man_bits = bits_x & ((1 << 23) - 1); + return exp_x - 127 + (man_bits != 0); } -template -__forceinline__ __device__ void timeout_check( - int **task_fifo_ptrs, int head, int rank, int expected, int tag = 0) { - auto start_time = clock64(); - while (not_finished(task_fifo_ptrs[rank] + head, expected)) { - if (clock64() - start_time > NUM_TIMEOUT_CYCLES and threadIdx.x == 0) { - printf("DeepEP timeout check failed: %d (rank = %d)\n", tag, rank); - trap(); - } +__forceinline__ __device__ void calculate_fp8_scales(float amax, + float &scale, + float &scale_inv, + bool round_scale) { + if (round_scale) { + auto exp_scale_inv = fast_log2_ceil(amax * kFinfoAmaxInvE4M3); + scale = fast_pow2(-exp_scale_inv); + scale_inv = fast_pow2(exp_scale_inv); + } else { + scale_inv = amax * kFinfoAmaxInvE4M3; + scale = kFinfoAmaxE4M3 / amax; } } -template -__forceinline__ __device__ void barrier_device(int **task_fifo_ptrs, - int head, - int rank, - int tag = 0) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +template > +__forceinline__ __device__ out_dtype_t +extract_required_scale_format(float value) { + if constexpr (kIsUE8M0) { + return static_cast((*reinterpret_cast(&value)) >> 23); + } else { + return value; + } +} + +template +__forceinline__ __device__ void barrier_block(int **barrier_signal_ptrs, + int rank) { auto thread_id = static_cast(threadIdx.x); - EP_DEVICE_ASSERT(kNumRanks <= 32); - if (thread_id < kNumRanks) { - atomicAdd_system(task_fifo_ptrs[rank] + head + thread_id, FINISHED_SUM_TAG); + // For non-sync-only cases, the memory operations by other threads in the + // block must be visible to the `sys` scope + if constexpr (not kSyncOnly) { memory_fence(); - atomicSub_system(task_fifo_ptrs[thread_id] + head + rank, FINISHED_SUM_TAG); + __syncthreads(); } - timeout_check(task_fifo_ptrs, head, rank, 0, tag); -#endif + + // Add self-ranks, sub other ranks + if (thread_id < kNumRanks) { + atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG); + atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG); + } + EP_DEVICE_ASSERT(kNumRanks <= blockDim.x); + + // Check timeout + auto start_time = clock64(); + while (true) { + auto value = thread_id < kNumRanks + ? ld_volatile_global(barrier_signal_ptrs[rank] + thread_id) + : 0; + if (__all_sync(0xffffffff, value <= 0)) break; + + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and thread_id < kNumRanks) { + printf( + "DeepEP timeout check failed: rank = %d, thread = %d, value = %d)\n", + rank, + thread_id, + value); + trap(); + } + } + __syncthreads(); +} + +__forceinline__ __device__ int atomic_cas_cta_acquire(int *addr, int x, int y) { + int ret; + asm volatile("atom.acquire.cta.shared::cta.cas.b32 %0, [%1], %2, %3;" + : "=r"(ret) + : "l"(addr), "r"(x), "r"(y) + : "memory"); + return ret; +} + +__forceinline__ __device__ int atomic_exch_cta_release(int *addr, int x) { + int ret; + asm volatile("atom.release.cta.shared::cta.exch.b32 %0, [%1], %2;" + : "=r"(ret) + : "l"(addr), "r"(x) + : "memory"); + return ret; +} + +__forceinline__ __device__ void acquire_lock(int *mutex) { + // To make later memory operations valid, we must use `acquire` for memory + // semantics + while (atomic_cas_cta_acquire(mutex, 0, 1) != 0) + ; +} + +__forceinline__ __device__ void release_lock(int *mutex) { + // To make previous memory operations visible to other threads, we must use + // `release` for memory semantics + atomic_exch_cta_release(mutex, 0); +} + +// Operation functors +template +struct ReduceSum { + __device__ T operator()(T a, T b) const { return a + b; } +}; +template +struct ReduceMax { + __device__ T operator()(T a, T b) const { return a > b ? a : b; } +}; +template +struct ReduceMin { + __device__ T operator()(T a, T b) const { return a < b ? a : b; } +}; + +// Unified reduction function +template +__forceinline__ __device__ T warp_reduce(T value, Op op) { + EP_STATIC_ASSERT(kNumLanes == 32 or kNumLanes == 16 or kNumLanes == 8 or + kNumLanes == 4 or kNumLanes == 2 or kNumLanes == 1, + "Invalid number of lanes"); + + if constexpr (kNumLanes >= 32) + value = op(value, __shfl_xor_sync(0xffffffff, value, 16)); + if constexpr (kNumLanes >= 16) + value = op(value, __shfl_xor_sync(0xffffffff, value, 8)); + if constexpr (kNumLanes >= 8) + value = op(value, __shfl_xor_sync(0xffffffff, value, 4)); + if constexpr (kNumLanes >= 4) + value = op(value, __shfl_xor_sync(0xffffffff, value, 2)); + if constexpr (kNumLanes >= 2) + value = op(value, __shfl_xor_sync(0xffffffff, value, 1)); + return value; +} + +// Convenience aliases +template +__forceinline__ __device__ T warp_reduce_sum(T value) { + return warp_reduce(value, ReduceSum{}); +} + +template +__forceinline__ __device__ T warp_reduce_max(T value) { + return warp_reduce(value, ReduceMax{}); +} + +template +__forceinline__ __device__ T warp_reduce_min(T value) { + return warp_reduce(value, ReduceMin{}); } } // namespace deep_ep