1- /* Copyright 2024 The OpenXLA Authors.
1+ /* Copyright 2025 The OpenXLA Authors.
22
33Licensed under the Apache License, Version 2.0 (the "License");
44you may not use this file except in compliance with the License.
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16- #include " xla/service /gpu/runtime/nvshmem_api .h"
16+ #include " xla/backends /gpu/collectives/nvshmem_collectives .h"
1717
1818#include " absl/strings/str_format.h"
1919#include " tsl/platform/logging.h"
@@ -22,6 +22,9 @@ limitations under the License.
2222#include " tsl/platform/statusor.h"
2323#include " third_party/nvshmem/nvshmem.h"
2424#include " third_party/nvshmem/nvshmemx.h"
25+ #include " xla/core/collectives/collectives_registry.h"
26+
27+ #include < cuda.h>
2528
2629namespace xla ::gpu {
2730
@@ -60,20 +63,24 @@ static absl::Status NvshmemToStatus(int s, const char* file, int64_t line,
6063
6164#define XLA_NVSHMEM_CHECK (expr ) CHECK(XLA_NVSHMEM_STATUS(expr).ok())
6265
63- int NvshmemApi::process_id_ = -1 ;
64- size_t NvshmemApi::num_processes_ = 0 ;
65- size_t NvshmemApi::device_count_per_process_ = 0 ;
66- std::function<absl::StatusOr<std::string>(std::string_view)>
67- NvshmemApi::kv_store_get_ = nullptr ;
68- std::function<absl::Status(std::string_view, std::string_view)>
69- NvshmemApi::kv_store_set_ = nullptr ;
70-
71- NvshmemApi& NvshmemApi::Default () {
72- static NvshmemApi instance;
73- return instance;
66+ NvshmemCollectives::~NvshmemCollectives () {
67+ if (initialized_) Finalize ();
68+ }
69+
70+ NvshmemCollectives* NvshmemCollectives::Default () {
71+ absl::StatusOr<Collectives*> collectives =
72+ CollectivesRegistry::Get (" gpu" , " nvshmem" );
73+ CHECK_OK (collectives) << " Failed to get NVSHMEM collectives" ; // Crash OK
74+
75+ if (auto * nvshmem_collectives =
76+ tsl::down_cast<NvshmemCollectives*>(*collectives)) {
77+ return nvshmem_collectives;
78+ }
79+
80+ LOG (FATAL) << " Unsupported collectives implementation for NVSHMEM" ;
7481}
7582
76- void NvshmemApi ::SetEnvInfo (
83+ void NvshmemCollectives ::SetEnvInfo (
7784 int process_id, size_t num_processes, size_t device_count_per_process,
7885 std::function<absl::StatusOr<std::string>(std::string_view)> kv_store_get,
7986 std::function<absl::Status(std::string_view, std::string_view)>
@@ -85,27 +92,14 @@ void NvshmemApi::SetEnvInfo(
8592 kv_store_set_ = kv_store_set;
8693}
8794
88- NvshmemApi::NvshmemApi () {
89- // Initialize NVSHMEM here since code path
90- // is already protected by singleton pattern
95+ absl::Status NvshmemCollectives::Initialize () {
9196 if (process_id_ == -1 ) {
92- LOG (FATAL)
93- << " NvshmemApi::SetEnvInfo was not called before using NVSHMEM API" ;
97+ LOG (FATAL) << " NvshmemCollectives::SetEnvInfo was not called before using "
98+ " NVSHMEM API" ;
9499 }
95100 if (device_count_per_process_ != 1 ) {
96101 LOG (FATAL) << " NVSHMEM API is only supported with one device per process" ;
97102 }
98- CHECK (Initialize ().ok ());
99- }
100-
101- NvshmemApi::~NvshmemApi () {
102- VLOG (3 ) << absl::StreamFormat (
103- " Finilizing NVSHMEM on process %d; num_processes=%llu" , process_id_,
104- num_processes_);
105- nvshmemx_hostlib_finalize ();
106- }
107-
108- absl::Status NvshmemApi::Initialize () {
109103 nvshmemx_init_attr_t nvshmem_init_attr = NVSHMEMX_INIT_ATTR_INITIALIZER;
110104 nvshmemx_uniqueid_t nvshmem_id = NVSHMEMX_UNIQUEID_INITIALIZER;
111105
@@ -132,7 +126,25 @@ absl::Status NvshmemApi::Initialize() {
132126 return absl::OkStatus ();
133127}
134128
135- absl::StatusOr<void *> NvshmemApi::Allocate (uint64_t bytes) {
129+ absl::Status NvshmemCollectives::InitializeOnce () {
130+ static absl::once_flag once_flag;
131+ absl::Status status = absl::OkStatus ();
132+ absl::call_once (once_flag, [&]() {
133+ status = Initialize ();
134+ initialized_ = true ;
135+ });
136+ return status;
137+ }
138+
139+ void NvshmemCollectives::Finalize () {
140+ VLOG (3 ) << absl::StreamFormat (
141+ " Finilizing NVSHMEM on process %d; num_processes=%llu" , process_id_,
142+ num_processes_);
143+ nvshmemx_hostlib_finalize ();
144+ }
145+
146+ absl::StatusOr<void *> NvshmemCollectives::Allocate (uint64_t bytes) {
147+ TF_RETURN_IF_ERROR (InitializeOnce ());
136148 VLOG (3 ) << absl::StreamFormat (
137149 " Start allocation of %s (%llu bytes) for NVSHMEM" ,
138150 tsl::strings::HumanReadableNumBytes (bytes), bytes);
@@ -145,11 +157,15 @@ absl::StatusOr<void*> NvshmemApi::Allocate(uint64_t bytes) {
145157 return buffer;
146158}
147159
148- absl::Status NvshmemApi::Deallocate (void * buffer) {
160+ absl::Status NvshmemCollectives::Deallocate (void * buffer) {
161+ TF_RETURN_IF_ERROR (InitializeOnce ());
149162 VLOG (3 ) << absl::StreamFormat (" Start de-allocation for NVSHMEM buffer: %p" ,
150163 buffer);
151164 nvshmem_free (buffer);
152165 return absl::OkStatus ();
153166}
154167
155168} // namespace xla::gpu
169+
170+ XLA_COLLECTIVES_REGISTER (" gpu" , " nvshmem" , 2 ,
171+ std::make_unique<xla::gpu::NvshmemCollectives>());
0 commit comments