@@ -15,20 +15,18 @@ limitations under the License.
1515
1616#include " xla/service/gpu/nccl_recv_thunk.h"
1717
18+ #include < cstdint>
1819#include < optional>
1920#include < string>
20- #include < utility>
2121#include < vector>
2222
2323#include " absl/status/status.h"
2424#include " absl/strings/string_view.h"
2525#include " xla/mlir_hlo/lhlo/IR/lhlo_ops.h"
2626#include " xla/service/collective_ops_utils.h"
27+ #include " xla/service/gpu/nccl_api.h"
2728#include " xla/stream_executor/stream.h"
28-
29- #if XLA_ENABLE_XCCL
30- #include " xla/stream_executor/gpu/gpu_stream.h"
31- #endif
29+ #include " tsl/platform/errors.h"
3230
3331namespace xla {
3432namespace gpu {
@@ -102,7 +100,6 @@ absl::Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target,
102100 DeviceBufferPair& buffer, se::Stream& stream,
103101 ncclComm_t comm, absl::string_view device_string,
104102 int64_t current_id) {
105- #if XLA_ENABLE_XCCL
106103 // Determine the source IDs for this instance. The source ID is the ID for
107104 // the peer that will copy its data to this instance. If there is no source,
108105 // just memzero() the destination buffer.
@@ -116,23 +113,12 @@ absl::Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target,
116113 VLOG (3 ) << absl::StreamFormat (" %s : id = %d, source_id = %d" , device_string,
117114 current_id, source_id.value_or (-1 ));
118115
119- TF_ASSIGN_OR_RETURN (auto dtype_and_multiplier,
120- ToNcclDataTypeAndCountMultiplier (
121- buffer.element_type , Thunk::kNcclCollectivePermute ));
122- ncclDataType_t dtype = dtype_and_multiplier.first ;
123- int64_t element_count = buffer.element_count * dtype_and_multiplier.second ;
124-
125- se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue (&stream);
126-
127116 // Receive data from the source peer to the destination buffer.
128117 if (source_id) {
129- VLOG (3 ) << absl::StreamFormat (
130- " %s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, "
131- " stream=%p)" ,
132- device_string, dest_addr.opaque (), element_count, *source_id,
133- static_cast <const void *>(comm), gpu_stream);
134- XLA_NCCL_RETURN_IF_ERROR (ncclRecv (dest_addr.opaque (), element_count, dtype,
135- *source_id, comm, gpu_stream));
118+ TF_RETURN_IF_ERROR (NcclApi::Recv (
119+ dest_addr, buffer.element_type , buffer.element_count , *source_id,
120+ reinterpret_cast <NcclApi::NcclCommHandle>(comm), &stream));
121+
136122 } else {
137123 // If there is no source peer, i.e. no sender to this instance, zero out
138124 // the destination buffer.
@@ -141,11 +127,6 @@ absl::Status RunRecv(NcclP2PConfig::SourceTargetMapEntry source_target,
141127 stream.ThenMemZero (&dest_addr, dest_addr.size ());
142128 }
143129 return absl::OkStatus ();
144- #else // XLA_ENABLE_XCCL
145- return Unimplemented (
146- " NCCL support is not available: this binary was not built with a CUDA "
147- " compiler, which is necessary to build the NCCL source library." );
148- #endif // XLA_ENABLE_XCCL
149130}
150131
151132} // namespace gpu
0 commit comments