Skip to content

Commit 9a1eb5d

Browse files
ezhulenevtensorflower-gardener
authored andcommitted
[xla:gpu] Do not use ncclSend and ncclRecv directly and use NcclApi
PiperOrigin-RevId: 599035702
1 parent 7e02e4b commit 9a1eb5d

File tree

4 files changed

+118
-82
lines changed

4 files changed

+118
-82
lines changed

third_party/xla/xla/service/gpu/nccl_all_to_all_thunk.cc

Lines changed: 38 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ limitations under the License.
2222
#include <vector>
2323

2424
#include "absl/status/status.h"
25-
#include "absl/strings/str_format.h"
2625
#include "absl/strings/substitute.h"
2726
#include "mlir/IR/Value.h" // from @llvm-project
2827
#include "xla/hlo/ir/hlo_instruction.h"
@@ -33,11 +32,11 @@ limitations under the License.
3332
#include "xla/service/gpu/nccl_collective_thunk.h"
3433
#include "xla/shape.h"
3534
#include "xla/shape_util.h"
35+
#include "xla/status_macros.h"
36+
#include "xla/stream_executor/device_memory.h"
3637
#include "tsl/platform/errors.h"
37-
38-
#if XLA_ENABLE_XCCL
39-
#include "xla/stream_executor/gpu/gpu_stream.h"
40-
#endif
38+
#include "tsl/platform/logging.h"
39+
#include "tsl/platform/statusor.h"
4140

