Skip to content

Commit e2a8155

Browse files
authored
Revert "Update deep_ep intranode & internode kernels (#74284)" (#76090)
1 parent 8e10916 commit e2a8155

File tree

11 files changed

+804
-1308
lines changed

11 files changed

+804
-1308
lines changed

paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp

Lines changed: 60 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,10 @@ Buffer::Buffer(int rank,
8383
calc_ctx = reinterpret_cast<phi::GPUContext*>(
8484
reinterpret_cast<paddle::distributed::ProcessGroupNCCL*>(pg)
8585
->GetDeviceContext(place, true));
86-
87-
// Metadata memory
88-
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
89-
int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*);
90-
int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*);
86+
// Task fifo memory
87+
int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS;
88+
int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS;
89+
int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS;
9190

9291
// Common checks
9392
EP_HOST_ASSERT(
@@ -106,35 +105,40 @@ Buffer::Buffer(int rank,
106105
EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS || low_latency_mode);
107106

108107
// Get ranks
108+
// CUDA_CHECK(cudaGetDevice(&device_id));
109109
rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS;
110-
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS);
110+
num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS),
111111
num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
112112

113113
// Get device info
114114
cudaDeviceProp device_prop = {};
115115
CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id));
116116

117117
if (num_nvl_bytes > 0) {
118-
// Local IPC: alloc local memory and set local IPC handles
119-
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank],
120-
num_nvl_bytes + barrier_signal_bytes +
121-
buffer_ptr_bytes + barrier_signal_ptr_bytes));
118+
// Local IPC: alloc local memory and set local IPC handle
119+
CUDA_CHECK(cudaMalloc(
120+
&buffer_ptrs[nvl_rank],
121+
num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes));
122122
CUDA_CHECK(
123123
cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
124-
buffer_ptrs_gpu =
125-
reinterpret_cast<void**>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) +
126-
num_nvl_bytes + barrier_signal_bytes);
127-
128-
// Set barrier signals
129-
barrier_signal_ptrs[nvl_rank] = reinterpret_cast<int*>(
130-
static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
131-
barrier_signal_ptrs_gpu = reinterpret_cast<int**>(
132-
static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
133-
barrier_signal_bytes + buffer_ptr_bytes);
124+
buffer_ptrs_gpu = reinterpret_cast<void**>(
125+
reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
126+
fifo_bytes);
127+
128+
// Set task fifo
129+
EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0);
130+
task_fifo_ptrs[nvl_rank] = reinterpret_cast<int*>(
131+
reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes);
132+
task_fifo_ptrs_gpu = reinterpret_cast<int**>(
133+
reinterpret_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes +
134+
fifo_bytes + buffer_ptr_bytes);
134135

135136
// No need to synchronize, will do a full device sync during `sync`
136137
CUDA_CHECK(cudaMemsetAsync(
137-
barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream));
138+
buffer_ptrs[nvl_rank],
139+
0,
140+
num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes,
141+
comm_stream));
138142
}
139143

