@@ -69,14 +69,15 @@ Buffer::Buffer(int rank,
6969 calc_ctx = reinterpret_cast <phi::GPUContext*>(
7070 reinterpret_cast <paddle::distributed::ProcessGroupNCCL*>(pg)
7171 ->GetDeviceContext (place, true ));
72- // Task fifo memory
73- int64_t fifo_bytes = sizeof (int ) * NUM_MAX_FIFO_SLOTS;
74- int64_t buffer_ptr_bytes = sizeof (void *) * NUM_MAX_NVL_PEERS;
75- int64_t task_ptr_bytes = sizeof (int *) * NUM_MAX_NVL_PEERS;
72+
73+ // Metadata memory
74+ int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof (int );
75+ int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof (void *);
76+ int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof (int *);
7677
7778 // Common checks
7879 EP_HOST_ASSERT (num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 &&
79- (num_nvl_bytes <= std::numeric_limits<int64_t >::max () ||
80+ (num_nvl_bytes <= std::numeric_limits<int >::max () ||
8081 num_rdma_bytes == 0 ));
8182 EP_HOST_ASSERT (
8283 num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 &&
@@ -90,40 +91,35 @@ Buffer::Buffer(int rank,
9091 EP_HOST_ASSERT (num_ranks > NUM_MAX_NVL_PEERS || low_latency_mode);
9192
9293 // Get ranks
93- // CUDA_CHECK(cudaGetDevice(&device_id));
9494 rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
95- num_rdma_ranks = std::max (1 , num_ranks / NUM_MAX_NVL_PEERS),
95+ num_rdma_ranks = std::max (1 , num_ranks / NUM_MAX_NVL_PEERS);
9696 num_nvl_ranks = std::min (num_ranks, NUM_MAX_NVL_PEERS);
9797
9898 // Get device info
9999 cudaDeviceProp device_prop = {};
100100 CUDA_CHECK (cudaGetDeviceProperties (&device_prop, device_id));
101101
102102 if (num_nvl_bytes > 0 ) {
103- // Local IPC: alloc local memory and set local IPC handle
104- CUDA_CHECK (cudaMalloc (
105- &buffer_ptrs[nvl_rank],
106- num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes ));
103+ // Local IPC: alloc local memory and set local IPC handles
104+ CUDA_CHECK (cudaMalloc (&buffer_ptrs[nvl_rank],
105+ num_nvl_bytes + barrier_signal_bytes +
106+ buffer_ptr_bytes + barrier_signal_ptr_bytes ));
107107 CUDA_CHECK (
108108 cudaIpcGetMemHandle (&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
109- buffer_ptrs_gpu = reinterpret_cast <void **>(
110- reinterpret_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
111- fifo_bytes);
112-
113- // Set task fifo
114- EP_HOST_ASSERT (NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0 );
115- task_fifo_ptrs[nvl_rank] = reinterpret_cast <int *>(
116- reinterpret_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
117- task_fifo_ptrs_gpu = reinterpret_cast <int **>(
118- reinterpret_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
119- fifo_bytes + buffer_ptr_bytes);
109+ buffer_ptrs_gpu =
110+ reinterpret_cast <void **>(static_cast <uint8_t *>(buffer_ptrs[nvl_rank]) +
111+ num_nvl_bytes + barrier_signal_bytes);
112+
113+ // Set barrier signals
114+ barrier_signal_ptrs[nvl_rank] = reinterpret_cast <int *>(
115+ static_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
116+ barrier_signal_ptrs_gpu = reinterpret_cast <int **>(
117+ static_cast <uint8_t *>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
118+ barrier_signal_bytes + buffer_ptr_bytes);
120119
121120 // No need to synchronize, will do a full device sync during `sync`
122121 CUDA_CHECK (cudaMemsetAsync (
123- buffer_ptrs[nvl_rank],
124- 0 ,
125- num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes,
126- comm_stream));
122+ barrier_signal_ptrs[nvl_rank], 0 , barrier_signal_bytes, comm_stream));
127123 }
128124
129125 // Create 32 MiB workspace
@@ -165,8 +161,7 @@ Buffer::~Buffer() noexcept(false) {
165161 if (num_nvl_bytes > 0 ) {
166162 // Barrier
167163 intranode::barrier (
168- task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream);
169- move_fifo_slots ();
164+ barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);
170165 CUDA_CHECK (cudaDeviceSynchronize ());
171166
172167 // Close remote IPC
@@ -197,10 +192,6 @@ Buffer::~Buffer() noexcept(false) {
197192 CUDA_CHECK (cudaFreeHost (const_cast <int *>(moe_recv_expert_counter)));
198193}
199194
200- void Buffer::move_fifo_slots (int num_slots) {
201- head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS;
202- }
203-
204195bool Buffer::is_available () const { return available; }
205196
206197bool Buffer::is_internode_available () const {
@@ -249,7 +240,7 @@ void Buffer::sync(
249240
250241 // Sync IPC handles
251242 if (num_nvl_bytes > 0 ) {
252- EP_HOST_ASSERT (num_ranks == static_cast < int64_t >( device_ids.size () ));
243+ EP_HOST_ASSERT (num_ranks == device_ids.size ());
253244 EP_HOST_ASSERT (device_ids.size () == all_gathered_handles.size ());
254245 for (int i = 0 , offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks;
255246 ++i) {
@@ -261,22 +252,22 @@ void Buffer::sync(
261252 ipc_handles[i].reserved , handle_str.c_str (), CUDA_IPC_HANDLE_SIZE);
262253 CUDA_CHECK (cudaIpcOpenMemHandle (
263254 &buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess));
264- task_fifo_ptrs [i] = reinterpret_cast <int *>(
265- reinterpret_cast <uint8_t *>(buffer_ptrs[i]) + num_nvl_bytes);
255+ barrier_signal_ptrs [i] = reinterpret_cast <int *>(
256+ static_cast <uint8_t *>(buffer_ptrs[i]) + num_nvl_bytes);
266257 } else {
267258 EP_HOST_ASSERT (std::memcmp (ipc_handles[i].reserved ,
268259 handle_str.c_str (),
269260 CUDA_IPC_HANDLE_SIZE) == 0 );
270261 }
271262 }
272263
273- // Copy all buffer and task pointers to GPU
264+ // Copy all buffer and barrier signal pointers to GPU
274265 CUDA_CHECK (cudaMemcpy (buffer_ptrs_gpu,
275266 buffer_ptrs,
276267 sizeof (void *) * NUM_MAX_NVL_PEERS,
277268 cudaMemcpyHostToDevice));
278- CUDA_CHECK (cudaMemcpy (task_fifo_ptrs_gpu ,
279- task_fifo_ptrs ,
269+ CUDA_CHECK (cudaMemcpy (barrier_signal_ptrs_gpu ,
270+ barrier_signal_ptrs ,
280271 sizeof (int *) * NUM_MAX_NVL_PEERS,
281272 cudaMemcpyHostToDevice));
282273 CUDA_CHECK (cudaDeviceSynchronize ());
@@ -520,7 +511,7 @@ Buffer::intranode_dispatch(
520511
521512 // FP8 scales checks
522513 float * x_scales_ptr = nullptr ;
523- int num_scales = 0 ;
514+ int num_scales = 0 , scale_token_stride = 0 , scale_hidden_stride = 0 ;
524515 if (x_scales.has_value ()) {
525516 EP_HOST_ASSERT (x.element_size () == 1 );
526517 EP_HOST_ASSERT (x_scales->scalar_type () == deep_ep::detail::kFloat32 );
@@ -529,6 +520,8 @@ Buffer::intranode_dispatch(
529520 EP_HOST_ASSERT (x_scales->size (0 ) == num_tokens);
530521 num_scales = x_scales->dim () == 1 ? 1 : static_cast <int >(x_scales->size (1 ));
531522 x_scales_ptr = x_scales->data_ptr <float >();
523+ scale_token_stride = static_cast <int >(x_scales->stride (0 ));
524+ scale_hidden_stride = static_cast <int >(x_scales->stride (1 ));
532525 }
533526
534527 // Allocate all tensors on comm stream if set
@@ -564,12 +557,10 @@ Buffer::intranode_dispatch(
564557 intranode::cached_notify_dispatch (rank_prefix_matrix.data_ptr <int >(),
565558 num_memset_int,
566559 buffer_ptrs_gpu,
567- task_fifo_ptrs_gpu,
568- head,
560+ barrier_signal_ptrs_gpu,
569561 rank,
570562 num_ranks,
571563 comm_stream);
572- move_fifo_slots (2 );
573564 } else {
574565 rank_prefix_matrix = ConvertPaddleTensorToDetailTensor (
575566 paddle::experimental::empty ({num_ranks, num_ranks},
@@ -604,12 +595,10 @@ Buffer::intranode_dispatch(
604595 num_memset_int,
605596 expert_alignment,
606597 buffer_ptrs_gpu,
607- task_fifo_ptrs_gpu,
608- head,
598+ barrier_signal_ptrs_gpu,
609599 rank,
610600 comm_stream,
611601 num_channels);
612- move_fifo_slots (3 );
613602
614603 // Synchronize total received tokens and tokens per expert
615604 auto start_time = std::chrono::high_resolution_clock::now ();
@@ -719,10 +708,13 @@ Buffer::intranode_dispatch(
719708 is_token_in_rank.data_ptr <bool >(),
720709 channel_prefix_matrix.data_ptr <int >(),
721710 num_tokens,
711+ 0 , // num_worst_tokens (not exposed)
722712 static_cast <int >(hidden * recv_x.element_size () / sizeof (int4)),
723713 num_topk,
724714 num_experts,
725715 num_scales,
716+ scale_token_stride,
717+ scale_hidden_stride,
726718 buffer_ptrs_gpu,
727719 rank,
728720 num_ranks,
@@ -867,15 +859,11 @@ Buffer::intranode_combine(
867859 num_channels,
868860 num_recv_tokens,
869861 num_channels * num_ranks * 2 ,
870- task_fifo_ptrs_gpu,
871- head,
862+ barrier_signal_ptrs_gpu,
872863 rank,
873864 num_ranks,
874865 comm_stream);
875866
876- // NOTES: this function uses two FIFO slots (barrier before and after)
877- move_fifo_slots (2 );
878-
879867 // Combine data
880868 auto recv_x = ConvertPaddleTensorToDetailTensor (paddle::experimental::empty (
881869 {num_recv_tokens, hidden}, x.dtype (), x.place ()));
@@ -895,6 +883,8 @@ Buffer::intranode_combine(
895883 recv_topk_weights_ptr,
896884 x.data_ptr (),
897885 topk_weights_ptr,
886+ nullptr , // bias_ptrs[0] (not exposed)
887+ nullptr , // bias_ptrs[1] (not exposed)
898888 src_idx.data_ptr <int >(),
899889 rank_prefix_matrix.data_ptr <int >(),
900890 channel_prefix_matrix.data_ptr <int >(),
@@ -1084,7 +1074,7 @@ Buffer::internode_dispatch(
10841074
10851075 // FP8 scales checks
10861076 float * x_scales_ptr = nullptr ;
1087- int num_scales = 0 ;
1077+ int num_scales = 0 , scale_token_stride = 0 , scale_hidden_stride = 0 ;
10881078 if (x_scales.has_value ()) {
10891079 EP_HOST_ASSERT (x.element_size () == 1 );
10901080 EP_HOST_ASSERT (x_scales->scalar_type () == deep_ep::detail::kFloat32 );
@@ -1093,6 +1083,8 @@ Buffer::internode_dispatch(
10931083 EP_HOST_ASSERT (x_scales->size (0 ) == num_tokens);
10941084 num_scales = x_scales->dim () == 1 ? 1 : static_cast <int >(x_scales->size (1 ));
10951085 x_scales_ptr = x_scales->data_ptr <float >();
1086+ scale_token_stride = static_cast <int >(x_scales->stride (0 ));
1087+ scale_hidden_stride = static_cast <int >(x_scales->stride (1 ));
10961088 }
10971089
10981090 // Allocate all tensors on comm stream if set
@@ -1144,15 +1136,13 @@ Buffer::internode_dispatch(
11441136 config.num_max_rdma_chunked_recv_tokens ,
11451137 buffer_ptrs_gpu,
11461138 config.num_max_nvl_chunked_recv_tokens ,
1147- task_fifo_ptrs_gpu,
1148- head,
1139+ barrier_signal_ptrs_gpu,
11491140 rank,
11501141 comm_stream,
11511142 config.get_rdma_buffer_size_hint (hidden_int4 * sizeof (int4), num_ranks),
11521143 num_nvl_bytes,
11531144 true ,
11541145 low_latency_mode);
1155- move_fifo_slots (2 );
11561146 } else {
11571147 rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor (
11581148 paddle::experimental::empty ({num_rdma_ranks, num_channels},
@@ -1196,14 +1186,12 @@ Buffer::internode_dispatch(
11961186 config.num_max_rdma_chunked_recv_tokens ,
11971187 buffer_ptrs_gpu,
11981188 config.num_max_nvl_chunked_recv_tokens ,
1199- task_fifo_ptrs_gpu,
1200- head,
1189+ barrier_signal_ptrs_gpu,
12011190 rank,
12021191 comm_stream,
12031192 config.get_rdma_buffer_size_hint (hidden_int4 * sizeof (int4), num_ranks),
12041193 num_nvl_bytes,
12051194 low_latency_mode);
1206- move_fifo_slots (3 );
12071195
12081196 // Synchronize total received tokens and tokens per expert
12091197 auto start_time = std::chrono::high_resolution_clock::now ();
@@ -1320,12 +1308,14 @@ Buffer::internode_dispatch(
13201308 recv_rdma_rank_prefix_sum.data_ptr <int >(),
13211309 gbl_channel_prefix_matrix.data_ptr <int >(),
13221310 recv_gbl_rank_prefix_sum.data_ptr <int >(),
1311+ is_token_in_rank.data_ptr <bool >(),
13231312 num_tokens,
13241313 hidden_int4,
13251314 num_scales,
13261315 num_topk,
13271316 num_experts,
1328- is_token_in_rank.data_ptr <bool >(),
1317+ scale_token_stride,
1318+ scale_hidden_stride,
13291319 rdma_buffer_ptr,
13301320 config.num_max_rdma_chunked_send_tokens ,
13311321 config.num_max_rdma_chunked_recv_tokens ,
@@ -1523,15 +1513,13 @@ Buffer::internode_combine(
15231513 config.num_max_rdma_chunked_recv_tokens ,
15241514 buffer_ptrs_gpu,
15251515 config.num_max_nvl_chunked_recv_tokens ,
1526- task_fifo_ptrs_gpu,
1527- head,
1516+ barrier_signal_ptrs_gpu,
15281517 rank,
15291518 comm_stream,
15301519 config.get_rdma_buffer_size_hint (hidden_int4 * sizeof (int4), num_ranks),
15311520 num_nvl_bytes,
15321521 false ,
15331522 low_latency_mode);
1534- move_fifo_slots (2 );
15351523
15361524 // Launch data combine
15371525 auto combined_x =
@@ -1543,6 +1531,8 @@ Buffer::internode_combine(
15431531 is_combined_token_in_rank.data_ptr <bool >(),
15441532 x.data_ptr (),
15451533 topk_weights_ptr,
1534+ nullptr , // bias_ptrs[0] (not exposed)
1535+ nullptr , // bias_ptrs[1] (not exposed)
15461536 combined_rdma_head.data_ptr <int >(),
15471537 combined_nvl_head.data_ptr <int >(),
15481538 src_meta.data_ptr (),
0 commit comments