Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
102 changes: 102 additions & 0 deletions third_party/tsl/third_party/nvshmem/nvshmem.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# NVSHMEM

load("@bazel_skylib//rules:expand_template.bzl", "expand_template")
load("@bazel_skylib//rules:write_file.bzl", "write_file")

options_substitions = {
"#cmakedefine NVSHMEM_COMPLEX_SUPPORT": "/* #undef NVSHMEM_COMPLEX_SUPPORT */",
"#cmakedefine NVSHMEM_DEBUG": "/* #undef NVSHMEM_DEBUG */",
"#cmakedefine NVSHMEM_DEVEL": "/* #undef NVSHMEM_DEVEL */",
"#cmakedefine NVSHMEM_TRACE": "/* #undef NVSHMEM_TRACE */",
"#cmakedefine NVSHMEM_DEFAULT_PMI2": "/* #undef NVSHMEM_DEFAULT_PMI2 */",
"#cmakedefine NVSHMEM_DEFAULT_PMIX": "/* #undef NVSHMEM_DEFAULT_PMIX */",
"#cmakedefine NVSHMEM_DEFAULT_UCX": "/* #undef NVSHMEM_DEFAULT_UCX */",
"#cmakedefine NVSHMEM_DISABLE_COLL_POLL": "#define NVSHMEM_DISABLE_COLL_POLL",
"#cmakedefine NVSHMEM_GPU_COLL_USE_LDST": "/* #undef NVSHMEM_GPU_COLL_USE_LDST */",
"#cmakedefine NVSHMEM_IBDEVX_SUPPORT": "/* #undef NVSHMEM_IBDEVX_SUPPORT */",
"#cmakedefine NVSHMEM_IBRC_SUPPORT": "#define NVSHMEM_IBRC_SUPPORT",
"#cmakedefine NVSHMEM_LIBFABRIC_SUPPORT": "/* #undef NVSHMEM_LIBFABRIC_SUPPORT */",
"#cmakedefine NVSHMEM_MPI_SUPPORT": "/* #undef NVSHMEM_MPI_SUPPORT */",
"#cmakedefine NVSHMEM_NVTX": "#define NVSHMEM_NVTX",
"#cmakedefine NVSHMEM_PMIX_SUPPORT": "/* #undef NVSHMEM_PMIX_SUPPORT */",
"#cmakedefine NVSHMEM_SHMEM_SUPPORT": "/* #undef NVSHMEM_SHMEM_SUPPORT */",
"#cmakedefine NVSHMEM_TEST_STATIC_LIB": "/* #undef NVSHMEM_TEST_STATIC_LIB */",
"#cmakedefine NVSHMEM_TIMEOUT_DEVICE_POLLING": "/* #undef NVSHMEM_TIMEOUT_DEVICE_POLLING */",
"#cmakedefine NVSHMEM_UCX_SUPPORT": "/* #undef NVSHMEM_UCX_SUPPORT */",
"#cmakedefine NVSHMEM_USE_DLMALLOC": "/* #undef NVSHMEM_USE_DLMALLOC */",
"#cmakedefine NVSHMEM_USE_NCCL": "/* #undef NVSHMEM_USE_NCCL */",
"#cmakedefine NVSHMEM_USE_GDRCOPY": "/* #undef NVSHMEM_USE_GDRCOPY */",
"#cmakedefine NVSHMEM_VERBOSE": "/* #undef NVSHMEM_VERBOSE */",
"#cmakedefine NVSHMEM_BUILD_TESTS": "#define NVSHMEM_BUILD_TESTS",
"#cmakedefine NVSHMEM_BUILD_EXAMPLES": "#define NVSHMEM_BUILD_EXAMPLES",
"#cmakedefine NVSHMEM_IBGDA_SUPPORT_GPUMEM_ONLY": "/* #undef NVSHMEM_IBGDA_SUPPORT_GPUMEM_ONLY */",
"#cmakedefine NVSHMEM_IBGDA_SUPPORT": "/* #undef NVSHMEM_IBGDA_SUPPORT */",
"#cmakedefine NVSHMEM_ENABLE_ALL_DEVICE_INLINING": "/* #undef NVSHMEM_ENABLE_ALL_DEVICE_INLINING */",
}

