Skip to content

Commit e514c1f

Browse files
trevor-mtensorflower-gardener
authored andcommitted
PR tensorflow#21683: [XLA:GPU] NVSHMEM allocation
Imported from GitHub PR openxla/xla#21683 Requires openxla/xla#20395 which adds the NVSHMEM library dependency. This PR adds the following: 1. Nvshmem flag to enable nvshmem 2. Set nvshmem initialization issue when GPU PJRT client is created. The first time NVSHMEM is used, it will be initialized. 3. Uses the user buffer memory pool for nvshmem. If nvshmem is enabled, it will be allocated using `nvshmem_malloc`. This same memory can be used by user buffers if nccl user buffers is also enabled. 4. Update the `CollectiveColorer` so that mosaic_gpu custom calls use the nvshmem memory space. Copybara import of the project: -- aee33791e16ab2149118de728dbb9e62f5e7cc31 by Trevor Morris <tmorris@nvidia.com>: Add nvshmem flag, memory allocation, and memory space assignment Set Nvshmem env info during client creation Rename flag and use absl::string_view -- f8fca39300b3915eb6320142f58fa9c0ec7a1eaa by Trevor Morris <tmorris@nvidia.com>: Use explicit types in test -- e41faa3f72b778fcf8ea8111d3cde59548b8f9f5 by Trevor Morris <tmorris@nvidia.com>: Add user buffer allgather and allreduce tests with and without nvshmem alloc Set nvshmem in XLA_FLAGS test fixes formatting -- cf0c36865de8b8a010caaf62c3a36b64e36037bd by Trevor Morris <tmorris@nvidia.com>: Fixes -- 3b4d11123cdb794d0a60e65b94d22ded04b7b2b4 by Trevor Morris <tmorris@nvidia.com>: Remove early dso check -- 359f2b243ec97b1f8003c27f0b07dde82407ff6c by Trevor Morris <tmorris@nvidia.com>: Add flag comment -- fd15a7cac745adc1971bec63e148047b9b811729 by Trevor Morris <tmorris@nvidia.com>: Also assign memory space for mosaic_gpu_v2 Merging this change closes tensorflow#21683 PiperOrigin-RevId: 747816712
1 parent 0cfdb00 commit e514c1f

File tree

11 files changed

+425
-26
lines changed

11 files changed

+425
-26
lines changed

third_party/xla/xla/backends/gpu/collectives/BUILD

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,22 @@ package_group(
1818
],
1919
)
2020

21+
config_setting(
22+
name = "arm_build",
23+
values = {"cpu": "arm"},
24+
)
25+
2126
# Build target that registers all available GPU collectives implementations with the collectives
2227
# registry at link time.
2328
cc_library(
2429
name = "gpu_collectives_plugin",
2530
deps = [
2631
":gpu_collectives_stub",
27-
] + if_nccl([":nccl_collectives"]),
32+
] + if_nccl([":nccl_collectives"]) + select({
33+
# TODO(b/409709288): Fix nvshmem ARM issues and remove this condition.
34+
":arm_build": [],
35+
"//conditions:default": [":nvshmem_collectives"],
36+
}),
2837
)
2938

