Skip to content

Commit f2d2f61

Browse files
committed
Move to NvshmemCollectives
1 parent 58f842d commit f2d2f61

File tree

10 files changed

+193
-128
lines changed

10 files changed

+193
-128
lines changed

xla/backends/gpu/collectives/BUILD

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,30 @@ cc_library(
260260
"@local_config_rocm//rocm:rccl",
261261
]),
262262
)
263+
264+
cc_library(
265+
name = "nvshmem_collectives",
266+
srcs = if_gpu_is_configured(["nvshmem_collectives.cc"]),
267+
hdrs = if_gpu_is_configured(["nvshmem_collectives.h"]),
268+
local_defines = if_cuda_is_configured([
269+
"GOOGLE_CUDA=1",
270+
]) + if_rocm_is_configured([
271+
"TENSORFLOW_USE_ROCM=1",
272+
]),
273+
deps = [
274+
"//xla/core/collectives",
275+
"//xla/core/collectives:collectives_registry",
276+
"@com_google_absl//absl/status",
277+
"@com_google_absl//absl/status:statusor",
278+
"@com_google_absl//absl/strings:str_format",
279+
"@com_googlesource_code_re2//:re2",
280+
"@tsl//tsl/platform:errors",
281+
"@tsl//tsl/platform:numbers",
282+
"@tsl//tsl/platform:logging",
283+
"@tsl//tsl/platform:statusor",
284+
]+ if_cuda_is_configured([
285+
"@local_config_cuda//cuda:cuda_headers",
286+
"@nvshmem//:nvshmem",
287+
]),
288+
alwayslink = True, # registers collectives implementation
289+
)