140144
// Create 32 MiB workspace
@@ -180,7 +184,8 @@ Buffer::~Buffer() noexcept(false) {
180184
if (num_nvl_bytes > 0) {
181185
// Barrier
182186
intranode::barrier(
183-
barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream);
187+
task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream);
188+
move_fifo_slots();
184189
CUDA_CHECK(cudaDeviceSynchronize());
185190

186191
// Close remote IPC
@@ -211,6 +216,10 @@ Buffer::~Buffer() noexcept(false) {
211216
CUDA_CHECK(cudaFreeHost(const_cast<int*>(moe_recv_expert_counter)));
212217
}
213218

219+
void Buffer::move_fifo_slots(int num_slots) {
220+
head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS;
221+
}
222+
214223
bool Buffer::is_available() const { return available; }
215224

216225
bool Buffer::is_internode_available() const {
@@ -259,7 +268,7 @@ void Buffer::sync(
259268

260269
// Sync IPC handles
261270
if (num_nvl_bytes > 0) {
262-
EP_HOST_ASSERT(num_ranks == device_ids.size());
271+
EP_HOST_ASSERT(num_ranks == static_cast<int64_t>(device_ids.size()));
263272
EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size());
264273
for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks;
265274
++i) {
@@ -271,22 +280,22 @@ void Buffer::sync(
271280
ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE);
272281
CUDA_CHECK(cudaIpcOpenMemHandle(
273282
&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess));
274-
barrier_signal_ptrs[i] = reinterpret_cast<int*>(
275-
static_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
283+
task_fifo_ptrs[i] = reinterpret_cast<int*>(
284+
reinterpret_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
276285
} else {
277286
EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved,
278287
handle_str.c_str(),
279288
CUDA_IPC_HANDLE_SIZE) == 0);
280289
}
281290
}
282291

283-
// Copy all buffer and barrier signal pointers to GPU
292+
// Copy all buffer and task pointers to GPU
284293
CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu,
285294
buffer_ptrs,
286295
sizeof(void*) * NUM_MAX_NVL_PEERS,
287296
cudaMemcpyHostToDevice));
288-
CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu,
289-
barrier_signal_ptrs,
297+
CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu,
298+
task_fifo_ptrs,
290299
sizeof(int*) * NUM_MAX_NVL_PEERS,
291300
cudaMemcpyHostToDevice));
292301
CUDA_CHECK(cudaDeviceSynchronize());
@@ -530,7 +539,7 @@ Buffer::intranode_dispatch(
530539

531540
// FP8 scales checks
532541
float* x_scales_ptr = nullptr;
533-
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
542+
int num_scales = 0;
534543
if (x_scales.has_value()) {
535544
EP_HOST_ASSERT(x.element_size() == 1);
536545
EP_HOST_ASSERT(x_scales->scalar_type() == deep_ep::detail::kFloat32);
@@ -539,8 +548,6 @@ Buffer::intranode_dispatch(
539548
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
540549
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
541550
x_scales_ptr = x_scales->data_ptr<float>();
542-
scale_token_stride = static_cast<int>(x_scales->stride(0));
543-
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
544551
}
545552

546553
// Allocate all tensors on comm stream if set
@@ -579,10 +586,12 @@ Buffer::intranode_dispatch(
579586
intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr<int>(),
580587
num_memset_int,
581588
buffer_ptrs_gpu,
582-
barrier_signal_ptrs_gpu,
589+
task_fifo_ptrs_gpu,
590+
head,
583591
rank,
584592
num_ranks,
585593
comm_stream);
594+
move_fifo_slots(2);
586595
} else {
587596
rank_prefix_matrix = ConvertPaddleTensorToDetailTensor(
588597
paddle::experimental::empty({num_ranks, num_ranks},
@@ -617,10 +626,12 @@ Buffer::intranode_dispatch(
617626
num_memset_int,
618627
expert_alignment,
619628
buffer_ptrs_gpu,
620-
barrier_signal_ptrs_gpu,
629+
task_fifo_ptrs_gpu,
630+
head,
621631
rank,
622632
comm_stream,
623633
num_channels);
634+
move_fifo_slots(3);
624635

625636
// Synchronize total received tokens and tokens per expert
626637
auto start_time = std::chrono::high_resolution_clock::now();
@@ -730,13 +741,10 @@ Buffer::intranode_dispatch(
730741
is_token_in_rank.data_ptr<bool>(),
731742
channel_prefix_matrix.data_ptr<int>(),
732743
num_tokens,
733-
0, // num_worst_tokens (not exposed)
734744
static_cast<int>(hidden * recv_x.element_size() / sizeof(int4)),
735745
num_topk,
736746
num_experts,
737747
num_scales,
738-
scale_token_stride,
739-
scale_hidden_stride,
740748
buffer_ptrs_gpu,
741749
rank,
742750
num_ranks,
@@ -881,11 +889,15 @@ Buffer::intranode_combine(
881889
num_channels,
882890
num_recv_tokens,
883891
num_channels * num_ranks * 2,
884-
barrier_signal_ptrs_gpu,
892+
task_fifo_ptrs_gpu,
893+
head,
885894
rank,
886895
num_ranks,
887896
comm_stream);
888897

898+
// NOTES: this function uses two FIFO slots (barrier before and after)
899+
move_fifo_slots(2);
900+
889901
// Combine data
890902
auto recv_x = ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
891903
{num_recv_tokens, hidden}, x.dtype(), x.place()));
@@ -905,8 +917,6 @@ Buffer::intranode_combine(
905917
recv_topk_weights_ptr,
906918
x.data_ptr(),
907919
topk_weights_ptr,
908-
nullptr, // bias_ptrs[0] (not exposed)
909-
nullptr, // bias_ptrs[1] (not exposed)
910920
src_idx.data_ptr<int>(),
911921
rank_prefix_matrix.data_ptr<int>(),
912922
channel_prefix_matrix.data_ptr<int>(),
@@ -1096,7 +1106,7 @@ Buffer::internode_dispatch(
10961106

10971107
// FP8 scales checks
10981108
float* x_scales_ptr = nullptr;
1099-
int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0;
1109+
int num_scales = 0;
11001110
if (x_scales.has_value()) {
11011111
EP_HOST_ASSERT(x.element_size() == 1);
11021112
EP_HOST_ASSERT(x_scales->scalar_type() == deep_ep::detail::kFloat32);
@@ -1105,8 +1115,6 @@ Buffer::internode_dispatch(
11051115
EP_HOST_ASSERT(x_scales->size(0) == num_tokens);
11061116
num_scales = x_scales->dim() == 1 ? 1 : static_cast<int>(x_scales->size(1));
11071117
x_scales_ptr = x_scales->data_ptr<float>();
1108-
scale_token_stride = static_cast<int>(x_scales->stride(0));
1109-
scale_hidden_stride = static_cast<int>(x_scales->stride(1));
11101118
}
11111119

11121120
// Allocate all tensors on comm stream if set
@@ -1161,13 +1169,15 @@ Buffer::internode_dispatch(
11611169
config.num_max_rdma_chunked_recv_tokens,
11621170
buffer_ptrs_gpu,
11631171
config.num_max_nvl_chunked_recv_tokens,
1164-
barrier_signal_ptrs_gpu,
1172+
task_fifo_ptrs_gpu,
1173+
head,
11651174
rank,
11661175
comm_stream,
11671176
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
11681177
num_nvl_bytes,
11691178
true,
11701179
low_latency_mode);
1180+
move_fifo_slots(2);
11711181
} else {
11721182
rdma_channel_prefix_matrix = ConvertPaddleTensorToDetailTensor(
11731183
paddle::experimental::empty({num_rdma_ranks, num_channels},
@@ -1211,12 +1221,14 @@ Buffer::internode_dispatch(
12111221
config.num_max_rdma_chunked_recv_tokens,
12121222
buffer_ptrs_gpu,
12131223
config.num_max_nvl_chunked_recv_tokens,
1214-
barrier_signal_ptrs_gpu,
1224+
task_fifo_ptrs_gpu,
1225+
head,
12151226
rank,
12161227
comm_stream,
12171228
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
12181229
num_nvl_bytes,
12191230
low_latency_mode);
1231+
move_fifo_slots(3);
12201232

12211233
// Synchronize total received tokens and tokens per expert
12221234
auto start_time = std::chrono::high_resolution_clock::now();
@@ -1333,14 +1345,12 @@ Buffer::internode_dispatch(
13331345
recv_rdma_rank_prefix_sum.data_ptr<int>(),
13341346
gbl_channel_prefix_matrix.data_ptr<int>(),
13351347
recv_gbl_rank_prefix_sum.data_ptr<int>(),
1336-
is_token_in_rank.data_ptr<bool>(),
13371348
num_tokens,
13381349
hidden_int4,
13391350
num_scales,
13401351
num_topk,
13411352
num_experts,
1342-
scale_token_stride,
1343-
scale_hidden_stride,
1353+
is_token_in_rank.data_ptr<bool>(),
13441354
rdma_buffer_ptr,
13451355
config.num_max_rdma_chunked_send_tokens,
13461356
config.num_max_rdma_chunked_recv_tokens,
@@ -1538,13 +1548,15 @@ Buffer::internode_combine(
15381548
config.num_max_rdma_chunked_recv_tokens,
15391549
buffer_ptrs_gpu,
15401550
config.num_max_nvl_chunked_recv_tokens,
1541-
barrier_signal_ptrs_gpu,
1551+
task_fifo_ptrs_gpu,
1552+
head,
15421553
rank,
15431554
comm_stream,
15441555
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
15451556
num_nvl_bytes,
15461557
false,
15471558
low_latency_mode);
1559+
move_fifo_slots(2);
15481560

15491561
// Launch data combine
15501562
auto combined_x =
@@ -1556,8 +1568,6 @@ Buffer::internode_combine(
15561568
is_combined_token_in_rank.data_ptr<bool>(),
15571569
x.data_ptr(),
15581570
topk_weights_ptr,
1559-
nullptr, // bias_ptrs[0] (not exposed)
1560-
nullptr, // bias_ptrs[1] (not exposed)
15611571
combined_rdma_head.data_ptr<int>(),
15621572
combined_nvl_head.data_ptr<int>(),
15631573
src_meta.data_ptr(),

paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,10 @@ struct Buffer {
8181
// After IPC/NVSHMEM synchronization, this flag will be true
8282
bool available = false;
8383

84-
// Barrier signals
85-
int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
86-
int** barrier_signal_ptrs_gpu = nullptr;
84+
// Task fifo
85+
int head = 0;
86+
int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr};
87+
int** task_fifo_ptrs_gpu = nullptr;
8788

8889
// Workspace
8990
void* workspace = nullptr;
@@ -100,6 +101,9 @@ struct Buffer {
100101
volatile int* moe_recv_rdma_counter = nullptr;
101102
int* moe_recv_rdma_counter_mapped = nullptr;
102103

104+
private:
105+
void move_fifo_slots(int num_slots = 1);
106+
103107
public:
104108
Buffer(int rank,
105109
int num_ranks,

paddle/fluid/distributed/collective/deep_ep/include/types.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@ struct Tensor {
7373
}
7474

7575
int64_t element_size() const { return phi::SizeOf(raw_tensor_.dtype()); }
76-
77-
int64_t stride(int64_t d) const { return raw_tensor_.strides().at(d); }
7876
};
7977

8078
} // namespace deep_ep::detail

0 commit comments

Comments
 (0)