Skip to content

Commit 628e1ac

Browse files
ezhulenevtensorflower-gardener
authored andcommitted
[xla:gpu] Do not use ncclSend and ncclRecv directly and use NcclApi part #2
PiperOrigin-RevId: 599037622
1 parent 9a1eb5d commit 628e1ac

File tree

2 files changed

+14
-53
lines changed

2 files changed

+14
-53
lines changed

third_party/xla/xla/service/gpu/nccl_recv_thunk.cc

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,18 @@ limitations under the License.
1515

1616
#include "xla/service/gpu/nccl_recv_thunk.h"
1717

18+
#include <cstdint>
1819
#include <optional>
1920
#include <string>
20-
#include <utility>
2121
#include <vector>
2222

2323
#include "absl/status/status.h"
2424
#include "absl/strings/string_view.h"
2525
#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h"
2626
#include "xla/service/collective_ops_utils.h"
27+
#include "xla/service/gpu/nccl_api.h"
2728
#include "xla/stream_executor/stream.h"
28-
29-
#if XLA_ENABLE_XCCL
30-
#include "xla/stream_executor/gpu/gpu_stream.h"
31-
#endif
29+
#include "tsl/platform/errors.h"
3230

3331
namespace xla {
3432
namespace gpu {
@@ -102,7 +100,6 @@ absl::Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target,
102100
DeviceBufferPair& buffer, se::Stream& stream,
103101
ncclComm_t comm, absl::string_view device_string,
104102
int64_t current_id) {
105-
#if XLA_ENABLE_XCCL
106103
// Determine the source IDs for this instance. The source ID is the ID for
107104
// the peer that will copy its data to this instance. If there is no source,
108105
// just memzero() the destination buffer.
@@ -116,23 +113,12 @@ absl::Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target,
116113
VLOG(3) << absl::StreamFormat("%s : id = %d, source_id = %d", device_string,
117114
current_id, source_id.value_or(-1));
118115

119-
TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
120-
ToNcclDataTypeAndCountMultiplier(
121-
buffer.element_type, Thunk::kNcclCollectivePermute));
122-
ncclDataType_t dtype = dtype_and_multiplier.first;
123-
int64_t element_count = buffer.element_count * dtype_and_multiplier.second;
124-
125-
se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);
126-
127116
// Receive data from the source peer to the destination buffer.
128117
if (source_id) {
129-
VLOG(3) << absl::StreamFormat(
130-
"%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, "
131-
"stream=%p)",
132-
device_string, dest_addr.opaque(), element_count, *source_id,
133-
static_cast<const void*>(comm), gpu_stream);
134-
XLA_NCCL_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype,
135-
*source_id, comm, gpu_stream));
118+
TF_RETURN_IF_ERROR(NcclApi::Recv(
119+
dest_addr, buffer.element_type, buffer.element_count, *source_id,
120+
reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));
121+
136122
} else {
137123
// If there is no source peer, i.e. no sender to this instance, zero out
138124
// the destination buffer.
@@ -141,11 +127,6 @@ absl::Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target,
141127
stream.ThenMemZero(&dest_addr, dest_addr.size());
142128
}
143129
return absl::OkStatus();
144-
#else // XLA_ENABLE_XCCL
145-
return Unimplemented(
146-
"NCCL support is not available: this binary was not built with a CUDA "
147-
"compiler, which is necessary to build the NCCL source library.");
148-
#endif // XLA_ENABLE_XCCL
149130
}
150131

151132
} // namespace gpu

third_party/xla/xla/service/gpu/nccl_send_thunk.cc

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,18 @@ limitations under the License.
1515

1616
#include "xla/service/gpu/nccl_send_thunk.h"
1717

18+
#include <cstdint>
1819
#include <optional>
1920
#include <string>
20-
#include <utility>
2121
#include <vector>
2222

2323
#include "absl/status/status.h"
2424
#include "absl/strings/string_view.h"
2525
#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h"
2626
#include "xla/service/collective_ops_utils.h"
27+
#include "xla/service/gpu/nccl_api.h"
2728
#include "xla/stream_executor/stream.h"
28-
29-
#if XLA_ENABLE_XCCL
30-
#include "xla/stream_executor/gpu/gpu_stream.h"
31-
#endif
29+
#include "tsl/platform/errors.h"
3230

3331
namespace xla {
3432
namespace gpu {
@@ -102,10 +100,8 @@ absl::Status RunSend(NcclP2PConfig::SourceTargetMapEntry source_target,
102100
DeviceBufferPair& buffer, se::Stream& stream,
103101
ncclComm_t comm, absl::string_view device_string,
104102
int64_t current_id) {
105-
#if XLA_ENABLE_XCCL
106103
// Determine the target IDs for this instance. The target ID is the ID
107104
// to which this instance will copy its data.
108-
109105
int device_ordinal = stream.parent()->device_ordinal();
110106
VLOG(3) << "Performing collective permute from device ordinal: "
111107
<< device_ordinal << "current_id " << current_id;
@@ -116,30 +112,14 @@ absl::Status RunSend(NcclP2PConfig::SourceTargetMapEntry source_target,
116112
VLOG(3) << absl::StreamFormat("%s : id = %d, target_id = %d", device_string,
117113
current_id, target_id.value_or(-1));
118114

119-
TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
120-
ToNcclDataTypeAndCountMultiplier(
121-
buffer.element_type, Thunk::kNcclCollectivePermute));
122-
ncclDataType_t dtype = dtype_and_multiplier.first;
123-
int64_t element_count = buffer.element_count * dtype_and_multiplier.second;
124-
125-
se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);
126-
127115
// Send source buffer to target peer if needed.
128116
if (target_id) {
129-
VLOG(3) << absl::StreamFormat(
130-
"%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
131-
"comm=%p, stream=%p)",
132-
device_string, src_addr.opaque(), element_count, *target_id,
133-
static_cast<const void*>(comm), gpu_stream);
134-
XLA_NCCL_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype,
135-
*target_id, comm, gpu_stream));
117+
TF_RETURN_IF_ERROR(NcclApi::Send(
118+
src_addr, buffer.element_type, buffer.element_type, *target_id,
119+
reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));
136120
}
121+
137122
return absl::OkStatus();
138-
#else // XLA_ENABLE_XCCL
139-
return Unimplemented(
140-
"NCCL support is not available: this binary was not built with a CUDA "
141-
"compiler, which is necessary to build the NCCL source library.");
142-
#endif // XLA_ENABLE_XCCL
143123
}
144124

145125
} // namespace gpu

0 commit comments

Comments
 (0)