expand_template(
name = "nvshmem_build_options_h",
out = "src/include/non_abi/nvshmem_build_options.h",
substitutions = options_substitions,
template = "src/include/non_abi/nvshmem_build_options.h.in",
)

NVSHMEM_MAJOR = 3

version_substitions = {
"@PROJECT_VERSION_MAJOR@": str(NVSHMEM_MAJOR),
"@PROJECT_VERSION_MINOR@": "1",
"@PROJECT_VERSION_PATCH@": "7",
"@PROJECT_VERSION_TWEAK@": "0",
"@TRANSPORT_VERSION_MAJOR@": "3",
"@TRANSPORT_VERSION_MINOR@": "0",
"@TRANSPORT_VERSION_PATCH@": "0",
"@BOOTSTRAP_VERSION_MAJOR@": "3",
"@BOOTSTRAP_VERSION_MINOR@": "0",
"@BOOTSTRAP_VERSION_PATCH@": "0",
"@INTERLIB_VERSION_MAJOR@": "3",
"@INTERLIB_VERSION_MINOR@": "0",
"@INTERLIB_VERSION_PATCH@": "0",
"@INFO_BUILD_VARS@": "",
}

expand_template(
name = "nvshmem_version_h",
out = "src/include/non_abi/nvshmem_version.h",
substitutions = version_substitions,
template = "src/include/non_abi/nvshmem_version.h.in",
)

cc_library(
name = "nvshmem",
hdrs = glob([
"src/include/**",
]) + [
":nvshmem_build_options_h",
":nvshmem_version_h",
],
includes = ["src/include"],
include_prefix = "third_party/nvshmem",
strip_include_prefix = "src/include",
visibility = ["//visibility:public"],
deps = [
"@xla//xla/tsl/cuda:nvshmem_stub",
],
)

# This additional header allows us to determine the configured NVSHMEM version
# without including the rest of NVSHMEM.
write_file(
name = "nvshmem_config_header",
out = "nvshmem_config.h",
content = [
"#define TF_NVSHMEM_VERSION \"{}\"".format(NVSHMEM_MAJOR),
],
)