4241
namespace xla {
4342
namespace gpu {
@@ -151,60 +150,44 @@ absl::Status NcclAllToAllStartThunk::RunNcclCollective(
151150
absl::Status RunAllToAll(bool has_split_dimension,
152151
std::vector<DeviceBufferPair>& buffers,
153152
se::Stream& stream, ncclComm_t comm) {
154-
#if XLA_ENABLE_XCCL
155153
int device_ordinal = stream.parent()->device_ordinal();
156154
VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal;
157155

158-
se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);
159-
160156
TF_ASSIGN_OR_RETURN(
161157
int32_t num_participants,
162158
NcclApi::CommCount(reinterpret_cast<NcclApi::NcclCommHandle>(comm)));
163159

164160
TF_RETURN_IF_ERROR(NcclApi::GroupStart());
161+
165162
// AllToAll can operate in two modes. Either it specifies a split dimension,
166163
// in which case inputs are split and outputs concatenated in that dimension
167164
// (here, we only support dimension 0), or it takes a list of inputs
168165
// and produces a tuple of outputs.
169166
if (has_split_dimension) {
170-
for (size_t i = 0; i < buffers.size(); ++i) {
171-
DeviceBufferPair& buffer = buffers[i];
172-
const uint8_t* send_buffer =
173-
static_cast<uint8_t*>(buffer.source_buffer.opaque());
174-
uint8_t* recv_buffer =
175-
static_cast<uint8_t*>(buffer.destination_buffer.opaque());
176-
177-
TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
178-
ToNcclDataTypeAndCountMultiplier(
179-
buffer.element_type, Thunk::kNcclAllToAll));
180-
auto [dtype, multiplier] = dtype_and_multiplier;
181-
int64_t element_count = buffer.element_count;
182-
183-
TF_RET_CHECK(element_count % num_participants == 0)
167+
for (DeviceBufferPair& buffer : buffers) {
168+
TF_RET_CHECK(buffer.element_count % num_participants == 0)
184169
<< "Buffer was not an exact multiple of the number of participants.";
185-
size_t chunk_elements = element_count / num_participants;
186-
size_t chunk_bytes = chunk_elements * ShapeUtil::ByteSizeOfPrimitiveType(
187-
buffer.element_type);
188-
189-
for (int rank = 0; rank < num_participants; ++rank) {
190-
VLOG(3) << absl::StreamFormat(
191-
"Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
192-
"comm=%p, stream=%p)",
193-
send_buffer + rank * chunk_bytes, chunk_elements * multiplier, rank,
194-
static_cast<const void*>(comm), gpu_stream);
195-
XLA_NCCL_RETURN_IF_ERROR(ncclSend(send_buffer + rank * chunk_bytes,
196-
chunk_elements * multiplier, dtype,
197-
rank, comm, gpu_stream));
198-
199-
VLOG(3) << absl::StreamFormat(
200-
"Calling ncclRecv(recvbuff=%p, count=%d, peer=%d "
201-
"comm=%p, stream=%p)",
202-
recv_buffer + rank * chunk_bytes, chunk_elements * multiplier, rank,
203-
static_cast<const void*>(comm), gpu_stream);
204-
205-
XLA_NCCL_RETURN_IF_ERROR(ncclRecv(recv_buffer + rank * chunk_bytes,
206-
chunk_elements * multiplier, dtype,
207-
rank, comm, gpu_stream));
170+
171+
size_t chunk_elements = buffer.element_count / num_participants;
172+
173+
for (int peer = 0; peer < num_participants; ++peer) {
174+
TF_ASSIGN_OR_RETURN(
175+
se::DeviceMemoryBase send_slice,
176+
NcclApi::Slice(buffer.source_buffer, buffer.element_type,
177+
peer * chunk_elements, chunk_elements));
178+
179+
TF_ASSIGN_OR_RETURN(
180+
se::DeviceMemoryBase recv_slice,
181+
NcclApi::Slice(buffer.destination_buffer, buffer.element_type,
182+
peer * chunk_elements, chunk_elements));
183+
184+
TF_RETURN_IF_ERROR(NcclApi::Send(
185+
send_slice, buffer.element_type, chunk_elements, peer,
186+
reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));
187+
188+
TF_RETURN_IF_ERROR(NcclApi::Recv(
189+
recv_slice, buffer.element_type, chunk_elements, peer,
190+
reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));
208191
}
209192
}
210193
} else {
@@ -213,45 +196,18 @@ absl::Status RunAllToAll(bool has_split_dimension,
213196

214197
for (size_t i = 0; i < buffers.size(); ++i) {
215198
DeviceBufferPair& buffer = buffers[i];
216-
const uint8_t* send_buffer =
217-
static_cast<uint8_t*>(buffer.source_buffer.opaque());
218-
uint8_t* recv_buffer =
219-
static_cast<uint8_t*>(buffer.destination_buffer.opaque());
220-
221-
TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
222-
ToNcclDataTypeAndCountMultiplier(
223-
buffer.element_type, Thunk::kNcclAllToAll));
224-
auto [dtype, multiplier] = dtype_and_multiplier;
225-
int64_t element_count = buffer.element_count * multiplier;
226-
227-
VLOG(3) << absl::StreamFormat(
228-
"Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
229-
"comm=%p, stream=%p)",
230-
send_buffer, element_count, i, static_cast<const void*>(comm),
231-
gpu_stream);
232-
233-
XLA_NCCL_RETURN_IF_ERROR(ncclSend(send_buffer, element_count, dtype,
234-
/*rank=*/i, comm, gpu_stream));
235-
236-
VLOG(3) << absl::StreamFormat(
237-
"Calling ncclRecv(recvbuff=%p, count=%d, peer=%d "
238-
"comm=%p, stream=%p)",
239-
recv_buffer, element_count, i, static_cast<const void*>(comm),
240-
gpu_stream);
241-
242-
XLA_NCCL_RETURN_IF_ERROR(ncclRecv(recv_buffer, element_count, dtype,
243-
/*rank=*/i, comm, gpu_stream));
199+
200+
TF_RETURN_IF_ERROR(NcclApi::Send(
201+
buffer.source_buffer, buffer.element_type, buffer.element_count, i,
202+
reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));
203+
204+
TF_RETURN_IF_ERROR(NcclApi::Recv(
205+
buffer.destination_buffer, buffer.element_type, buffer.element_count,
206+
i, reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));
244207
}
245208
}
246-
TF_RETURN_IF_ERROR(NcclApi::GroupEnd());
247-
248-
VLOG(3) << "Done performing all-to-all for ordinal: " << device_ordinal;
249-
return absl::OkStatus();
250-
#else // XLA_ENABLE_XCCL
251-
return Unimplemented(
252-
"NCCL support is not available: this binary was not built with a CUDA "
253-
"compiler, which is necessary to build the NCCL source library.");
254-
#endif // XLA_ENABLE_XCCL
209+
210+
return NcclApi::GroupEnd();
255211
}
256212

257213
} // namespace gpu

third_party/xla/xla/service/gpu/nccl_api.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ limitations under the License.
2929
#include "xla/primitive_util.h"
3030
#include "xla/service/collective_ops_utils.h"
3131
#include "xla/service/gpu/nccl_clique_key.h"
32+
#include "xla/shape_util.h"
3233
#include "xla/stream_executor/device_memory.h"
3334
#include "xla/stream_executor/gpu/gpu_stream.h"
3435
#include "xla/stream_executor/stream.h"
@@ -172,6 +173,14 @@ static ncclUniqueId AsNcclUniqueId(const NcclCliqueId& clique_id) {
172173
return id;
173174
}
174175

176+
absl::StatusOr<se::DeviceMemoryBase> NcclApi::Slice(se::DeviceMemoryBase buff,
177+
PrimitiveType dtype,
178+
size_t offset,
179+
size_t count) {
180+
size_t multiplier = ShapeUtil::ByteSizeOfPrimitiveType(dtype);
181+
return buff.GetByteSlice(offset * multiplier, count * multiplier);
182+
}
183+
175184
absl::StatusOr<NcclCliqueId> NcclApi::GetUniqueId() {
176185
VLOG(3) << "Get NCCL unique id";
177186
ncclUniqueId id;
@@ -289,4 +298,38 @@ absl::Status NcclApi::AllGather(se::DeviceMemoryBase send_buffer,
289298
nccl_dtype, Cast(comm), se::gpu::AsGpuStreamValue(stream)));
290299
}
291300

301+
absl::Status NcclApi::Send(se::DeviceMemoryBase send_buffer,
302+
PrimitiveType dtype, size_t count, int32_t peer,
303+
NcclCommHandle comm, se::Stream* stream) {
304+
VLOG(3) << absl::StreamFormat(
305+
"Launch NCCL Send operation on device #%d; send_buffer=%p; dtype=%s; "
306+
"count=%d; peer=%d; comm=%p; stream=%p",
307+
stream->parent()->device_ordinal(), send_buffer.opaque(),
308+
primitive_util::LowercasePrimitiveTypeName(dtype), count, peer, comm,
309+
stream);
310+
311+
TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false));
312+
313+
return XLA_NCCL_STATUS(
314+
ncclSend(send_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype,
315+
peer, Cast(comm), se::gpu::AsGpuStreamValue(stream)));
316+
}
317+
318+
absl::Status NcclApi::Recv(se::DeviceMemoryBase recv_buffer,
319+
PrimitiveType dtype, size_t count, int32_t peer,
320+
NcclCommHandle comm, se::Stream* stream) {
321+
VLOG(3) << absl::StreamFormat(
322+
"Launch NCCL Recv operation on device #%d; recv_buffer=%p; dtype=%s; "
323+
"count=%d; peer=%d; comm=%p; stream=%p",
324+
stream->parent()->device_ordinal(), recv_buffer.opaque(),
325+
primitive_util::LowercasePrimitiveTypeName(dtype), count, peer, comm,
326+
stream);
327+
328+
TF_ASSIGN_OR_RETURN(ncclDataType_t nccl_dtype, ToNcclDataType(dtype, false));
329+
330+
return XLA_NCCL_STATUS(
331+
ncclRecv(recv_buffer.opaque(), ToNcclCount(dtype, count), nccl_dtype,
332+
peer, Cast(comm), se::gpu::AsGpuStreamValue(stream)));
333+
}
334+
292335
} // namespace xla::gpu

