@@ -83,11 +83,10 @@ Buffer::Buffer(int rank,
8383 calc_ctx = reinterpret_cast <phi::GPUContext*>(
8484 reinterpret_cast <paddle::distributed::ProcessGroupNCCL*>(pg)
8585 ->GetDeviceContext (place, true ));
86-
87- // Metadata memory
88- int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof (int );
89- int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof (void *);
90- int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof (int *);
86+ // Task fifo memory
87+ int64_t fifo_bytes = sizeof (int ) * NUM_MAX_FIFO_SLOTS;
88+ int64_t buffer_ptr_bytes = sizeof (void *) * NUM_MAX_NVL_PEERS;
89+ int64_t task_ptr_bytes = sizeof (int *) * NUM_MAX_NVL_PEERS;
9190
9291 // Common checks
9392 EP_HOST_ASSERT (
@@ -106,35 +105,40 @@ Buffer::Buffer(int rank,
106105 EP_HOST_ASSERT (num_ranks > NUM_MAX_NVL_PEERS || low_latency_mode);
107106
108107 // Get ranks
108+ // CUDA_CHECK(cudaGetDevice(&device_id));
109109 rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
110- num_rdma_ranks = std::max (1 , num_ranks / NUM_MAX_NVL_PEERS);
110+ num_rdma_ranks = std::max (1 , num_ranks / NUM_MAX_NVL_PEERS),
111111 num_nvl_ranks = std::min (num_ranks, NUM_MAX_NVL_PEERS);
112112
113113 // Get device info
114114 cudaDeviceProp device_prop = {};
115115 CUDA_CHECK (cudaGetDeviceProperties (&device_prop, device_id));
116116
117117 if (num_nvl_bytes > 0 ) {
118- // Local IPC: alloc local memory and set local IPC handles
119- CUDA_CHECK (cudaMalloc (&buffer_ptrs[nvl_rank],
120- num_nvl_bytes + barrier_signal_bytes +
121- buffer_ptr_bytes + barrier_signal_ptr_bytes ));
118+ // Local IPC: alloc local memory and set local IPC handle
119+ CUDA_CHECK (cudaMalloc (
120+ &buffer_ptrs[nvl_rank],
121+ num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes ));
122122 CUDA_CHECK (
123123 cudaIpcGetMemHandle (&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
124- buffer_ptrs_gpu =
125- reinterpret_cast <void **>(static_cast <uint8_t *>(buffer_ptrs[nvl_rank]) +
126- num_nvl_bytes + barrier_signal_bytes);
127-
128- // Set barrier signals
129- barrier_signal_ptrs[nvl_rank] = reinterpret_cast <int *>(
130- static_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
131- barrier_signal_ptrs_gpu = reinterpret_cast <int **>(
132- static_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
133- barrier_signal_bytes + buffer_ptr_bytes);
124+ buffer_ptrs_gpu = reinterpret_cast <void **>(
125+ reinterpret_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
126+ fifo_bytes);
127+
128+ // Set task fifo
129+ EP_HOST_ASSERT (NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0 );
130+ task_fifo_ptrs[nvl_rank] = reinterpret_cast <int *>(
131+ reinterpret_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
132+ task_fifo_ptrs_gpu = reinterpret_cast <int **>(
133+ reinterpret_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
134+ fifo_bytes + buffer_ptr_bytes);
134135
135136 // No need to synchronize, will do a full device sync during `sync`
136137 CUDA_CHECK (cudaMemsetAsync (
137- barrier_signal_ptrs[nvl_rank], 0 , barrier_signal_bytes, comm_stream));
138+ buffer_ptrs[nvl_rank],
139+ 0 ,
140+ num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes,
141+ comm_stream));
138142 }
139143
140144 // Create 32 MiB workspace
@@ -180,7 +184,8 @@ Buffer::~Buffer() noexcept(false) {
180184 if (num_nvl_bytes > 0 ) {
181185 // Barrier
182186 intranode::barrier (
183- barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);
187+ task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream);
188+ move_fifo_slots ();
184189 CUDA_CHECK (cudaDeviceSynchronize ());
185190
186191 // Close remote IPC
@@ -211,6 +216,10 @@ Buffer::~Buffer() noexcept(false) {
211216 CUDA_CHECK (cudaFreeHost (const_cast <int *>(moe_recv_expert_counter)));
212217}
213218
219+ void Buffer::move_fifo_slots (int num_slots) {
220+ head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS;
221+ }
222+
214223bool Buffer::is_available () const { return available; }
215224
216225bool Buffer::is_internode_available () const {
@@ -259,7 +268,7 @@ void Buffer::sync(
259268
260269 // Sync IPC handles
261270 if (num_nvl_bytes > 0 ) {
262- EP_HOST_ASSERT (num_ranks == device_ids.size ());
271+ EP_HOST_ASSERT (num_ranks == static_cast < int64_t >( device_ids.size () ));
263272 EP_HOST_ASSERT (device_ids.size () == all_gathered_handles.size ());
264273 for (int i = 0 , offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks;
265274 ++i) {
@@ -271,22 +280,22 @@ void Buffer::sync(
271280 ipc_handles[i].reserved , handle_str.c_str (), CUDA_IPC_HANDLE_SIZE);
272281 CUDA_CHECK (cudaIpcOpenMemHandle (
273282 &buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess));
274- barrier_signal_ptrs [i] = reinterpret_cast <int *>(
275- static_cast <uint8_t *>(buffer_ptrs[i]) + num_nvl_bytes);
283+ task_fifo_ptrs [i] = reinterpret_cast <int *>(
284+ reinterpret_cast <uint8_t *>(buffer_ptrs[i]) + num_nvl_bytes);
276285 } else {
277286 EP_HOST_ASSERT (std::memcmp (ipc_handles[i].reserved ,
278287 handle_str.c_str (),
279288 CUDA_IPC_HANDLE_SIZE) == 0 );
280289 }
281290 }
282291
283- // Copy all buffer and barrier signal pointers to GPU
292+ // Copy all buffer and task pointers to GPU
284293 CUDA_CHECK (cudaMemcpy (buffer_ptrs_gpu,
285294 buffer_ptrs,
286295 sizeof (void *) * NUM_MAX_NVL_PEERS,
287296 cudaMemcpyHostToDevice));
288- CUDA_CHECK (cudaMemcpy (barrier_signal_ptrs_gpu ,
289- barrier_signal_ptrs ,
297+ CUDA_CHECK (cudaMemcpy (task_fifo_ptrs_gpu ,
298+ task_fifo_ptrs ,
290299 sizeof (int *) * NUM_MAX_NVL_PEERS,
291300 cudaMemcpyHostToDevice));
292301 CUDA_CHECK (cudaDeviceSynchronize ());
@@ -530,7 +539,7 @@ Buffer::intranode_dispatch(
530539
531540 // FP8 scales checks
532541 float * x_scales_ptr = nullptr ;
533- int num_scales = 0 , scale_token_stride = 0 , scale_hidden_stride = 0 ;
542+ int num_scales = 0 ;
534543 if (x_scales.has_value ()) {
535544 EP_HOST_ASSERT (x.element_size () == 1 );
536545 EP_HOST_ASSERT (x_scales->scalar_type () == deep_ep::detail::kFloat32 );
@@ -539,8 +548,6 @@ Buffer::intranode_dispatch(
539548 EP_HOST_ASSERT (x_scales->size (0 ) == num_tokens);
540549 num_scales = x_scales->dim () == 1 ? 1 : static_cast <int >(x_scales->size (1 ));
541550 x_scales_ptr = x_scales->data_ptr <float >();
542- scale_token_stride = static_cast <int >(x_scales->stride (0 ));
543- scale_hidden_stride = static_cast <int >(x_scales->stride (1 ));
544551 }
545552
546553 // Allocate all tensors on comm stream if set
@@ -579,10 +586,12 @@ Buffer::intranode_dispatch(
579586 intranode::cached_notify_dispatch (rank_prefix_matrix.data_ptr <int >(),
580587 num_memset_int,
581588 buffer_ptrs_gpu,
582- barrier_signal_ptrs_gpu,
589+ task_fifo_ptrs_gpu,
590+ head,
583591 rank,
584592 num_ranks,
585593 comm_stream);
594+ move_fifo_slots (2 );
586595 } else {
587596 rank_prefix_matrix = ConvertPaddleTensorToDetailTensor (
588597 paddle::experimental::empty ({num_ranks, num_ranks},
@@ -617,10 +626,12 @@ Buffer::intranode_dispatch(
617626 num_memset_int,
618627 expert_alignment,
619628 buffer_ptrs_gpu,
620- barrier_signal_ptrs_gpu,
629+ task_fifo_ptrs_gpu,
630+ head,
621631 rank,
622632 comm_stream,
623633 num_channels);
634+ move_fifo_slots (3 );
624635
625636 // Synchronize total received tokens and tokens per expert
626637 auto start_time = std::chrono::high_resolution_clock::now ();
@@ -730,13 +741,10 @@ Buffer::intranode_dispatch(
730741 is_token_in_rank.data_ptr <bool >(),
731742 channel_prefix_matrix.data_ptr <int >(),
732743 num_tokens,
733- 0 , // num_worst_tokens (not exposed)
734744 static_cast <int >(hidden * recv_x.element_size () / sizeof (int4)),
735745 num_topk,
736746 num_experts,
737747 num_scales,
738- scale_token_stride,
739- scale_hidden_stride,
740748 buffer_ptrs_gpu,
741749 rank,
742750 num_ranks,
@@ -881,11 +889,15 @@ Buffer::intranode_combine(
881889 num_channels,
882890 num_recv_tokens,
883891 num_channels * num_ranks * 2 ,
884- barrier_signal_ptrs_gpu,
892+ task_fifo_ptrs_gpu,
893+ head,
885894 rank,
886895 num_ranks,
887896 comm_stream);
888897
898+ // NOTES: this function uses two FIFO slots (barrier before and after)
899+ move_fifo_slots (2 );
900+
889901 // Combine data
890902 auto recv_x = ConvertPaddleTensorToDetailTensor (paddle::experimental::empty (
891903 {num_recv_tokens, hidden}, x.dtype (), x.place ()));
@@ -905,8 +917,6 @@ Buffer::intranode_combine(
905917 recv_topk_weights_ptr,
906918 x.data_ptr (),
907919 topk_weights_ptr,
908- nullptr , // bias_ptrs[0] (not exposed)
909- nullptr , // bias_ptrs[1] (not exposed)
910920 src_idx.data_ptr <int >(),
911921 rank_prefix_matrix.data_ptr <int >(),
912922 channel_prefix_matrix.data_ptr <int >(),
@@ -1096,7 +1106,7 @@ Buffer::internode_dispatch(
10961106
10971107 // FP8 scales checks
10981108 float * x_scales_ptr = nullptr ;
1099- int num_scales = 0 , scale_token_stride = 0 , scale_hidden_stride = 0 ;
1109+ int num_scales = 0 ;
11001110 if (x_scales.has_value ()) {
11011111 EP_HOST_ASSERT (x.element_size () == 1 );
11021112 EP_HOST_ASSERT (x_scales->scalar_type () == deep_ep::detail::kFloat32 );
@@ -1105,8 +1115,6 @@ Buffer::internode_dispatch(
11051115 EP_HOST_ASSERT (x_scales->size (0 ) == num_tokens);
11061116 num_scales = x_scales->dim () == 1 ? 1 : static_cast <int >(x_scales->size (1 ));
11071117 x_scales_ptr = x_scales->data_ptr <float >();
1108- scale_token_stride = static_cast <int >(x_scales->stride (0 ));
1109- scale_hidden_stride = static_cast <int >(x_scales->stride (1 ));
11101118 }
11111119
11121120 // Allocate all tensors on comm stream if set
@@ -1161,13 +1169,15 @@ Buffer::internode_dispatch(
11611169 config.num_max_rdma_chunked_recv_tokens ,
11621170 buffer_ptrs_gpu,
11631171 config.num_max_nvl_chunked_recv_tokens ,
1164- barrier_signal_ptrs_gpu,
1172+ task_fifo_ptrs_gpu,
1173+ head,
11651174 rank,
11661175 comm_stream,
11671176 config.get_rdma_buffer_size_hint (hidden_int4 * sizeof (int4), num_ranks),
11681177 num_nvl_bytes,
11691178 true ,
11701179 low_latency_mode);
1180+ move_fifo_slots (2 );
11711181 } else {
11721182 rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor (
11731183 paddle::experimental::empty ({num_rdma_ranks, num_channels},
@@ -1211,12 +1221,14 @@ Buffer::internode_dispatch(
12111221 config.num_max_rdma_chunked_recv_tokens ,
12121222 buffer_ptrs_gpu,
12131223 config.num_max_nvl_chunked_recv_tokens ,
1214- barrier_signal_ptrs_gpu,
1224+ task_fifo_ptrs_gpu,
1225+ head,
12151226 rank,
12161227 comm_stream,
12171228 config.get_rdma_buffer_size_hint (hidden_int4 * sizeof (int4), num_ranks),
12181229 num_nvl_bytes,
12191230 low_latency_mode);
1231+ move_fifo_slots (3 );
12201232
12211233 // Synchronize total received tokens and tokens per expert
12221234 auto start_time = std::chrono::high_resolution_clock::now ();
@@ -1333,14 +1345,12 @@ Buffer::internode_dispatch(
13331345 recv_rdma_rank_prefix_sum.data_ptr <int >(),
13341346 gbl_channel_prefix_matrix.data_ptr <int >(),
13351347 recv_gbl_rank_prefix_sum.data_ptr <int >(),
1336- is_token_in_rank.data_ptr <bool >(),
13371348 num_tokens,
13381349 hidden_int4,
13391350 num_scales,
13401351 num_topk,
13411352 num_experts,
1342- scale_token_stride,
1343- scale_hidden_stride,
1353+ is_token_in_rank.data_ptr <bool >(),
13441354 rdma_buffer_ptr,
13451355 config.num_max_rdma_chunked_send_tokens ,
13461356 config.num_max_rdma_chunked_recv_tokens ,
@@ -1538,13 +1548,15 @@ Buffer::internode_combine(
15381548 config.num_max_rdma_chunked_recv_tokens ,
15391549 buffer_ptrs_gpu,
15401550 config.num_max_nvl_chunked_recv_tokens ,
1541- barrier_signal_ptrs_gpu,
1551+ task_fifo_ptrs_gpu,
1552+ head,
15421553 rank,
15431554 comm_stream,
15441555 config.get_rdma_buffer_size_hint (hidden_int4 * sizeof (int4), num_ranks),
15451556 num_nvl_bytes,
15461557 false ,
15471558 low_latency_mode);
1559+ move_fifo_slots (2 );
15481560
15491561 // Launch data combine
15501562 auto combined_x =
@@ -1556,8 +1568,6 @@ Buffer::internode_combine(
15561568 is_combined_token_in_rank.data_ptr <bool >(),
15571569 x.data_ptr (),
15581570 topk_weights_ptr,
1559- nullptr , // bias_ptrs[0] (not exposed)
1560- nullptr , // bias_ptrs[1] (not exposed)
15611571 combined_rdma_head.data_ptr <int >(),
15621572 combined_nvl_head.data_ptr <int >(),
15631573 src_meta.data_ptr (),
0 commit comments