Skip to content

Commit 23413d7

Browse files
committed
Add NVSHMEM library and initialization test
Add NVSHMEM host API support Load nvshmem via stub Add initialization test fixes fixes Formatting Some fixes Finalize() and use weak_ptr for kv_store Revert "Finalize() and use weak_ptr for kv_store" This reverts commit c8620281c822c67ef6ada84776691555ce6ca683. Call finalize Use store get/set lambdas again. Move test Move test in BUILD too
1 parent 27c5a17 commit 23413d7

File tree

15 files changed

+2024
-0
lines changed

15 files changed

+2024
-0
lines changed

third_party/tsl/third_party/nvshmem/BUILD

Whitespace-only changes.
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# NVSHMEM
2+
3+
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
4+
load("@bazel_skylib//rules:write_file.bzl", "write_file")
5+
6+
options_substitions = {
7+
"#cmakedefine NVSHMEM_COMPLEX_SUPPORT": "/* #undef NVSHMEM_COMPLEX_SUPPORT */",
8+
"#cmakedefine NVSHMEM_DEBUG": "/* #undef NVSHMEM_DEBUG */",
9+
"#cmakedefine NVSHMEM_DEVEL": "/* #undef NVSHMEM_DEVEL */",
10+
"#cmakedefine NVSHMEM_TRACE": "/* #undef NVSHMEM_TRACE */",
11+
"#cmakedefine NVSHMEM_DEFAULT_PMI2": "/* #undef NVSHMEM_DEFAULT_PMI2 */",
12+
"#cmakedefine NVSHMEM_DEFAULT_PMIX": "/* #undef NVSHMEM_DEFAULT_PMIX */",
13+
"#cmakedefine NVSHMEM_DEFAULT_UCX": "/* #undef NVSHMEM_DEFAULT_UCX */",
14+
"#cmakedefine NVSHMEM_DISABLE_COLL_POLL": "#define NVSHMEM_DISABLE_COLL_POLL",
15+
"#cmakedefine NVSHMEM_GPU_COLL_USE_LDST": "/* #undef NVSHMEM_GPU_COLL_USE_LDST */",
16+
"#cmakedefine NVSHMEM_IBDEVX_SUPPORT": "/* #undef NVSHMEM_IBDEVX_SUPPORT */",
17+
"#cmakedefine NVSHMEM_IBRC_SUPPORT": "#define NVSHMEM_IBRC_SUPPORT",
18+
"#cmakedefine NVSHMEM_LIBFABRIC_SUPPORT": "/* #undef NVSHMEM_LIBFABRIC_SUPPORT */",
19+
"#cmakedefine NVSHMEM_MPI_SUPPORT": "/* #undef NVSHMEM_MPI_SUPPORT */",
20+
"#cmakedefine NVSHMEM_NVTX": "#define NVSHMEM_NVTX",
21+
"#cmakedefine NVSHMEM_PMIX_SUPPORT": "/* #undef NVSHMEM_PMIX_SUPPORT */",
22+
"#cmakedefine NVSHMEM_SHMEM_SUPPORT": "/* #undef NVSHMEM_SHMEM_SUPPORT */",
23+
"#cmakedefine NVSHMEM_TEST_STATIC_LIB": "/* #undef NVSHMEM_TEST_STATIC_LIB */",
24+
"#cmakedefine NVSHMEM_TIMEOUT_DEVICE_POLLING": "/* #undef NVSHMEM_TIMEOUT_DEVICE_POLLING */",
25+
"#cmakedefine NVSHMEM_UCX_SUPPORT": "/* #undef NVSHMEM_UCX_SUPPORT */",
26+
"#cmakedefine NVSHMEM_USE_DLMALLOC": "/* #undef NVSHMEM_USE_DLMALLOC */",
27+
"#cmakedefine NVSHMEM_USE_NCCL": "/* #undef NVSHMEM_USE_NCCL */",
28+
"#cmakedefine NVSHMEM_USE_GDRCOPY": "/* #undef NVSHMEM_USE_GDRCOPY */",
29+
"#cmakedefine NVSHMEM_VERBOSE": "/* #undef NVSHMEM_VERBOSE */",
30+
"#cmakedefine NVSHMEM_BUILD_TESTS": "#define NVSHMEM_BUILD_TESTS",
31+
"#cmakedefine NVSHMEM_BUILD_EXAMPLES": "#define NVSHMEM_BUILD_EXAMPLES",
32+
"#cmakedefine NVSHMEM_IBGDA_SUPPORT_GPUMEM_ONLY": "/* #undef NVSHMEM_IBGDA_SUPPORT_GPUMEM_ONLY */",
33+
"#cmakedefine NVSHMEM_IBGDA_SUPPORT": "/* #undef NVSHMEM_IBGDA_SUPPORT */",
34+
"#cmakedefine NVSHMEM_ENABLE_ALL_DEVICE_INLINING": "/* #undef NVSHMEM_ENABLE_ALL_DEVICE_INLINING */",
35+
}
36+
37+
expand_template(
38+
name = "nvshmem_build_options_h",
39+
out = "src/include/non_abi/nvshmem_build_options.h",
40+
substitutions = options_substitions,
41+
template = "src/include/non_abi/nvshmem_build_options.h.in",
42+
)
43+
44+
NVSHMEM_MAJOR = 3
45+
46+
version_substitions = {
47+
"@PROJECT_VERSION_MAJOR@": str(NVSHMEM_MAJOR),
48+
"@PROJECT_VERSION_MINOR@": "0",
49+
"@PROJECT_VERSION_PATCH@": "6",
50+
"@PROJECT_VERSION_TWEAK@": "4",
51+
"@TRANSPORT_VERSION_MAJOR@": "3",
52+
"@TRANSPORT_VERSION_MINOR@": "0",
53+
"@TRANSPORT_VERSION_PATCH@": "0",
54+
"@BOOTSTRAP_VERSION_MAJOR@": "3",
55+
"@BOOTSTRAP_VERSION_MINOR@": "0",
56+
"@BOOTSTRAP_VERSION_PATCH@": "0",
57+
"@INTERLIB_VERSION_MAJOR@": "3",
58+
"@INTERLIB_VERSION_MINOR@": "0",
59+
"@INTERLIB_VERSION_PATCH@": "0",
60+
"@INFO_BUILD_VARS@": "",
61+
}
62+
63+
expand_template(
64+
name = "nvshmem_version_h",
65+
out = "src/include/non_abi/nvshmem_version.h",
66+
substitutions = version_substitions,
67+
template = "src/include/non_abi/nvshmem_version.h.in",
68+
)
69+
70+
cc_library(
71+
name = "nvshmem",
72+
hdrs = glob([
73+
"src/include/**",
74+
]) + [
75+
":nvshmem_build_options_h",
76+
":nvshmem_version_h",
77+
],
78+
includes = ["src/include"],
79+
include_prefix = "third_party/nvshmem",
80+
strip_include_prefix = "src/include",
81+
visibility = ["//visibility:public"],
82+
deps = [
83+
"@xla//xla/tsl/cuda:nvshmem_stub",
84+
],
85+
)
86+
87+
# This additional header allows us to determine the configured NVSHMEM version
88+
# without including the rest of NVSHMEM.
89+
write_file(
90+
name = "nvshmem_config_header",
91+
out = "nvshmem_config.h",
92+
content = [
93+
"#define TF_NVSHMEM_VERSION \"{}\"".format(NVSHMEM_MAJOR),
94+
],
95+
)
96+
97+
cc_library(
98+
name = "nvshmem_config",
99+
hdrs = ["nvshmem_config.h"],
100+
include_prefix = "third_party/nvshmem",
101+
visibility = ["//visibility:public"],
102+
)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""NVSHMEM - NVIDIA Shared Memory"""
2+
3+
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
4+
5+
def repo():
6+
tf_http_archive(
7+
name = "nvshmem",
8+
strip_prefix = "nvshmem_src_3.0.6-4",
9+
sha256 = "4f435fdee320a365dd19d24b9f74df69b69886d3902ec99b16b553d485b18871",
10+
urls = tf_mirror_urls("https://developer.download.nvidia.com/compute/redist/nvshmem/3.0.6/source/nvshmem_src_3.0.6-4.txz"),
11+
build_file = "//third_party/nvshmem:nvshmem.BUILD",
12+
)

third_party/tsl/workspace2.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ load("//third_party/hwloc:workspace.bzl", hwloc = "repo")
2323
load("//third_party/implib_so:workspace.bzl", implib_so = "repo")
2424
load("//third_party/llvm:setup.bzl", "llvm_setup")
2525
load("//third_party/nasm:workspace.bzl", nasm = "repo")
26+
load("//third_party/nvshmem:workspace.bzl", nvshmem = "repo")
2627
load("//third_party/py:python_configure.bzl", "python_configure")
2728
load("//third_party/py/ml_dtypes:workspace.bzl", ml_dtypes = "repo")
2829
load("//third_party/pybind11_abseil:workspace.bzl", pybind11_abseil = "repo")
@@ -50,6 +51,7 @@ def _initialize_third_party():
5051
implib_so()
5152
ml_dtypes()
5253
nasm()
54+
nvshmem()
5355
pybind11_abseil()
5456
pybind11_bazel()
5557
tensorrt()

xla/service/gpu/runtime/BUILD

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,28 @@ xla_test(
188188
],
189189
)
190190

191+
#===-------------------------------------------------------------------------------------------===//
192+
# NVSHMEM Integration
193+
#===-------------------------------------------------------------------------------------------===//
194+
195+
cc_library(
196+
name = "nvshmem_api",
197+
srcs = ["nvshmem_api.cc"],
198+
hdrs = ["nvshmem_api.h"],
199+
deps = [
200+
"@com_google_absl//absl/status",
201+
"@com_google_absl//absl/status:statusor",
202+
"@com_google_absl//absl/strings:str_format",
203+
"@tsl//tsl/platform:errors",
204+
"@tsl//tsl/platform:numbers",
205+
"@tsl//tsl/platform:logging",
206+
"@tsl//tsl/platform:statusor",
207+
]+ if_cuda_is_configured([
208+
"@local_config_cuda//cuda:cuda_headers",
209+
"@nvshmem//:nvshmem",
210+
]),
211+
)
212+
191213
#===-------------------------------------------------------------------------------------------===//
192214
# XLA Thunks Runtime
193215
#===-------------------------------------------------------------------------------------------===//
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/service/gpu/runtime/nvshmem_api.h"
17+
18+
#include "absl/strings/str_format.h"
19+
#include "tsl/platform/logging.h"
20+
#include "tsl/platform/errors.h"
21+
#include "tsl/platform/numbers.h"
22+
#include "tsl/platform/statusor.h"
23+
#include "third_party/nvshmem/nvshmem.h"
24+
#include "third_party/nvshmem/nvshmemx.h"
25+
26+
namespace xla::gpu {
27+
28+
//==-----------------------------------------------------------------------===//
29+
// Macros to return or warn on NVSHMEM errors.
30+
//==-----------------------------------------------------------------------===//
31+
32+
static absl::Status NvshmemToStatus(int s, const char* file, int64_t line,
33+
const char* expr) {
34+
if (s == 0) return absl::OkStatus();
35+
36+
return absl::InternalError(
37+
absl::StrFormat("%s:%d: NVSHMEM operation %s failed."
38+
" For extra logging, rerun with 'NVSHMEM_DEBUG=INFO'.",
39+
file, line, expr));
40+
}
41+
42+
#define XLA_NVSHMEM_STATUS(expr) \
43+
xla::gpu::NvshmemToStatus(expr, __FILE__, __LINE__, #expr)
44+
45+
#define XLA_NVSHMEM_RETURN_IF_ERROR(expr) \
46+
do { \
47+
absl::Status s = XLA_NVSHMEM_STATUS(expr); \
48+
if (!s.ok()) { \
49+
return s; \
50+
} \
51+
} while (0)
52+
53+
#define XLA_NVSHMEM_LOG_IF_ERROR(expr) \
54+
do { \
55+
absl::Status s = XLA_NVSHMEM_STATUS(expr); \
56+
if (!s.ok()) { \
57+
LOG(ERROR) << s.ToString(); \
58+
} \
59+
} while (0)
60+
61+
#define XLA_NVSHMEM_CHECK(expr) CHECK(XLA_NVSHMEM_STATUS(expr).ok())
62+
63+
int NvshmemApi::process_id_ = -1;
64+
size_t NvshmemApi::num_processes_ = 0;
65+
size_t NvshmemApi::device_count_per_process_ = 0;
66+
std::function<absl::StatusOr<std::string>(std::string_view)>
67+
NvshmemApi::kv_store_get_ = nullptr;
68+
std::function<absl::Status(std::string_view, std::string_view)>
69+
NvshmemApi::kv_store_set_ = nullptr;
70+
71+
NvshmemApi& NvshmemApi::Default() {
72+
static NvshmemApi instance;
73+
return instance;
74+
}
75+
76+
void NvshmemApi::SetEnvInfo(
77+
int process_id, size_t num_processes, size_t device_count_per_process,
78+
std::function<absl::StatusOr<std::string>(std::string_view)> kv_store_get,
79+
std::function<absl::Status(std::string_view, std::string_view)>
80+
kv_store_set) {
81+
process_id_ = process_id;
82+
num_processes_ = num_processes;
83+
device_count_per_process_ = device_count_per_process;
84+
kv_store_get_ = kv_store_get;
85+
kv_store_set_ = kv_store_set;
86+
}
87+
88+
NvshmemApi::NvshmemApi() {
89+
// Initialize NVSHMEM here since code path
90+
// is already protected by singleton pattern
91+
if (process_id_ == -1) {
92+
LOG(FATAL)
93+
<< "NvshmemApi::SetEnvInfo was not called before using NVSHMEM API";
94+
}
95+
if (device_count_per_process_ != 1) {
96+
LOG(FATAL) << "NVSHMEM API is only supported with one device per process";
97+
}
98+
CHECK(Initialize().ok());
99+
}
100+
101+
NvshmemApi::~NvshmemApi() {
102+
VLOG(3) << absl::StreamFormat(
103+
"Finilizing NVSHMEM on process %d; num_processes=%llu", process_id_,
104+
num_processes_);
105+
nvshmemx_hostlib_finalize();
106+
}
107+
108+
absl::Status NvshmemApi::Initialize() {
109+
nvshmemx_init_attr_t nvshmem_init_attr;
110+
nvshmemx_uniqueid_t nvshmem_id;
111+
112+
// Initialize NVSHMEM
113+
if (process_id_ == 0) {
114+
XLA_NVSHMEM_RETURN_IF_ERROR(nvshmemx_get_uniqueid(&nvshmem_id));
115+
std::string_view nvshmem_id_str(reinterpret_cast<char*>(&nvshmem_id),
116+
sizeof(nvshmemx_uniqueid_t));
117+
TF_RETURN_IF_ERROR(kv_store_set_(kv_store_key_, nvshmem_id_str));
118+
} else {
119+
TF_ASSIGN_OR_RETURN(std::string id_str, kv_store_get_(kv_store_key_));
120+
std::copy(id_str.data(), id_str.data() + sizeof(nvshmemx_uniqueid_t),
121+
reinterpret_cast<char*>(&nvshmem_id));
122+
}
123+
124+
XLA_NVSHMEM_RETURN_IF_ERROR(nvshmemx_set_attr_uniqueid_args(
125+
process_id_, num_processes_, &nvshmem_id, &nvshmem_init_attr));
126+
XLA_NVSHMEM_RETURN_IF_ERROR(nvshmemx_hostlib_init_attr(
127+
NVSHMEMX_INIT_WITH_UNIQUEID, &nvshmem_init_attr));
128+
129+
VLOG(3) << absl::StreamFormat(
130+
"Initialized NVSHMEM on process %d; num_processes=%llu", process_id_,
131+
num_processes_);
132+
return absl::OkStatus();
133+
}
134+
135+
absl::StatusOr<void*> NvshmemApi::Allocate(uint64_t bytes) {
136+
VLOG(3) << absl::StreamFormat(
137+
"Start allocation of %s (%llu bytes) for NVSHMEM",
138+
tsl::strings::HumanReadableNumBytes(bytes), bytes);
139+
void* buffer = nvshmem_malloc(bytes);
140+
if (buffer == nullptr) {
141+
return absl::InternalError(absl::StrFormat(
142+
"Failed to allocate %s (%llu bytes) from NVSHMEM memory",
143+
tsl::strings::HumanReadableNumBytes(bytes), bytes));
144+
}
145+
return buffer;
146+
}
147+
148+
absl::Status NvshmemApi::Deallocate(void* buffer) {
149+
VLOG(3) << absl::StreamFormat("Start de-allocation for NVSHMEM buffer: %p",
150+
buffer);
151+
nvshmem_free(buffer);
152+
return absl::OkStatus();
153+
}
154+
155+
} // namespace xla::gpu
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/* Copyright 2024 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_SERVICE_GPU_RUNTIME_NVSHMEM_API_H_
17+
#define XLA_SERVICE_GPU_RUNTIME_NVSHMEM_API_H_
18+
19+
#include <functional>
20+
#include <string_view>
21+
22+
#include <cuda.h>
23+
24+
#include "absl/status/status.h"
25+
#include "absl/status/statusor.h"
26+
27+
namespace xla::gpu {
28+
29+
//===----------------------------------------------------------------------===//
30+
// NvshmemApi
31+
//===----------------------------------------------------------------------===//
32+
33+
class NvshmemApi {
34+
public:
35+
// Returns a default NvshmemApi for a current process.
36+
// NvshmemApi follows the Singleton design pattern
37+
static NvshmemApi& Default();
38+
39+
static void SetEnvInfo(
40+
int process_id, size_t num_processes, size_t device_count_per_process,
41+
std::function<absl::StatusOr<std::string>(std::string_view)> kv_store_get,
42+
std::function<absl::Status(std::string_view, std::string_view)>
43+
kv_store_set);
44+
NvshmemApi(NvshmemApi const&) = delete;
45+
void operator=(NvshmemApi const&) = delete;
46+
47+
absl::StatusOr<void*> Allocate(uint64_t bytes);
48+
absl::Status Deallocate(void* buffer);
49+
50+
private:
51+
NvshmemApi();
52+
~NvshmemApi();
53+
54+
absl::Status Initialize();
55+
56+
// Env variable
57+
static int process_id_;
58+
static size_t num_processes_;
59+
static size_t device_count_per_process_;
60+
static std::function<absl::StatusOr<std::string>(std::string_view)>
61+
kv_store_get_;
62+
static std::function<absl::Status(std::string_view, std::string_view)>
63+
kv_store_set_;
64+
static constexpr char kv_store_key_[] = "nvshmem_global_init";
65+
};
66+
67+
} // namespace xla::gpu
68+
69+
#endif // XLA_SERVICE_GPU_RUNTIME_NVSHMEM_API_H_

0 commit comments

Comments
 (0)