Skip to content

Commit 9b3cdbe

Browse files
trevor-mtensorflower-gardener
authored andcommitted
PR #21683: [XLA:GPU] NVSHMEM allocation
Imported from GitHub PR openxla/xla#21683 Requires openxla/xla#20395 which adds the NVSHMEM library dependency. This PR adds the following: 1. Nvshmem flag to enable nvshmem 2. Set nvshmem initialization issue when GPU PJRT client is created. The first time NVSHMEM is used, it will be initialized. 3. Uses the user buffer memory pool for nvshmem. If nvshmem is enabled, it will be allocated using `nvshmem_malloc`. This same memory can be used by user buffers if nccl user buffers is also enabled. 4. Update the `CollectiveColorer` so that mosaic_gpu custom calls use the nvshmem memory space. Copybara import of the project: -- aee33791e16ab2149118de728dbb9e62f5e7cc31 by Trevor Morris <tmorris@nvidia.com>: Add nvshmem flag, memory allocation, and memory space assignment Set Nvshmem env info during client creation Rename flag and use absl::string_view -- f8fca39300b3915eb6320142f58fa9c0ec7a1eaa by Trevor Morris <tmorris@nvidia.com>: Use explicit types in test -- e41faa3f72b778fcf8ea8111d3cde59548b8f9f5 by Trevor Morris <tmorris@nvidia.com>: Add user buffer allgather and allreduce tests with and without nvshmem alloc Set nvshmem in XLA_FLAGS test fixes formatting -- cf0c36865de8b8a010caaf62c3a36b64e36037bd by Trevor Morris <tmorris@nvidia.com>: Fixes -- 3b4d11123cdb794d0a60e65b94d22ded04b7b2b4 by Trevor Morris <tmorris@nvidia.com>: Remove early dso check -- 359f2b243ec97b1f8003c27f0b07dde82407ff6c by Trevor Morris <tmorris@nvidia.com>: Add flag comment -- fd15a7cac745adc1971bec63e148047b9b811729 by Trevor Morris <tmorris@nvidia.com>: Also assign memory space for mosaic_gpu_v2 Merging this change closes #21683 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#21683 from trevor-m:nvshmem-upstream-2 fd15a7cac745adc1971bec63e148047b9b811729 PiperOrigin-RevId: 740701134
1 parent f99f00e commit 9b3cdbe

24 files changed

+624
-476
lines changed

third_party/xla/xla/backends/gpu/collectives/BUILD

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,23 @@ package_group(
1818
],
1919
)
2020

21+
config_setting(
22+
name = "arm_build",
23+
values = {"cpu": "arm"},
24+
)
25+
2126
# Build target that registers all available GPU collectives implementations with the collectives
2227
# registry at link time.
2328
cc_library(
2429
name = "gpu_collectives_plugin",
2530
deps = [
2631
":gpu_collectives_stub",
2732
":nccl_collectives",
28-
],
33+
] + select({
34+
# TODO(b/409709288): Fix nvshmem ARM issues and remove this condition.
35+
":arm_build": [],
36+
"//conditions:default": [":nvshmem_collectives"],
37+
}),
2938
)
3039

