@@ -22,7 +22,6 @@ limitations under the License.
2222#include < vector>
2323
2424#include " absl/status/status.h"
25- #include " absl/strings/str_format.h"
2625#include " absl/strings/substitute.h"
2726#include " mlir/IR/Value.h" // from @llvm-project
2827#include " xla/hlo/ir/hlo_instruction.h"
@@ -33,11 +32,11 @@ limitations under the License.
3332#include " xla/service/gpu/nccl_collective_thunk.h"
3433#include " xla/shape.h"
3534#include " xla/shape_util.h"
35+ #include " xla/status_macros.h"
36+ #include " xla/stream_executor/device_memory.h"
3637#include " tsl/platform/errors.h"
37-
38- #if XLA_ENABLE_XCCL
39- #include " xla/stream_executor/gpu/gpu_stream.h"
40- #endif
38+ #include " tsl/platform/logging.h"
39+ #include " tsl/platform/statusor.h"
4140
4241namespace xla {
4342namespace gpu {
@@ -151,60 +150,44 @@ absl::Status NcclAllToAllStartThunk::RunNcclCollective(
151150absl::Status RunAllToAll (bool has_split_dimension,
152151 std::vector<DeviceBufferPair>& buffers,
153152 se::Stream& stream, ncclComm_t comm) {
154- #if XLA_ENABLE_XCCL
155153 int device_ordinal = stream.parent ()->device_ordinal ();
156154 VLOG (3 ) << " Performing all-to-all from device ordinal: " << device_ordinal;
157155
158- se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue (&stream);
159-
160156 TF_ASSIGN_OR_RETURN (
161157 int32_t num_participants,
162158 NcclApi::CommCount (reinterpret_cast <NcclApi::NcclCommHandle>(comm)));
163159
164160 TF_RETURN_IF_ERROR (NcclApi::GroupStart ());
161+
165162 // AllToAll can operate in two modes. Either it specifies a split dimension,
166163 // in which case inputs are split and outputs concatenated in that dimension
167164 // (here, we only support dimension 0), or it takes a list of inputs
168165 // and produces a tuple of outputs.
169166 if (has_split_dimension) {
170- for (size_t i = 0 ; i < buffers.size (); ++i) {
171- DeviceBufferPair& buffer = buffers[i];
172- const uint8_t * send_buffer =
173- static_cast <uint8_t *>(buffer.source_buffer .opaque ());
174- uint8_t * recv_buffer =
175- static_cast <uint8_t *>(buffer.destination_buffer .opaque ());
176-
177- TF_ASSIGN_OR_RETURN (auto dtype_and_multiplier,
178- ToNcclDataTypeAndCountMultiplier (
179- buffer.element_type , Thunk::kNcclAllToAll ));
180- auto [dtype, multiplier] = dtype_and_multiplier;
181- int64_t element_count = buffer.element_count ;
182-
183- TF_RET_CHECK (element_count % num_participants == 0 )
167+ for (DeviceBufferPair& buffer : buffers) {
168+ TF_RET_CHECK (buffer.element_count % num_participants == 0 )
184169 << " Buffer was not an exact multiple of the number of participants." ;
185- size_t chunk_elements = element_count / num_participants;
186- size_t chunk_bytes = chunk_elements * ShapeUtil::ByteSizeOfPrimitiveType (
187- buffer.element_type );
188-
189- for (int rank = 0 ; rank < num_participants; ++rank) {
190- VLOG (3 ) << absl::StreamFormat (
191- " Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
192- " comm=%p, stream=%p)" ,
193- send_buffer + rank * chunk_bytes, chunk_elements * multiplier, rank,
194- static_cast <const void *>(comm), gpu_stream);
195- XLA_NCCL_RETURN_IF_ERROR (ncclSend (send_buffer + rank * chunk_bytes,
196- chunk_elements * multiplier, dtype,
197- rank, comm, gpu_stream));
198-
199- VLOG (3 ) << absl::StreamFormat (
200- " Calling ncclRecv(recvbuff=%p, count=%d, peer=%d "
201- " comm=%p, stream=%p)" ,
202- recv_buffer + rank * chunk_bytes, chunk_elements * multiplier, rank,
203- static_cast <const void *>(comm), gpu_stream);
204-
205- XLA_NCCL_RETURN_IF_ERROR (ncclRecv (recv_buffer + rank * chunk_bytes,
206- chunk_elements * multiplier, dtype,
207- rank, comm, gpu_stream));
170+
171+ size_t chunk_elements = buffer.element_count / num_participants;
172+
173+ for (int peer = 0 ; peer < num_participants; ++peer) {
174+ TF_ASSIGN_OR_RETURN (
175+ se::DeviceMemoryBase send_slice,
176+ NcclApi::Slice (buffer.source_buffer , buffer.element_type ,
177+ peer * chunk_elements, chunk_elements));
178+
179+ TF_ASSIGN_OR_RETURN (
180+ se::DeviceMemoryBase recv_slice,
181+ NcclApi::Slice (buffer.destination_buffer , buffer.element_type ,
182+ peer * chunk_elements, chunk_elements));
183+
184+ TF_RETURN_IF_ERROR (NcclApi::Send (
185+ send_slice, buffer.element_type , chunk_elements, peer,
186+ reinterpret_cast <NcclApi::NcclCommHandle>(comm), &stream));
187+
188+ TF_RETURN_IF_ERROR (NcclApi::Recv (
189+ recv_slice, buffer.element_type , chunk_elements, peer,
190+ reinterpret_cast <NcclApi::NcclCommHandle>(comm), &stream));
208191 }
209192 }
210193 } else {
@@ -213,45 +196,18 @@ absl::Status RunAllToAll(bool has_split_dimension,
213196
214197 for (size_t i = 0 ; i < buffers.size (); ++i) {
215198 DeviceBufferPair& buffer = buffers[i];
216- const uint8_t * send_buffer =
217- static_cast <uint8_t *>(buffer.source_buffer .opaque ());
218- uint8_t * recv_buffer =
219- static_cast <uint8_t *>(buffer.destination_buffer .opaque ());
220-
221- TF_ASSIGN_OR_RETURN (auto dtype_and_multiplier,
222- ToNcclDataTypeAndCountMultiplier (
223- buffer.element_type , Thunk::kNcclAllToAll ));
224- auto [dtype, multiplier] = dtype_and_multiplier;
225- int64_t element_count = buffer.element_count * multiplier;
226-
227- VLOG (3 ) << absl::StreamFormat (
228- " Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
229- " comm=%p, stream=%p)" ,
230- send_buffer, element_count, i, static_cast <const void *>(comm),
231- gpu_stream);
232-
233- XLA_NCCL_RETURN_IF_ERROR (ncclSend (send_buffer, element_count, dtype,
234- /* rank=*/ i, comm, gpu_stream));
235-
236- VLOG (3 ) << absl::StreamFormat (
237- " Calling ncclRecv(recvbuff=%p, count=%d, peer=%d "
238- " comm=%p, stream=%p)" ,
239- recv_buffer, element_count, i, static_cast <const void *>(comm),
240- gpu_stream);
241-
242- XLA_NCCL_RETURN_IF_ERROR (ncclRecv (recv_buffer, element_count, dtype,
243- /* rank=*/ i, comm, gpu_stream));
199+
200+ TF_RETURN_IF_ERROR (NcclApi::Send (
201+ buffer.source_buffer , buffer.element_type , buffer.element_count , i,
202+ reinterpret_cast <NcclApi::NcclCommHandle>(comm), &stream));
203+
204+ TF_RETURN_IF_ERROR (NcclApi::Recv (
205+ buffer.destination_buffer , buffer.element_type , buffer.element_count ,
206+ i, reinterpret_cast <NcclApi::NcclCommHandle>(comm), &stream));
244207 }
245208 }
246- TF_RETURN_IF_ERROR (NcclApi::GroupEnd ());
247-
248- VLOG (3 ) << " Done performing all-to-all for ordinal: " << device_ordinal;
249- return absl::OkStatus ();
250- #else // XLA_ENABLE_XCCL
251- return Unimplemented (
252- " NCCL support is not available: this binary was not built with a CUDA "
253- " compiler, which is necessary to build the NCCL source library." );
254- #endif // XLA_ENABLE_XCCL
209+
210+ return NcclApi::GroupEnd ();
255211}
256212
257213} // namespace gpu
0 commit comments