Skip to content

Commit 1984135

Browse files
Rollback nvshmem allocator usage due to tests failure
Reverts e514c1f PiperOrigin-RevId: 747891808
1 parent 7245544 commit 1984135

File tree

11 files changed

+26
-425
lines changed

11 files changed

+26
-425
lines changed

third_party/xla/xla/backends/gpu/collectives/BUILD

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,13 @@ package_group(
1818
],
1919
)
2020

21-
config_setting(
22-
name = "arm_build",
23-
values = {"cpu": "arm"},
24-
)
25-
2621
# Build target that registers all available GPU collectives implementations with the collectives
2722
# registry at link time.
2823
cc_library(
2924
name = "gpu_collectives_plugin",
3025
deps = [
3126
":gpu_collectives_stub",
32-
] + if_nccl([":nccl_collectives"]) + select({
33-
# TODO(b/409709288): Fix nvshmem ARM issues and remove this condition.
34-
":arm_build": [],
35-
"//conditions:default": [":nvshmem_collectives"],
36-
}),
27+
] + if_nccl([":nccl_collectives"]),
3728
)
3829

3930
cc_library(
@@ -231,7 +222,6 @@ cc_library(
231222
"@com_google_absl//absl/synchronization",
232223
"@com_google_absl//absl/types:span",
233224
"@local_tsl//tsl/platform:casts",
234-
"@local_tsl//tsl/platform:numbers",
235225
] + if_cuda_is_configured([
236226
"@local_config_nccl//:nccl",
237227
]) + if_rocm_is_configured([
@@ -281,11 +271,14 @@ cc_library(
281271

282272
cc_library(
283273
name = "nvshmem_collectives",
284-
srcs = if_cuda_is_configured(["nvshmem_collectives.cc"]),
285-
hdrs = if_cuda_is_configured(["nvshmem_collectives.h"]),
274+
srcs = ["nvshmem_collectives.cc"],
275+
hdrs = ["nvshmem_collectives.h"],
276+
tags = [
277+
"cuda-only",
278+
"gpu",
279+
],
286280
visibility = ["//visibility:private"],
287281
deps = [
288-
":gpu_collectives",
289282
"//xla/core/collectives",
290283
"//xla/core/collectives:clique_id",
291284
"//xla/core/collectives:clique_key",
@@ -306,8 +299,9 @@ cc_library(
306299
"@com_google_absl//absl/types:span",
307300
"@local_tsl//tsl/platform:casts",
308301
"@local_tsl//tsl/platform:numbers",
309-
] + if_cuda_is_configured(["@nvshmem//:nvshmem_lib"]),
310-
alwayslink = True,
302+
"@nvshmem//:nvshmem_lib",
303+
],
304+
alwayslink = True, # registers collectives implementation
311305
)
312306

313307
xla_cc_test(

third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ limitations under the License.
4444
#include "xla/core/collectives/collectives_registry.h"
4545
#include "xla/core/collectives/communicator.h"
4646
#include "xla/core/collectives/rank_id.h"
47-
#include "xla/debug_options_flags.h"
4847
#include "xla/pjrt/distributed/key_value_store_interface.h"
4948
#include "xla/service/global_device_id.h"
5049
#include "xla/service/gpu/gpu_executable_run_options.h"
@@ -54,7 +53,6 @@ limitations under the License.
5453
#include "xla/tsl/platform/statusor.h"
5554
#include "xla/util.h"
5655
#include "tsl/platform/casts.h"
57-
#include "tsl/platform/numbers.h"
5856

5957
#if TENSORFLOW_USE_ROCM
6058
#include "rocm/rocm_config.h"
@@ -237,24 +235,7 @@ absl::Status NcclCollectives::GroupEnd() {
237235
return XLA_NCCL_STATUS(ncclGroupEnd());
238236
}
239237

240-
static absl::StatusOr<xla::gpu::GpuCollectives*> GetNvshmemCollectives() {
241-
TF_ASSIGN_OR_RETURN(xla::Collectives * collectives,
242-
xla::CollectivesRegistry::Get("gpu", "nvshmem"));
243-
xla::gpu::GpuCollectives* nvshmem_collectives =
244-
tsl::down_cast<xla::gpu::GpuCollectives*>(collectives);
245-
if (nvshmem_collectives == nullptr) {
246-
return absl::InternalError("Failed to get NVSHMEM collectives");
247-
}
248-
249-
return nvshmem_collectives;
250-
}
251-
252238
absl::StatusOr<void*> NcclCollectives::Allocate(uint64_t bytes) {
253-
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
254-
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
255-
return nvshmem_collectives->Allocate(bytes);
256-
}
257-
258239
void* ptr = nullptr;
259240
ncclResult_t res = ncclMemAlloc(&ptr, bytes);
260241
if (res != ncclSuccess) {
@@ -270,11 +251,6 @@ absl::StatusOr<void*> NcclCollectives::Allocate(uint64_t bytes) {
270251
}
271252

272253
absl::Status NcclCollectives::Deallocate(void* location) {
273-
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
274-
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
275-
return nvshmem_collectives->Deallocate(location);
276-
}
277-
278254
ncclResult_t res = ncclMemFree(location);
279255
if (res != ncclSuccess) {
280256
return absl::InternalError(absl::StrFormat(
@@ -342,11 +318,6 @@ class NcclIdStore {
342318

343319
absl::Status NcclCollectives::InitializeTopology(
344320
NcclCollectives::Topology topology) {
345-
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
346-
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
347-
TF_RETURN_IF_ERROR(nvshmem_collectives->InitializeTopology(topology));
348-
}
349-
350321
if (topology.num_nodes > 1) {
351322
auto nccl_id_store = std::make_shared<NcclIdStore>(
352323
topology.node_id, topology.device_id_to_node_id,

third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,6 @@ NvshmemCollectives* NvshmemCollectives::Default() {
5757
LOG(FATAL) << "Unsupported collectives implementation for NVSHMEM";
5858
}
5959

60-
absl::Status NvshmemCollectives::InitializeTopology(Topology topology) {
61-
SetEnvInfo(topology.node_id, topology.num_nodes,
62-
topology.device_count_per_process, topology.kv_store);
63-
return absl::OkStatus();
64-
}
65-
6660
void NvshmemCollectives::SetEnvInfo(
6761
int process_id, size_t num_processes, size_t device_count_per_process,
6862
std::weak_ptr<KeyValueStoreInterface> kv_store) {

third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.h

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ limitations under the License.
2525
#include "absl/status/status.h"
2626
#include "absl/status/statusor.h"
2727
#include "absl/types/span.h"
28-
#include "xla/backends/gpu/collectives/gpu_collectives.h"
2928
#include "xla/core/collectives/clique_id.h"
3029
#include "xla/core/collectives/clique_key.h"
3130
#include "xla/core/collectives/collectives.h"
@@ -36,7 +35,7 @@ limitations under the License.
3635
namespace xla::gpu {
3736

3837
// NVIDIA NVSHMEM library
39-
class NvshmemCollectives : public GpuCollectives {
38+
class NvshmemCollectives : public Collectives {
4039
public:
4140
~NvshmemCollectives() override;
4241

@@ -46,46 +45,28 @@ class NvshmemCollectives : public GpuCollectives {
4645
size_t device_count_per_process,
4746
std::weak_ptr<KeyValueStoreInterface> kv_store);
4847

49-
absl::StatusOr<void*> Allocate(uint64_t bytes) final;
48+
absl::StatusOr<void*> Allocate(uint64_t bytes);
5049

51-
absl::Status Deallocate(void* buffer) final;
50+
absl::Status Deallocate(void* buffer);
5251

5352
absl::StatusOr<CliqueId> CreateUniqueCliqueId() const final {
5453
return absl::UnimplementedError("Not implemented.");
5554
}
5655

57-
absl::Status GroupStart() final {
58-
return absl::UnimplementedError("Not implemented.");
59-
}
60-
absl::Status GroupEnd() final {
61-
return absl::UnimplementedError("Not implemented.");
62-
}
63-
64-
bool IsImplemented() const final { return true; }
65-
66-
bool IsGlobalConfig() const final { return false; }
67-
68-
absl::StatusOr<const CliqueIdCallback*> GetCliqueIdCallback(
69-
const CliqueIdCallback* clique_id_callback, bool is_local) final {
70-
return absl::UnimplementedError("Not implemented.");
71-
}
72-
7356
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>>
7457
CreateCommunicators(const CliqueKey& clique_key,
7558
const std::optional<CliqueIds>& clique_ids,
7659
absl::Span<const DeviceRank> ranks,
77-
const Collectives::Config& config) {
60+
const Config& config) final {
7861
return absl::UnimplementedError("Not implemented.");
7962
}
8063

8164
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>> SplitCommunicators(
8265
absl::Span<const Communicator* const> comms, int32_t color,
83-
absl::Span<const RankId> keys, const Collectives::Config& config) final {
66+
absl::Span<const RankId> keys, const Config& config) final {
8467
return absl::UnimplementedError("Not implemented.");
8568
}
8669

87-
absl::Status InitializeTopology(Topology topology) final;
88-
8970
private:
9071
absl::Status InitializeOnce();
9172

third_party/xla/xla/debug_options_flags.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
167167
opts.set_xla_gpu_nccl_termination_timeout_seconds(-1);
168168
opts.set_xla_gpu_enable_shared_constants(true);
169169
opts.set_xla_gpu_enable_nccl_user_buffers(false);
170-
opts.set_xla_gpu_experimental_enable_nvshmem(false);
171170
opts.set_xla_gpu_enable_nccl_comm_splitting(true);
172171
opts.set_xla_gpu_nccl_init_max_rank_per_root_ratio(0);
173172

@@ -1582,11 +1581,6 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
15821581
"Enables NCCL User Buffer Registration. collective_memory_size in the "
15831582
"allocator config must also be set to a non-zero value that is large "
15841583
"enough to meet peak collective memory usage."));
1585-
flag_list->push_back(tsl::Flag(
1586-
"xla_gpu_experimental_enable_nvshmem",
1587-
bool_setter_for(&DebugOptions::set_xla_gpu_experimental_enable_nvshmem),
1588-
debug_options->xla_gpu_experimental_enable_nvshmem(),
1589-
"Enables NVSHMEM."));
15901584
flag_list->push_back(tsl::Flag(
15911585
"xla_gpu_temp_buffer_use_separate_color",
15921586
bool_setter_for(

third_party/xla/xla/pjrt/gpu/BUILD

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -237,56 +237,6 @@ xla_test(
237237
],
238238
)
239239

240-
# TODO(b/409713313): Move this test to collectives directory.
241-
xla_test(
242-
name = "se_gpu_pjrt_client_nvshmem_test",
243-
srcs = ["se_gpu_pjrt_client_nvshmem_test.cc"],
244-
backend_tags = {"gpu": [
245-
"multi_gpu_h100",
246-
"no_oss",
247-
"noasan",
248-
"notap", # TODO(b/399931591): Re-enable once flakiness is resolved.
249-
"nomsan",
250-
]},
251-
backends = ["gpu"],
252-
env = {
253-
"XLA_FLAGS": "--xla_gpu_experimental_enable_nvshmem=true",
254-
},
255-
deps = [
256-
":gpu_topology_proto_cc",
257-
":se_gpu_pjrt_client",
258-
"//xla:shape_util",
259-
"//xla:util",
260-
"//xla:xla_data_proto_cc",
261-
"//xla:xla_proto_cc",
262-
"//xla/backends/gpu/collectives:gpu_collectives",
263-
"//xla/ffi",
264-
"//xla/ffi:ffi_api",
265-
"//xla/hlo/builder:xla_computation",
266-
"//xla/hlo/parser:hlo_parser",
267-
"//xla/hlo/testlib:test",
268-
"//xla/hlo/utils:hlo_query",
269-
"//xla/pjrt:pjrt_client",
270-
"//xla/pjrt:pjrt_compiler",
271-
"//xla/pjrt:pjrt_executable",
272-
"//xla/pjrt:raw_buffer",
273-
"//xla/pjrt/distributed",
274-
"//xla/pjrt/distributed:client",
275-
"//xla/pjrt/distributed:in_memory_key_value_store",
276-
"//xla/pjrt/distributed:service",
277-
"//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options",
278-
"//xla/service:platform_util",
279-
"//xla/tests:literal_test_util",
280-
"//xla/tsl/lib/core:status_test_util",
281-
"//xla/tsl/platform:statusor",
282-
"@com_google_absl//absl/log:check",
283-
"@com_google_absl//absl/status:statusor",
284-
"@com_google_absl//absl/strings",
285-
"@com_google_absl//absl/time",
286-
"@com_google_absl//absl/types:span",
287-
],
288-
)
289-
290240
xla_test(
291241
name = "pjrt_client_test_se_gpu",
292242
srcs = ["pjrt_client_test_se_gpu.cc"],

0 commit comments

Comments
 (0)