third_party/xla/xla/service/gpu/nccl_api.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ struct NcclApi {
4545
// Convenience handles for defining API functions.
4646
using NcclCommHandle = NcclComm*;
4747

48+
// Returns a slice of device memory `buff` containing `count` values of data
49+
// type `dtype` starting from `offset`.
50+
static absl::StatusOr<se::DeviceMemoryBase> Slice(se::DeviceMemoryBase buff,
51+
PrimitiveType dtype,
52+
size_t offset,
53+
size_t count);
54+
4855
// Creates a new unique clique id.
4956
//
5057
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclgetuniqueid
@@ -111,6 +118,20 @@ struct NcclApi {
111118
se::DeviceMemoryBase recv_buffer,
112119
PrimitiveType dtype, size_t count,
113120
NcclCommHandle comm, se::Stream* stream);
121+
122+
// Send data from `send_buff` to rank `peer`.
123+
//
124+
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend
125+
static absl::Status Send(se::DeviceMemoryBase send_buffer,
126+
PrimitiveType dtype, size_t count, int32_t peer,
127+
NcclCommHandle comm, se::Stream* stream);
128+
129+
// Receive data from rank `peer` into `recv_buff`.
130+
//
131+
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv
132+
static absl::Status Recv(se::DeviceMemoryBase recv_buffer,
133+
PrimitiveType dtype, size_t count, int32_t peer,
134+
NcclCommHandle comm, se::Stream* stream);
114135
};
115136

116137
//===----------------------------------------------------------------------===//

third_party/xla/xla/service/gpu/nccl_api_stub.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ limitations under the License.
2626

2727
namespace xla::gpu {
2828

29+
absl::StatusOr<se::DeviceMemoryBase> NcclApi::Slice(se::DeviceMemoryBase,
30+
PrimitiveType, size_t,
31+
size_t) {
32+
return absl::UnimplementedError("XLA compiled without NCCL support");
33+
}
34+
2935
absl::StatusOr<NcclCliqueId> NcclApi::GetUniqueId() {
3036
return absl::UnimplementedError("XLA compiled without NCCL support");
3137
}
@@ -74,4 +80,14 @@ absl::Status NcclApi::AllGather(se::DeviceMemoryBase, se::DeviceMemoryBase,
7480
return absl::UnimplementedError("XLA compiled without NCCL support");
7581
}
7682

83+
absl::Status NcclApi::Send(se::DeviceMemoryBase, PrimitiveType, size_t, int32_t,
84+
NcclCommHandle, se::Stream*) {
85+
return absl::UnimplementedError("XLA compiled without NCCL support");
86+
}
87+
88+
absl::Status NcclApi::Recv(se::DeviceMemoryBase, PrimitiveType, size_t, int32_t,
89+
NcclCommHandle, se::Stream*) {
90+
return absl::UnimplementedError("XLA compiled without NCCL support");
91+
}
92+
7793
} // namespace xla::gpu

0 commit comments

Comments
 (0)