xla/service/gpu/runtime/nvshmem_api.cc renamed to xla/backends/gpu/collectives/nvshmem_collectives.cc

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright 2024 The OpenXLA Authors.
1+
/* Copyright 2025 The OpenXLA Authors.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "xla/service/gpu/runtime/nvshmem_api.h"
16+
#include "xla/backends/gpu/collectives/nvshmem_collectives.h"
1717

1818
#include "absl/strings/str_format.h"
1919
#include "tsl/platform/logging.h"
@@ -22,6 +22,9 @@ limitations under the License.
2222
#include "tsl/platform/statusor.h"
2323
#include "third_party/nvshmem/nvshmem.h"
2424
#include "third_party/nvshmem/nvshmemx.h"
25+
#include "xla/core/collectives/collectives_registry.h"
26+
27+
#include <cuda.h>
2528

2629
namespace xla::gpu {
2730

@@ -60,20 +63,24 @@ static absl::Status NvshmemToStatus(int s, const char* file, int64_t line,
6063

6164
#define XLA_NVSHMEM_CHECK(expr) CHECK(XLA_NVSHMEM_STATUS(expr).ok())
6265

63-
int NvshmemApi::process_id_ = -1;
64-
size_t NvshmemApi::num_processes_ = 0;
65-
size_t NvshmemApi::device_count_per_process_ = 0;
66-
std::function<absl::StatusOr<std::string>(std::string_view)>
67-
NvshmemApi::kv_store_get_ = nullptr;
68-
std::function<absl::Status(std::string_view, std::string_view)>
69-
NvshmemApi::kv_store_set_ = nullptr;
70-
71-
NvshmemApi& NvshmemApi::Default() {
72-
static NvshmemApi instance;
73-
return instance;
66+
NvshmemCollectives::~NvshmemCollectives() {
67+
if (initialized_) Finalize();
68+
}
69+
70+
NvshmemCollectives* NvshmemCollectives::Default() {
71+
absl::StatusOr<Collectives*> collectives =
72+
CollectivesRegistry::Get("gpu", "nvshmem");
73+
CHECK_OK(collectives) << "Failed to get NVSHMEM collectives"; // Crash OK
74+
75+
if (auto* nvshmem_collectives =
76+
tsl::down_cast<NvshmemCollectives*>(*collectives)) {
77+
return nvshmem_collectives;
78+
}
79+
80+
LOG(FATAL) << "Unsupported collectives implementation for NVSHMEM";
7481
}
7582

76-
void NvshmemApi::SetEnvInfo(
83+
void NvshmemCollectives::SetEnvInfo(
7784
int process_id, size_t num_processes, size_t device_count_per_process,
7885
std::function<absl::StatusOr<std::string>(std::string_view)> kv_store_get,
7986
std::function<absl::Status(std::string_view, std::string_view)>
@@ -85,27 +92,14 @@ void NvshmemApi::SetEnvInfo(
8592
kv_store_set_ = kv_store_set;
8693
}
8794

88-
NvshmemApi::NvshmemApi() {
89-
// Initialize NVSHMEM here since code path
90-
// is already protected by singleton pattern
95+
absl::Status NvshmemCollectives::Initialize() {
9196
if (process_id_ == -1) {
92-
LOG(FATAL)
93-
<< "NvshmemApi::SetEnvInfo was not called before using NVSHMEM API";
97+
LOG(FATAL) << "NvshmemCollectives::SetEnvInfo was not called before using "
98+
"NVSHMEM API";
9499
}
95100
if (device_count_per_process_ != 1) {
96101
LOG(FATAL) << "NVSHMEM API is only supported with one device per process";
97102
}
98-
CHECK(Initialize().ok());
99-
}
100-
101-
NvshmemApi::~NvshmemApi() {
102-
VLOG(3) << absl::StreamFormat(
103-
"Finilizing NVSHMEM on process %d; num_processes=%llu", process_id_,
104-
num_processes_);
105-
nvshmemx_hostlib_finalize();
106-
}
107-
108-
absl::Status NvshmemApi::Initialize() {
109103
nvshmemx_init_attr_t nvshmem_init_attr = NVSHMEMX_INIT_ATTR_INITIALIZER;
110104
nvshmemx_uniqueid_t nvshmem_id = NVSHMEMX_UNIQUEID_INITIALIZER;
111105

@@ -132,7 +126,25 @@ absl::Status NvshmemApi::Initialize() {
132126
return absl::OkStatus();
133127
}
134128

135-
absl::StatusOr<void*> NvshmemApi::Allocate(uint64_t bytes) {
129+
absl::Status NvshmemCollectives::InitializeOnce() {
130+
static absl::once_flag once_flag;
131+
absl::Status status = absl::OkStatus();
132+
absl::call_once(once_flag, [&]() {
133+
status = Initialize();
134+
initialized_ = true;
135+
});
136+
return status;
137+
}
138+
139+
void NvshmemCollectives::Finalize() {
140+
VLOG(3) << absl::StreamFormat(
141+
"Finilizing NVSHMEM on process %d; num_processes=%llu", process_id_,
142+
num_processes_);
143+
nvshmemx_hostlib_finalize();
144+
}
145+
146+
absl::StatusOr<void*> NvshmemCollectives::Allocate(uint64_t bytes) {
147+
TF_RETURN_IF_ERROR(InitializeOnce());
136148
VLOG(3) << absl::StreamFormat(
137149
"Start allocation of %s (%llu bytes) for NVSHMEM",
138150
tsl::strings::HumanReadableNumBytes(bytes), bytes);
@@ -145,11 +157,15 @@ absl::StatusOr<void*> NvshmemApi::Allocate(uint64_t bytes) {
145157
return buffer;
146158
}
147159

148-
absl::Status NvshmemApi::Deallocate(void* buffer) {
160+
absl::Status NvshmemCollectives::Deallocate(void* buffer) {
161+
TF_RETURN_IF_ERROR(InitializeOnce());
149162
VLOG(3) << absl::StreamFormat("Start de-allocation for NVSHMEM buffer: %p",
150163
buffer);
151164
nvshmem_free(buffer);
152165
return absl::OkStatus();
153166
}
154167

155168
} // namespace xla::gpu
169+
170+
XLA_COLLECTIVES_REGISTER("gpu", "nvshmem", 2,
171+
std::make_unique<xla::gpu::NvshmemCollectives>());
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_BACKENDS_GPU_COLLECTIVES_NVSHMEM_COLLECTIVES_H_
17+
#define XLA_BACKENDS_GPU_COLLECTIVES_NVSHMEM_COLLECTIVES_H_
18+
19+
#include <functional>
20+
#include <string_view>
21+
22+
#include "absl/status/status.h"
23+
#include "absl/status/statusor.h"
24+
#include "xla/core/collectives/collectives.h"
25+
26+
namespace xla::gpu {
27+
28+
// NVIDIA NVSHMEM library
29+
class NvshmemCollectives : public Collectives {
30+
public:
31+
~NvshmemCollectives() override;
32+
33+
static NvshmemCollectives* Default();
34+
35+
void SetEnvInfo(
36+
int process_id, size_t num_processes, size_t device_count_per_process,
37+
std::function<absl::StatusOr<std::string>(std::string_view)> kv_store_get,
38+
std::function<absl::Status(std::string_view, std::string_view)>
39+
kv_store_set);
40+
41+
absl::StatusOr<void*> Allocate(uint64_t bytes);
42+
43+
absl::Status Deallocate(void* buffer);
44+
45+
absl::StatusOr<CliqueId> CreateUniqueCliqueId() const final {
46+
return absl::UnimplementedError("Not implemented.");
47+
}
48+
49+
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
50+
CreateCommunicators(int32_t, const CliqueKey&, const std::optional<CliqueId>&,
51+
absl::Span<const DeviceRank>,
52+
const Collectives::Config&) final {
53+
return absl::UnimplementedError("Not implemented.");
54+
}
55+
56+
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>> SplitCommunicators(
57+
absl::Span<const Communicator* const>, int32_t, absl::Span<const RankId>,
58+
const Collectives::Config&) final {
59+
return absl::UnimplementedError("Not implemented.");
60+
}
61+
62+
private:
63+
absl::Status Initialize();
64+
absl::Status InitializeOnce();
65+
66+
void Finalize();
67+
68+
int process_id_ = -1;
69+
size_t num_processes_ = 0;
70+
size_t device_count_per_process_ = 0;
71+
std::function<absl::StatusOr<std::string>(std::string_view)> kv_store_get_ =
72+
nullptr;
73+
std::function<absl::Status(std::string_view, std::string_view)>
74+
kv_store_set_ = nullptr;
75+
bool initialized_ = false;
76+
77+
static constexpr char kv_store_key_[] = "nvshmem_global_init";
78+
};
79+
80+
} // namespace xla::gpu
81+
82+
#endif // XLA_BACKENDS_GPU_COLLECTIVES_NVSHMEM_COLLECTIVES_H_

xla/core/collectives/collectives_registry.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,24 @@ absl::StatusOr<Collectives*> CollectivesRegistry::Default(
9999
return registry.platform_collectives[canonical_platform_name].begin()->second;
100100
}
101101

102+
absl::StatusOr<Collectives*> CollectivesRegistry::Get(
103+
absl::string_view platform_name, absl::string_view implementation_name) {
104+
TF_ASSIGN_OR_RETURN(std::string canonical_platform_name,
105+
PlatformUtil::CanonicalPlatformName(platform_name));
106+
107+
auto& registry = GetCollectivesRegistry();
108+
absl::MutexLock lock(&registry.mu);
109+
110+
for (const auto& registration : registry.collectives) {
111+
if (registration.platform_name == canonical_platform_name &&
112+
registration.name == implementation_name)
113+
return registration.collectives.get();
114+
}
115+
116+
return Internal(
117+
"No collectives registered for platform: %s (canonical name: %s) and "
118+
"implementation: %s",
119+
platform_name, canonical_platform_name, implementation_name);
120+
}
121+
102122
} // namespace xla

xla/core/collectives/collectives_registry.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ class CollectivesRegistry {
4545

4646
// Returns the default collectives implementation for the given platform.
4747
static absl::StatusOr<Collectives*> Default(absl::string_view platform_name);
48+
49+
// Return a specific collectives implementation by name for the given
50+
// platform.
51+
static absl::StatusOr<Collectives*> Get(
52+
absl::string_view platform_name, absl::string_view implementation_name);
4853
};
4954

5055
} // namespace xla

xla/python/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,7 @@ tsl_pybind_extension(
12531253
"-Wl,-rpath,$$ORIGIN/../nvidia/cudnn/lib",
12541254
"-Wl,-rpath,$$ORIGIN/../nvidia/cusolver/lib",
12551255
"-Wl,-rpath,$$ORIGIN/../nvidia/nccl/lib",
1256+
"-Wl,-rpath,$$ORIGIN/../nvidia/nvshmem/lib",
12561257
],
12571258
"//conditions:default": [],
12581259
}),

xla/service/gpu/runtime/BUILD

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -187,28 +187,6 @@ xla_test(
187187
],
188188
)
189189

190-
#===-------------------------------------------------------------------------------------------===//
191-
# NVSHMEM Integration
192-
#===-------------------------------------------------------------------------------------------===//
193-
194-
cc_library(
195-
name = "nvshmem_api",
196-
srcs = ["nvshmem_api.cc"],
197-
hdrs = ["nvshmem_api.h"],
198-
deps = [
199-
"@com_google_absl//absl/status",
200-
"@com_google_absl//absl/status:statusor",
201-
"@com_google_absl//absl/strings:str_format",
202-
"@tsl//tsl/platform:errors",
203-
"@tsl//tsl/platform:numbers",
204-
"@tsl//tsl/platform:logging",
205-
"@tsl//tsl/platform:statusor",
206-
]+ if_cuda_is_configured([
207-
"@local_config_cuda//cuda:cuda_headers",
208-
"@nvshmem//:nvshmem",
209-
]),
210-
)
211-
212190
#===-------------------------------------------------------------------------------------------===//
213191
# XLA Thunks Runtime
214192
#===-------------------------------------------------------------------------------------------===//

xla/service/gpu/runtime/nvshmem_api.h

Lines changed: 0 additions & 69 deletions
This file was deleted.

0 commit comments

Comments
 (0)