3140
cc_library(
@@ -132,6 +141,7 @@ cc_library(
132141
srcs = ["gpu_collectives.cc"],
133142
hdrs = ["gpu_collectives.h"],
134143
deps = [
144+
"//xla:executable_run_options",
135145
"//xla:shape_util",
136146
"//xla:util",
137147
"//xla:xla_data_proto_cc",
@@ -140,10 +150,12 @@ cc_library(
140150
"//xla/core/collectives:clique_key",
141151
"//xla/core/collectives:collectives_registry",
142152
"//xla/core/collectives:communicator",
153+
"//xla/pjrt/distributed:key_value_store_interface",
154+
"//xla/service:global_device_id",
143155
"//xla/stream_executor:device_memory",
144156
"//xla/stream_executor:stream",
145157
"//xla/stream_executor:stream_executor_h",
146-
"//xla/tsl/platform:logging",
158+
"@com_google_absl//absl/container:flat_hash_map",
147159
"@com_google_absl//absl/log",
148160
"@com_google_absl//absl/log:check",
149161
"@com_google_absl//absl/status",
@@ -194,9 +206,11 @@ cc_library(
194206
]),
195207
visibility = ["//visibility:private"],
196208
deps = [
209+
":gpu_clique_key",
197210
":gpu_collectives",
198211
":nccl_communicator",
199212
":nccl_errors",
213+
"//xla:debug_options_flags",
200214
"//xla:status_macros",
201215
"//xla:util",
202216
"//xla/core/collectives",
@@ -205,17 +219,24 @@ cc_library(
205219
"//xla/core/collectives:collectives_registry",
206220
"//xla/core/collectives:communicator",
207221
"//xla/core/collectives:rank_id",
222+
"//xla/pjrt/distributed:key_value_store_interface",
223+
"//xla/service:global_device_id",
224+
"//xla/service/gpu:gpu_executable_run_options",
208225
"//xla/tsl/platform:errors",
209226
"//xla/tsl/platform:logging",
210227
"//xla/tsl/platform:statusor",
211228
"@com_google_absl//absl/algorithm:container",
229+
"@com_google_absl//absl/base:core_headers",
230+
"@com_google_absl//absl/container:flat_hash_map",
212231
"@com_google_absl//absl/status",
213232
"@com_google_absl//absl/status:statusor",
214233
"@com_google_absl//absl/strings",
215234
"@com_google_absl//absl/strings:str_format",
216235
"@com_google_absl//absl/strings:string_view",
236+
"@com_google_absl//absl/synchronization",
217237
"@com_google_absl//absl/types:span",
218238
"@local_tsl//tsl/platform:casts",
239+
"@local_tsl//tsl/platform:numbers",
219240
] + if_cuda_is_configured([
220241
"@local_config_nccl//:nccl",
221242
]) + if_rocm_is_configured([
@@ -265,14 +286,11 @@ cc_library(
265286

266287
cc_library(
267288
name = "nvshmem_collectives",
268-
srcs = ["nvshmem_collectives.cc"],
269-
hdrs = ["nvshmem_collectives.h"],
270-
tags = [
271-
"cuda-only",
272-
"gpu",
273-
],
289+
srcs = if_cuda_is_configured(["nvshmem_collectives.cc"]),
290+
hdrs = if_cuda_is_configured(["nvshmem_collectives.h"]),
274291
visibility = ["//visibility:private"],
275292
deps = [
293+
":gpu_collectives",
276294
"//xla/core/collectives",
277295
"//xla/core/collectives:clique_id",
278296
"//xla/core/collectives:clique_key",
@@ -293,9 +311,8 @@ cc_library(
293311
"@com_google_absl//absl/types:span",
294312
"@local_tsl//tsl/platform:casts",
295313
"@local_tsl//tsl/platform:numbers",
296-
"@nvshmem//:nvshmem_lib",
297-
],
298-
alwayslink = True, # registers collectives implementation
314+
] + if_cuda_is_configured(["@nvshmem//:nvshmem_lib"]),
315+
alwayslink = True,
299316
)
300317

301318
xla_cc_test(

third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,18 @@ limitations under the License.
1919
#include <cstddef>
2020
#include <cstdint>
2121
#include <functional>
22+
#include <memory>
2223

24+
#include "absl/container/flat_hash_map.h"
2325
#include "absl/status/status.h"
2426
#include "absl/status/statusor.h"
2527
#include "xla/core/collectives/clique_id.h"
2628
#include "xla/core/collectives/clique_key.h"
2729
#include "xla/core/collectives/collectives.h"
2830
#include "xla/core/collectives/communicator.h"
31+
#include "xla/executable_run_options.h"
32+
#include "xla/pjrt/distributed/key_value_store_interface.h"
33+
#include "xla/service/global_device_id.h"
2934
#include "xla/stream_executor/device_memory.h"
3035
#include "xla/stream_executor/stream.h"
3136
#include "xla/stream_executor/stream_executor.h"
@@ -103,6 +108,23 @@ class GpuCollectives : public Collectives {
103108
// Tries to cast a Collectives::Config to a GpuCollectives::Config.
104109
static absl::StatusOr<const Config*> TryCast(
105110
const Collectives::Config* config);
111+
112+
// TODO(patrios): Use smart wrapper instead of void*.
113+
virtual absl::StatusOr<void*> Allocate(uint64_t bytes) = 0;
114+
115+
virtual absl::Status Deallocate(void* buffer) = 0;
116+
117+
struct Topology {
118+
int32_t node_id;
119+
int32_t num_nodes;
120+
size_t device_count_per_process;
121+
std::shared_ptr<KeyValueStoreInterface> kv_store;
122+
absl::flat_hash_map<GlobalDeviceId, int32_t> device_id_to_node_id;
123+
gpu::GpuExecutableRunOptions* gpu_executable_run_options;
124+
};
125+
126+
// Initializes the topology information for the collectives backend.
127+
virtual absl::Status InitializeTopology(Topology topology) = 0;
106128
};
107129

108130
} // namespace xla::gpu

third_party/xla/xla/backends/gpu/collectives/gpu_collectives_stub.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ class GpuCollectivesStub : public GpuCollectives {
6464

6565
absl::Status GroupStart() final { return UnimplementedError(); }
6666
absl::Status GroupEnd() final { return UnimplementedError(); }
67+
absl::StatusOr<void*> Allocate(uint64_t bytes) final {
68+
return UnimplementedError();
69+
}
70+
71+
absl::Status Deallocate(void* buffer) final { return UnimplementedError(); }
72+
73+
absl::Status InitializeTopology(Topology topology) final {
74+
return UnimplementedError();
75+
}
6776

6877
protected:
6978
static absl::Status UnimplementedError() {

third_party/xla/xla/backends/gpu/collectives/nccl_collectives.cc

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,21 @@ limitations under the License.
2020
#include <memory>
2121
#include <optional>
2222
#include <string>
23+
#include <utility>
2324
#include <vector>
2425

2526
#include "absl/algorithm/container.h"
27+
#include "absl/base/thread_annotations.h"
28+
#include "absl/container/flat_hash_map.h"
2629
#include "absl/status/status.h"
2730
#include "absl/status/statusor.h"
2831
#include "absl/strings/str_cat.h"
2932
#include "absl/strings/str_format.h"
3033
#include "absl/strings/str_join.h"
3134
#include "absl/strings/string_view.h"
35+
#include "absl/synchronization/mutex.h"
3236
#include "absl/types/span.h"
37+
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
3338
#include "xla/backends/gpu/collectives/gpu_collectives.h"
3439
#include "xla/backends/gpu/collectives/nccl_communicator.h"
3540
#include "xla/backends/gpu/collectives/nccl_errors.h"
@@ -39,12 +44,17 @@ limitations under the License.
3944
#include "xla/core/collectives/collectives_registry.h"
4045
#include "xla/core/collectives/communicator.h"
4146
#include "xla/core/collectives/rank_id.h"
47+
#include "xla/debug_options_flags.h"
48+
#include "xla/pjrt/distributed/key_value_store_interface.h"
49+
#include "xla/service/global_device_id.h"
50+
#include "xla/service/gpu/gpu_executable_run_options.h"
4251
#include "xla/status_macros.h"
4352
#include "xla/tsl/platform/errors.h"
4453
#include "xla/tsl/platform/logging.h"
4554
#include "xla/tsl/platform/statusor.h"
4655
#include "xla/util.h"
4756
#include "tsl/platform/casts.h"
57+
#include "tsl/platform/numbers.h"
4858

4959
#if TENSORFLOW_USE_ROCM
5060
#include "rocm/rocm_config.h"
@@ -227,6 +237,128 @@ absl::Status NcclCollectives::GroupEnd() {
227237
return XLA_NCCL_STATUS(ncclGroupEnd());
228238
}
229239

240+
static absl::StatusOr<xla::gpu::GpuCollectives*> GetNvshmemCollectives() {
241+
TF_ASSIGN_OR_RETURN(xla::Collectives * collectives,
242+
xla::CollectivesRegistry::Get("gpu", "nvshmem"));
243+
xla::gpu::GpuCollectives* nvshmem_collectives =
244+
tsl::down_cast<xla::gpu::GpuCollectives*>(collectives);
245+
if (nvshmem_collectives == nullptr) {
246+
return absl::InternalError("Failed to get NVSHMEM collectives");
247+
}
248+
249+
return nvshmem_collectives;
250+
}
251+
252+
absl::StatusOr<void*> NcclCollectives::Allocate(uint64_t bytes) {
253+
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
254+
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
255+
return nvshmem_collectives->Allocate(bytes);
256+
}
257+
258+
void* ptr = nullptr;
259+
ncclResult_t res = ncclMemAlloc(&ptr, bytes);
260+
if (res != ncclSuccess) {
261+
return absl::InternalError(absl::StrFormat(
262+
"failed to allocate %s (%llu bytes) from device collective memory: %s, "
263+
"Last NCCL warning(error) log entry (may be unrelated): %s",
264+
tsl::strings::HumanReadableNumBytes(bytes), bytes,
265+
ncclGetErrorString(res), ncclGetLastError(nullptr)));
266+
}
267+
VLOG(2) << "Allocated collective memory " << ptr << " of " << bytes
268+
<< " bytes";
269+
return ptr;
270+
}
271+
272+
absl::Status NcclCollectives::Deallocate(void* location) {
273+
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
274+
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
275+
return nvshmem_collectives->Deallocate(location);
276+
}
277+
278+
ncclResult_t res = ncclMemFree(location);
279+
if (res != ncclSuccess) {
280+
return absl::InternalError(absl::StrFormat(
281+
"failed to free device collective memory at %p; result: %s, Last NCCL "
282+
"warning(error) log entry (may be unrelated): %s",
283+
location, ncclGetErrorString(res), ncclGetLastError(nullptr)));
284+
}
285+
286+
VLOG(2) << "Deallocated collective memory " << location;
287+
return absl::OkStatus();
288+
}
289+
290+
class NcclIdStore {
291+
public:
292+
NcclIdStore(int node_id,
293+
absl::flat_hash_map<GlobalDeviceId, int> device_to_node,
294+
std::shared_ptr<KeyValueStoreInterface> kv_store)
295+
: node_id_(node_id),
296+
device_to_node_(std::move(device_to_node)),
297+
kv_store_(std::move(kv_store)) {}
298+
299+
absl::StatusOr<CliqueId> GetNcclUniqueId(const CliqueKey& key) {
300+
auto* gpu_key = tsl::down_cast<const gpu::GpuCliqueKey*>(&key);
301+
if (gpu_key == nullptr) {
302+
return InvalidArgument("Expected GPU clique key");
303+
}
304+
305+
// The caller must ensure that threads calling this method concurrently have
306+
// unique keys, otherwise the global key-value store may hold the wrong
307+
// value.
308+
{
309+
absl::MutexLock lock(&mu_);
310+
auto it = cache_.find(*gpu_key);
311+
if (it != cache_.end()) {
312+
return it->second;
313+
}
314+
}
315+
CliqueId clique_id;
316+
int primary_node_id = device_to_node_.at(gpu_key->root_device());
317+
if (node_id_ == primary_node_id) {
318+
TF_ASSIGN_OR_RETURN(
319+
clique_id, gpu::GpuCollectives::Default()->CreateUniqueCliqueId());
320+
TF_RETURN_IF_ERROR(
321+
kv_store_->Set(gpu_key->ToString(), clique_id.ToString()));
322+
} else {
323+
TF_ASSIGN_OR_RETURN(
324+
std::string id_str,
325+
kv_store_->Get(gpu_key->ToString(), absl::Minutes(10)));
326+
clique_id = CliqueId(id_str);
327+
}
328+
absl::MutexLock lock(&mu_);
329+
auto result = cache_.emplace(*gpu_key, std::move(clique_id));
330+
TF_RET_CHECK(result.second) << "Unique ID already in cache.";
331+
return result.first->second;
332+
}
333+
334+
private:
335+
const int node_id_;
336+
const absl::flat_hash_map<GlobalDeviceId, int> device_to_node_;
337+
const std::shared_ptr<KeyValueStoreInterface> kv_store_;
338+
339+
absl::Mutex mu_;
340+
absl::flat_hash_map<gpu::GpuCliqueKey, CliqueId> cache_ ABSL_GUARDED_BY(mu_);
341+
};
342+
343+
absl::Status NcclCollectives::InitializeTopology(
344+
NcclCollectives::Topology topology) {
345+
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
346+
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
347+
TF_RETURN_IF_ERROR(nvshmem_collectives->InitializeTopology(topology));
348+
}
349+
350+
if (topology.num_nodes > 1) {
351+
auto nccl_id_store = std::make_shared<NcclIdStore>(
352+
topology.node_id, topology.device_id_to_node_id,
353+
std::move(topology.kv_store));
354+
topology.gpu_executable_run_options->set_clique_id_callback(
355+
[nccl_id_store](const CliqueKey& key) {
356+
return nccl_id_store->GetNcclUniqueId(key);
357+
});
358+
}
359+
return absl::OkStatus();
360+
}
361+
230362
} // namespace xla::gpu
231363

232364
XLA_COLLECTIVES_REGISTER("gpu", "nccl", 1,

third_party/xla/xla/backends/gpu/collectives/nccl_collectives.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ class NcclCollectives : public GpuCollectives {
5757
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>> SplitCommunicators(
5858
absl::Span<const Communicator* const> comms, int32_t color,
5959
absl::Span<const RankId> keys, const Collectives::Config& config) final;
60+
61+
absl::StatusOr<void*> Allocate(uint64_t bytes) final;
62+
63+
absl::Status Deallocate(void* location) final;
64+
65+
absl::Status InitializeTopology(Topology topology) final;
6066
};
6167

6268
} // namespace xla::gpu

third_party/xla/xla/backends/gpu/collectives/nvshmem_collectives.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ NvshmemCollectives* NvshmemCollectives::Default() {
5757
LOG(FATAL) << "Unsupported collectives implementation for NVSHMEM";
5858
}
5959

60+
absl::Status NvshmemCollectives::InitializeTopology(Topology topology) {
61+
SetEnvInfo(topology.node_id, topology.num_nodes,
62+
topology.device_count_per_process, topology.kv_store);
63+
return absl::OkStatus();
64+
}
65+
6066
void NvshmemCollectives::SetEnvInfo(
6167
int process_id, size_t num_processes, size_t device_count_per_process,
6268
std::weak_ptr<KeyValueStoreInterface> kv_store) {

0 commit comments

Comments
 (0)