cc_library(
name = "nvshmem_config",
hdrs = ["nvshmem_config.h"],
include_prefix = "third_party/nvshmem",
visibility = ["//visibility:public"],
)
13 changes: 13 additions & 0 deletions third_party/tsl/third_party/nvshmem/workspace.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""NVSHMEM - NVIDIA Shared Memory"""

load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
tf_http_archive(
name = "nvshmem",
strip_prefix = "nvshmem_src",
sha256 = "2146ff231d9aadd2b11f324c142582f89e3804775877735dc507b4dfd70c788b",
urls = tf_mirror_urls("https://developer.download.nvidia.com/compute/redist/nvshmem/3.1.7/source/nvshmem_src_3.1.7-1.txz"),
build_file = "//third_party/nvshmem:nvshmem.BUILD",
type = "tar",
)
2 changes: 2 additions & 0 deletions tsl_workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ load("@tsl//third_party/hwloc:workspace.bzl", hwloc = "repo")
load("@tsl//third_party/implib_so:workspace.bzl", implib_so = "repo")
load("@tsl//third_party/llvm:setup.bzl", "llvm_setup")
load("@tsl//third_party/nasm:workspace.bzl", nasm = "repo")
load("@tsl//third_party/nvshmem:workspace.bzl", nvshmem = "repo")
load("@tsl//third_party/py:python_configure.bzl", "python_configure")
load("@tsl//third_party/py/ml_dtypes:workspace.bzl", ml_dtypes = "repo")
load("@tsl//third_party/pybind11_abseil:workspace.bzl", pybind11_abseil = "repo")
Expand Down Expand Up @@ -52,6 +53,7 @@ def _initialize_third_party():
implib_so()
ml_dtypes()
nasm()
nvshmem()
pybind11_abseil()
pybind11_bazel()
tensorrt()
Expand Down
29 changes: 29 additions & 0 deletions xla/backends/gpu/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,32 @@ xla_cc_test(
"@local_config_rocm//rocm:rccl",
]),
)

cc_library(
name = "nvshmem_collectives",
srcs = if_gpu_is_configured(["nvshmem_collectives.cc"]),
hdrs = if_gpu_is_configured(["nvshmem_collectives.h"]),
local_defines = if_cuda_is_configured([
"GOOGLE_CUDA=1",
]) + if_rocm_is_configured([
"TENSORFLOW_USE_ROCM=1",
]),
tags = ["requires-gpu-nvidia"],
deps = [
"//xla/core/collectives",
"//xla/core/collectives:collectives_registry",
"//xla/pjrt/distributed:key_value_store_interface",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_googlesource_code_re2//:re2",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:numbers",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
]+ if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
"@nvshmem//:nvshmem",
]),
alwayslink = True, # registers collectives implementation
)
148 changes: 148 additions & 0 deletions xla/backends/gpu/collectives/nvshmem_collectives.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/backends/gpu/collectives/nvshmem_collectives.h"

#include "absl/strings/str_format.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/numbers.h"
#include "tsl/platform/statusor.h"
#include "third_party/nvshmem/nvshmem.h"
#include "third_party/nvshmem/nvshmemx.h"
#include "xla/core/collectives/collectives_registry.h"

#include <cuda.h>

namespace xla::gpu {

NvshmemCollectives::~NvshmemCollectives() {
if (initialized_) Finalize();
}

NvshmemCollectives* NvshmemCollectives::Default() {
absl::StatusOr<Collectives*> collectives =
CollectivesRegistry::Get("gpu", "nvshmem");
CHECK_OK(collectives) << "Failed to get NVSHMEM collectives"; // Crash OK

if (auto* nvshmem_collectives =
tsl::down_cast<NvshmemCollectives*>(*collectives)) {
return nvshmem_collectives;
}

LOG(FATAL) << "Unsupported collectives implementation for NVSHMEM";
}

void NvshmemCollectives::SetEnvInfo(
int process_id, size_t num_processes, size_t device_count_per_process,
std::weak_ptr<KeyValueStoreInterface> kv_store) {
process_id_ = process_id;
num_processes_ = num_processes;
device_count_per_process_ = device_count_per_process;
kv_store_ = kv_store;
}

absl::Status NvshmemCollectives::InitializeOnce() {
auto init_fn = [this]() -> absl::Status {
if (process_id_ == -1) {
LOG(FATAL)
<< "NvshmemCollectives::SetEnvInfo was not called before using "
"NVSHMEM API";
}
if (device_count_per_process_ != 1) {
LOG(FATAL) << "NVSHMEM API is only supported with one device per process";
}
nvshmemx_init_attr_t nvshmem_init_attr = NVSHMEMX_INIT_ATTR_INITIALIZER;
nvshmemx_uniqueid_t nvshmem_id = NVSHMEMX_UNIQUEID_INITIALIZER;

// Initialize NVSHMEM
if (std::shared_ptr<KeyValueStoreInterface> kv_store = kv_store_.lock()) {
if (process_id_ == 0) {
if (nvshmemx_get_uniqueid(&nvshmem_id) != 0) {
return absl::InternalError("nvshmemx_get_uniqueid failed.");
}
absl::string_view nvshmem_id_str(reinterpret_cast<char*>(&nvshmem_id),
sizeof(nvshmemx_uniqueid_t));
TF_RETURN_IF_ERROR(kv_store->Set(kv_store_key_, nvshmem_id_str));
} else {
TF_ASSIGN_OR_RETURN(std::string id_str,
kv_store->Get(kv_store_key_, absl::Minutes(10)));
std::copy(id_str.data(), id_str.data() + sizeof(nvshmemx_uniqueid_t),
reinterpret_cast<char*>(&nvshmem_id));
}
} else {
return absl::InternalError(
"KV store is not available for nvshmem initialization.");
}

if (nvshmemx_set_attr_uniqueid_args(process_id_, num_processes_,
&nvshmem_id, &nvshmem_init_attr) != 0) {
return absl::InternalError("nvshmemx_set_attr_uniqueid_args failed.");
}
if (nvshmemx_hostlib_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID,
&nvshmem_init_attr) != 0) {
return absl::InternalError("nvshmemx_hostlib_init_attr failed.");
}

VLOG(3) << absl::StreamFormat(
"Initialized NVSHMEM on process %d; num_processes=%llu", process_id_,
num_processes_);
return absl::OkStatus();
};

static absl::once_flag once_flag;
absl::Status status = absl::OkStatus();
absl::call_once(once_flag, [&]() {
status = init_fn();
initialized_ = true;
});
return status;
}

void NvshmemCollectives::Finalize() {
VLOG(3) << absl::StreamFormat(
"Finilizing NVSHMEM on process %d; num_processes=%llu", process_id_,
num_processes_);
nvshmemx_hostlib_finalize();
}

absl::StatusOr<void*> NvshmemCollectives::Allocate(uint64_t bytes) {
TF_RETURN_IF_ERROR(InitializeOnce());
VLOG(3) << absl::StreamFormat(
"Start allocation of %s (%llu bytes) for NVSHMEM",
tsl::strings::HumanReadableNumBytes(bytes), bytes);
void* buffer = nvshmem_malloc(bytes);
if (buffer == nullptr) {
return absl::InternalError(absl::StrFormat(
"Failed to allocate %s (%llu bytes) from NVSHMEM memory",
tsl::strings::HumanReadableNumBytes(bytes), bytes));
}
return buffer;
}

absl::Status NvshmemCollectives::Deallocate(void* buffer) {
TF_RETURN_IF_ERROR(InitializeOnce());
VLOG(3) << absl::StreamFormat("Start de-allocation for NVSHMEM buffer: %p",
buffer);
nvshmem_free(buffer);
return absl::OkStatus();
}

} // namespace xla::gpu

// NvshmemCollectives currently does not implement GpuCollectives, so it cannot
// be used as a host-side collectives library. Therefore, set priority to -100.
XLA_COLLECTIVES_REGISTER("gpu", "nvshmem", -100,
std::make_unique<xla::gpu::NvshmemCollectives>());
77 changes: 77 additions & 0 deletions xla/backends/gpu/collectives/nvshmem_collectives.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_GPU_COLLECTIVES_NVSHMEM_COLLECTIVES_H_
#define XLA_BACKENDS_GPU_COLLECTIVES_NVSHMEM_COLLECTIVES_H_

#include <functional>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "xla/core/collectives/collectives.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"

namespace xla::gpu {

// NVIDIA NVSHMEM library
class NvshmemCollectives : public Collectives {
public:
~NvshmemCollectives() override;

static NvshmemCollectives* Default();

void SetEnvInfo(int process_id, size_t num_processes,
size_t device_count_per_process,
std::weak_ptr<KeyValueStoreInterface> kv_store);

absl::StatusOr<void*> Allocate(uint64_t bytes);

absl::Status Deallocate(void* buffer);

absl::StatusOr<CliqueId> CreateUniqueCliqueId() const final {
return absl::UnimplementedError("Not implemented.");
}

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
CreateCommunicators(const CliqueKey& clique_key,
const std::optional<CliqueIds>& clique_ids,
absl::Span<const DeviceRank> ranks,
const Config& config) final {
return absl::UnimplementedError("Not implemented.");
}

absl::StatusOr<std::vector<std::unique_ptr<Communicator>>> SplitCommunicators(
absl::Span<const Communicator* const> comms, int32_t color,
absl::Span<const RankId> keys, const Config& config) final {
return absl::UnimplementedError("Not implemented.");
}

private:
absl::Status InitializeOnce();

void Finalize();

int process_id_ = -1;
size_t num_processes_ = 0;
size_t device_count_per_process_ = 0;
std::weak_ptr<KeyValueStoreInterface> kv_store_;
bool initialized_ = false;

static constexpr char kv_store_key_[] = "nvshmem_global_init";
};

} // namespace xla::gpu

#endif // XLA_BACKENDS_GPU_COLLECTIVES_NVSHMEM_COLLECTIVES_H_
Loading
Loading