3039
cc_library(
@@ -222,6 +231,7 @@ cc_library(
222231
"@com_google_absl//absl/synchronization",
223232
"@com_google_absl//absl/types:span",
224233
"@local_tsl//tsl/platform:casts",
234+
"@local_tsl//tsl/platform:numbers",
225235
] + if_cuda_is_configured([
226236
"@local_config_nccl//:nccl",
227237
]) + if_rocm_is_configured([
@@ -271,14 +281,11 @@ cc_library(
271281

272282
cc_library(
273283
name = "nvshmem_collectives",
274-
srcs = ["nvshmem_collectives.cc"],
275-
hdrs = ["nvshmem_collectives.h"],
276-
tags = [
277-
"cuda-only",
278-
"gpu",
279-
],
284+
srcs = if_cuda_is_configured(["nvshmem_collectives.cc"]),
285+
hdrs = if_cuda_is_configured(["nvshmem_collectives.h"]),
280286
visibility = ["//visibility:private"],
281287
deps = [
288+
":gpu_collectives",
282289
"//xla/core/collectives",
283290
"//xla/core/collectives:clique_id",
284291
"//xla/core/collectives:clique_key",
@@ -299,9 +306,8 @@ cc_library(
299306
"@com_google_absl//absl/types:span",
300307
"@local_tsl//tsl/platform:casts",
301308
"@local_tsl//tsl/platform:numbers",
302-
"@nvshmem//:nvshmem_lib",
303-
],
304-
alwayslink = True, # registers collectives implementation
309+
] + if_cuda_is_configured(["@nvshmem//:nvshmem_lib"]),
310+
alwayslink = True,
305311
)
306312

307313
xla_cc_test(

third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ limitations under the License.
4444
#include "xla/core/collectives/collectives_registry.h"
4545
#include "xla/core/collectives/communicator.h"
4646
#include "xla/core/collectives/rank_id.h"
47+
#include "xla/debug_options_flags.h"
4748
#include "xla/pjrt/distributed/key_value_store_interface.h"
4849
#include "xla/service/global_device_id.h"
4950
#include "xla/service/gpu/gpu_executable_run_options.h"
@@ -53,6 +54,7 @@ limitations under the License.
5354
#include "xla/tsl/platform/statusor.h"
5455
#include "xla/util.h"
5556
#include "tsl/platform/casts.h"
57+
#include "tsl/platform/numbers.h"
5658

5759
#if TENSORFLOW_USE_ROCM
5860
#include "rocm/rocm_config.h"
@@ -235,7 +237,24 @@ absl::Status NcclCollectives::GroupEnd() {
235237
return XLA_NCCL_STATUS(ncclGroupEnd());
236238
}
237239

240+
static absl::StatusOr<xla::gpu::GpuCollectives*> GetNvshmemCollectives() {
241+
TF_ASSIGN_OR_RETURN(xla::Collectives * collectives,
242+
xla::CollectivesRegistry::Get("gpu", "nvshmem"));
243+
xla::gpu::GpuCollectives* nvshmem_collectives =
244+
tsl::down_cast<xla::gpu::GpuCollectives*>(collectives);
245+
if (nvshmem_collectives == nullptr) {
246+
return absl::InternalError("Failed to get NVSHMEM collectives");
247+
}
248+
249+
return nvshmem_collectives;
250+
}
251+
238252
absl::StatusOr<void*> NcclCollectives::Allocate(uint64_t bytes) {
253+
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
254+
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
255+
return nvshmem_collectives->Allocate(bytes);
256+
}
257+
239258
void* ptr = nullptr;
240259
ncclResult_t res = ncclMemAlloc(&ptr, bytes);
241260
if (res != ncclSuccess) {
@@ -251,6 +270,11 @@ absl::StatusOr<void*> NcclCollectives::Allocate(uint64_t bytes) {
251270
}
252271

253272
absl::Status NcclCollectives::Deallocate(void* location) {
273+
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
274+
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
275+
return nvshmem_collectives->Deallocate(location);
276+
}
277+
254278
ncclResult_t res = ncclMemFree(location);
255279
if (res != ncclSuccess) {
256280
return absl::InternalError(absl::StrFormat(
@@ -318,6 +342,11 @@ class NcclIdStore {
318342

319343
absl::Status NcclCollectives::InitializeTopology(
320344
NcclCollectives::Topology topology) {
345+
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
346+
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
347+
TF_RETURN_IF_ERROR(nvshmem_collectives->InitializeTopology(topology));
348+
}
349+
321350
if (topology.num_nodes > 1) {
322351
auto nccl_id_store = std::make_shared<NcclIdStore>(
323352
topology.node_id, topology.device_id_to_node_id,

third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ NvshmemCollectives* NvshmemCollectives::Default() {
5757
LOG(FATAL) << "Unsupported collectives implementation for NVSHMEM";
5858
}
5959

60+
absl::Status NvshmemCollectives::InitializeTopology(Topology topology) {
61+
SetEnvInfo(topology.node_id, topology.num_nodes,
62+
topology.device_count_per_process, topology.kv_store);
63+
return absl::OkStatus();
64+
}
65+
6066
void NvshmemCollectives::SetEnvInfo(
6167
int process_id, size_t num_processes, size_t device_count_per_process,
6268
std::weak_ptr<KeyValueStoreInterface> kv_store) {

third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.h

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "absl/status/status.h"
2626
#include "absl/status/statusor.h"
2727
#include "absl/types/span.h"
28+
#include "xla/backends/gpu/collectives/gpu_collectives.h"
2829
#include "xla/core/collectives/clique_id.h"
2930
#include "xla/core/collectives/clique_key.h"
3031
#include "xla/core/collectives/collectives.h"
@@ -35,7 +36,7 @@ limitations under the License.
3536
namespace xla::gpu {
3637

3738
// NVIDIA NVSHMEM library
38-
class NvshmemCollectives : public Collectives {
39+
class NvshmemCollectives : public GpuCollectives {
3940
public:
4041
~NvshmemCollectives() override;
4142

@@ -45,28 +46,46 @@ class NvshmemCollectives : public Collectives {
4546
size_t device_count_per_process,
4647
std::weak_ptr<KeyValueStoreInterface> kv_store);
4748

48-
absl::StatusOr<void*> Allocate(uint64_t bytes);
49+
absl::StatusOr<void*> Allocate(uint64_t bytes) final;
4950

50-
absl::Status Deallocate(void* buffer);
51+
absl::Status Deallocate(void* buffer) final;
5152

5253
absl::StatusOr<CliqueId> CreateUniqueCliqueId() const final {
5354
return absl::UnimplementedError("Not implemented.");
5455
}
5556

57+
absl::Status GroupStart() final {
58+
return absl::UnimplementedError("Not implemented.");
59+
}
60+
absl::Status GroupEnd() final {
61+
return absl::UnimplementedError("Not implemented.");
62+
}
63+
64+
bool IsImplemented() const final { return true; }
65+
66+
bool IsGlobalConfig() const final { return false; }
67+
68+
absl::StatusOr<const CliqueIdCallback*> GetCliqueIdCallback(
69+
const CliqueIdCallback* clique_id_callback, bool is_local) final {
70+
return absl::UnimplementedError("Not implemented.");
71+
}
72+
5673
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
5774
CreateCommunicators(const CliqueKey& clique_key,
5875
const std::optional<CliqueIds>& clique_ids,
5976
absl::Span<const DeviceRank> ranks,
60-
const Config& config) final {
77+
const Collectives::Config& config) {
6178
return absl::UnimplementedError("Not implemented.");
6279
}
6380

6481
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>> SplitCommunicators(
6582
absl::Span<const Communicator* const> comms, int32_t color,
66-
absl::Span<const RankId> keys, const Config& config) final {
83+
absl::Span<const RankId> keys, const Collectives::Config& config) final {
6784
return absl::UnimplementedError("Not implemented.");
6885
}
6986

87+
absl::Status InitializeTopology(Topology topology) final;
88+
7089
private:
7190
absl::Status InitializeOnce();
7291

third_party/xla/xla/debug_options_flags.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
167167
opts.set_xla_gpu_nccl_termination_timeout_seconds(-1);
168168
opts.set_xla_gpu_enable_shared_constants(true);
169169
opts.set_xla_gpu_enable_nccl_user_buffers(false);
170+
opts.set_xla_gpu_experimental_enable_nvshmem(false);
170171
opts.set_xla_gpu_enable_nccl_comm_splitting(true);
171172
opts.set_xla_gpu_nccl_init_max_rank_per_root_ratio(0);
172173

@@ -1581,6 +1582,11 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
15811582
"Enables NCCL User Buffer Registration. collective_memory_size in the "
15821583
"allocator config must also be set to a non-zero value that is large "
15831584
"enough to meet peak collective memory usage."));
1585+
flag_list->push_back(tsl::Flag(
1586+
"xla_gpu_experimental_enable_nvshmem",
1587+
bool_setter_for(&DebugOptions::set_xla_gpu_experimental_enable_nvshmem),
1588+
debug_options->xla_gpu_experimental_enable_nvshmem(),
1589+
"Enables NVSHMEM."));
15841590
flag_list->push_back(tsl::Flag(
15851591
"xla_gpu_temp_buffer_use_separate_color",
15861592
bool_setter_for(

third_party/xla/xla/pjrt/gpu/BUILD

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,56 @@ xla_test(
237237
],
238238
)
239239

240+
# TODO(b/409713313): Move this test to collectives directory.
241+
xla_test(
242+
name = "se_gpu_pjrt_client_nvshmem_test",
243+
srcs = ["se_gpu_pjrt_client_nvshmem_test.cc"],
244+
backend_tags = {"gpu": [
245+
"multi_gpu_h100",
246+
"no_oss",
247+
"noasan",
248+
"notap", # TODO(b/399931591): Re-enable once flakiness is resolved.
249+
"nomsan",
250+
]},
251+
backends = ["gpu"],
252+
env = {
253+
"XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true",
254+
},
255+
deps = [
256+
":gpu_topology_proto_cc",
257+
":se_gpu_pjrt_client",
258+
"//xla:shape_util",
259+
"//xla:util",
260+
"//xla:xla_data_proto_cc",
261+
"//xla:xla_proto_cc",
262+
"//xla/backends/gpu/collectives:gpu_collectives",
263+
"//xla/ffi",
264+
"//xla/ffi:ffi_api",
265+
"//xla/hlo/builder:xla_computation",
266+
"//xla/hlo/parser:hlo_parser",
267+
"//xla/hlo/testlib:test",
268+
"//xla/hlo/utils:hlo_query",
269+
"//xla/pjrt:pjrt_client",
270+
"//xla/pjrt:pjrt_compiler",
271+
"//xla/pjrt:pjrt_executable",
272+
"//xla/pjrt:raw_buffer",
273+
"//xla/pjrt/distributed",
274+
"//xla/pjrt/distributed:client",
275+
"//xla/pjrt/distributed:in_memory_key_value_store",
276+
"//xla/pjrt/distributed:service",
277+
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
278+
"//xla/service:platform_util",
279+
"//xla/tests:literal_test_util",
280+
"//xla/tsl/lib/core:status_test_util",
281+
"//xla/tsl/platform:statusor",
282+
"@com_google_absl//absl/log:check",
283+
"@com_google_absl//absl/status:statusor",
284+
"@com_google_absl//absl/strings",
285+
"@com_google_absl//absl/time",
286+
"@com_google_absl//absl/types:span",
287+
],
288+
)
289+
240290
xla_test(
241291
name = "pjrt_client_test_se_gpu",
242292
srcs = ["pjrt_client_test_se_gpu.cc"],

0 commit comments

Comments
 (0)