Skip to content

Commit 5a9082e

Browse files
trevor-mGoogle-ML-Automation
authored andcommitted
PR #21683: [XLA:GPU] NVSHMEM allocation
Imported from GitHub PR #21683 Requires #20395 which adds the NVSHMEM library dependency. This PR adds the following: 1. Nvshmem flag to enable nvshmem 2. Set nvshmem initialization issue when GPU PJRT client is created. The first time NVSHMEM is used, it will be initialized. 3. Uses the user buffer memory pool for nvshmem. If nvshmem is enabled, it will be allocated using `nvshmem_malloc`. This same memory can be used by user buffers if nccl user buffers is also enabled. 4. Update the `CollectiveColorer` so that mosaic_gpu custom calls use the nvshmem memory space. Copybara import of the project: -- aee3379 by Trevor Morris <tmorris@nvidia.com>: Add nvshmem flag, memory allocation, and memory space assignment Set Nvshmem env info during client creation Rename flag and use absl::string_view -- f8fca39 by Trevor Morris <tmorris@nvidia.com>: Use explicit types in test -- e41faa3 by Trevor Morris <tmorris@nvidia.com>: Add user buffer allgather and allreduce tests with and without nvshmem alloc Set nvshmem in XLA_FLAGS test fixes formatting -- cf0c368 by Trevor Morris <tmorris@nvidia.com>: Fixes -- 3b4d111 by Trevor Morris <tmorris@nvidia.com>: Remove early dso check -- 359f2b2 by Trevor Morris <tmorris@nvidia.com>: Add flag comment -- fd15a7c by Trevor Morris <tmorris@nvidia.com>: Also assign memory space for mosaic_gpu_v2 Merging this change closes #21683 FUTURE_COPYBARA_INTEGRATE_REVIEW=#21683 from trevor-m:nvshmem-upstream-2 fd15a7c PiperOrigin-RevId: 740701134
1 parent ad59fdf commit 5a9082e

24 files changed

+631
-478
lines changed

xla/backends/gpu/collectives/BUILD

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,23 @@ package_group(
1818
],
1919
)
2020

21+
config_setting(
22+
name = "arm_build",
23+
values = {"cpu": "arm"},
24+
)
25+
2126
# Build target that registers all available GPU collectives implementations with the collectives
2227
# registry at link time.
2328
cc_library(
2429
name = "gpu_collectives_plugin",
2530
deps = [
2631
":gpu_collectives_stub",
2732
":nccl_collectives",
28-
],
33+
] + select({
34+
# TODO(b/409709288): Fix nvshmem ARM issues and remove this condition.
35+
":arm_build": [],
36+
"//conditions:default": [":nvshmem_collectives"],
37+
}),
2938
)
3039

3140
cc_library(
@@ -132,6 +141,7 @@ cc_library(
132141
srcs = ["gpu_collectives.cc"],
133142
hdrs = ["gpu_collectives.h"],
134143
deps = [
144+
"//xla:executable_run_options",
135145
"//xla:shape_util",
136146
"//xla:util",
137147
"//xla:xla_data_proto_cc",
@@ -140,10 +150,12 @@ cc_library(
140150
"//xla/core/collectives:clique_key",
141151
"//xla/core/collectives:collectives_registry",
142152
"//xla/core/collectives:communicator",
153+
"//xla/pjrt/distributed:key_value_store_interface",
154+
"//xla/service:global_device_id",
143155
"//xla/stream_executor:device_memory",
144156
"//xla/stream_executor:stream",
145157
"//xla/stream_executor:stream_executor_h",
146-
"//xla/tsl/platform:logging",
158+
"@com_google_absl//absl/container:flat_hash_map",
147159
"@com_google_absl//absl/log",
148160
"@com_google_absl//absl/log:check",
149161
"@com_google_absl//absl/status",
@@ -194,9 +206,11 @@ cc_library(
194206
]),
195207
visibility = ["//visibility:private"],
196208
deps = [
209+
":gpu_clique_key",
197210
":gpu_collectives",
198211
":nccl_communicator",
199212
":nccl_errors",
213+
"//xla:debug_options_flags",
200214
"//xla:status_macros",
201215
"//xla:util",
202216
"//xla/core/collectives",
@@ -205,17 +219,24 @@ cc_library(
205219
"//xla/core/collectives:collectives_registry",
206220
"//xla/core/collectives:communicator",
207221
"//xla/core/collectives:rank_id",
222+
"//xla/pjrt/distributed:key_value_store_interface",
223+
"//xla/service:global_device_id",
224+
"//xla/service/gpu:gpu_executable_run_options",
208225
"//xla/tsl/platform:errors",
209226
"//xla/tsl/platform:logging",
210227
"//xla/tsl/platform:statusor",
211228
"@com_google_absl//absl/algorithm:container",
229+
"@com_google_absl//absl/base:core_headers",
230+
"@com_google_absl//absl/container:flat_hash_map",
212231
"@com_google_absl//absl/status",
213232
"@com_google_absl//absl/status:statusor",
214233
"@com_google_absl//absl/strings",
215234
"@com_google_absl//absl/strings:str_format",
216235
"@com_google_absl//absl/strings:string_view",
236+
"@com_google_absl//absl/synchronization",
217237
"@com_google_absl//absl/types:span",
218238
"@tsl//tsl/platform:casts",
239+
"@tsl//tsl/platform:numbers",
219240
] + if_cuda_is_configured([
220241
"@local_config_nccl//:nccl",
221242
]) + if_rocm_is_configured([
@@ -265,14 +286,15 @@ cc_library(
265286

266287
cc_library(
267288
name = "nvshmem_collectives",
268-
srcs = ["nvshmem_collectives.cc"],
269-
hdrs = ["nvshmem_collectives.h"],
289+
srcs = if_cuda_is_configured(["nvshmem_collectives.cc"]),
290+
hdrs = if_cuda_is_configured(["nvshmem_collectives.h"]),
270291
tags = [
271292
"cuda-only",
272293
"gpu",
273294
],
274295
visibility = ["//visibility:private"],
275296
deps = [
297+
":gpu_collectives",
276298
"//xla/core/collectives",
277299
"//xla/core/collectives:clique_id",
278300
"//xla/core/collectives:clique_key",
@@ -295,7 +317,7 @@ cc_library(
295317
"@tsl//tsl/platform:casts",
296318
"@tsl//tsl/platform:numbers",
297319
],
298-
alwayslink = True, # registers collectives implementation
320+
alwayslink = True,
299321
)
300322

301323
xla_cc_test(

xla/backends/gpu/collectives/gpu_collectives.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,18 @@ limitations under the License.
1919
#include <cstddef>
2020
#include <cstdint>
2121
#include <functional>
22+
#include <memory>
2223

24+
#include "absl/container/flat_hash_map.h"
2325
#include "absl/status/status.h"
2426
#include "absl/status/statusor.h"
2527
#include "xla/core/collectives/clique_id.h"
2628
#include "xla/core/collectives/clique_key.h"
2729
#include "xla/core/collectives/collectives.h"
2830
#include "xla/core/collectives/communicator.h"
31+
#include "xla/executable_run_options.h"
32+
#include "xla/pjrt/distributed/key_value_store_interface.h"
33+
#include "xla/service/global_device_id.h"
2934
#include "xla/stream_executor/device_memory.h"
3035
#include "xla/stream_executor/stream.h"
3136
#include "xla/stream_executor/stream_executor.h"
@@ -103,6 +108,23 @@ class GpuCollectives : public Collectives {
103108
// Tries to cast a Collectives::Config to a GpuCollectives::Config.
104109
static absl::StatusOr<const Config*> TryCast(
105110
const Collectives::Config* config);
111+
112+
// TODO(patrios): Use smart wrapper instead of void*.
113+
virtual absl::StatusOr<void*> Allocate(uint64_t bytes) = 0;
114+
115+
virtual absl::Status Deallocate(void* buffer) = 0;
116+
117+
struct Topology {
118+
int32_t node_id;
119+
int32_t num_nodes;
120+
size_t device_count_per_process;
121+
std::shared_ptr<KeyValueStoreInterface> kv_store;
122+
absl::flat_hash_map<GlobalDeviceId, int32_t> device_id_to_node_id;
123+
gpu::GpuExecutableRunOptions* gpu_executable_run_options;
124+
};
125+
126+
// Initializes the topology information for the collectives backend.
127+
virtual absl::Status InitializeTopology(Topology topology) = 0;
106128
};
107129

108130
} // namespace xla::gpu

xla/backends/gpu/collectives/gpu_collectives_stub.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ class GpuCollectivesStub : public GpuCollectives {
6464

6565
absl::Status GroupStart() final { return UnimplementedError(); }
6666
absl::Status GroupEnd() final { return UnimplementedError(); }
67+
absl::StatusOr<void*> Allocate(uint64_t bytes) final {
68+
return UnimplementedError();
69+
}
70+
71+
absl::Status Deallocate(void* buffer) final { return UnimplementedError(); }
72+
73+
absl::Status InitializeTopology(Topology topology) final {
74+
return UnimplementedError();
75+
}
6776

6877
protected:
6978
static absl::Status UnimplementedError() {

xla/backends/gpu/collectives/nccl_collectives.cc

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,21 @@ limitations under the License.
2020
#include <memory>
2121
#include <optional>
2222
#include <string>
23+
#include <utility>
2324
#include <vector>
2425

2526
#include "absl/algorithm/container.h"
27+
#include "absl/base/thread_annotations.h"
28+
#include "absl/container/flat_hash_map.h"
2629
#include "absl/status/status.h"
2730
#include "absl/status/statusor.h"
2831
#include "absl/strings/str_cat.h"
2932
#include "absl/strings/str_format.h"
3033
#include "absl/strings/str_join.h"
3134
#include "absl/strings/string_view.h"
35+
#include "absl/synchronization/mutex.h"
3236
#include "absl/types/span.h"
37+
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
3338
#include "xla/backends/gpu/collectives/gpu_collectives.h"
3439
#include "xla/backends/gpu/collectives/nccl_communicator.h"
3540
#include "xla/backends/gpu/collectives/nccl_errors.h"
@@ -39,12 +44,17 @@ limitations under the License.
3944
#include "xla/core/collectives/collectives_registry.h"
4045
#include "xla/core/collectives/communicator.h"
4146
#include "xla/core/collectives/rank_id.h"
47+
#include "xla/debug_options_flags.h"
48+
#include "xla/pjrt/distributed/key_value_store_interface.h"
49+
#include "xla/service/global_device_id.h"
50+
#include "xla/service/gpu/gpu_executable_run_options.h"
4251
#include "xla/status_macros.h"
4352
#include "xla/tsl/platform/errors.h"
4453
#include "xla/tsl/platform/logging.h"
4554
#include "xla/tsl/platform/statusor.h"
4655
#include "xla/util.h"
4756
#include "tsl/platform/casts.h"
57+
#include "tsl/platform/numbers.h"
4858

4959
#if TENSORFLOW_USE_ROCM
5060
#include "rocm/rocm_config.h"
@@ -227,6 +237,128 @@ absl::Status NcclCollectives::GroupEnd() {
227237
return XLA_NCCL_STATUS(ncclGroupEnd());
228238
}
229239

240+
static absl::StatusOr<xla::gpu::GpuCollectives*> GetNvshmemCollectives() {
241+
TF_ASSIGN_OR_RETURN(xla::Collectives * collectives,
242+
xla::CollectivesRegistry::Get("gpu", "nvshmem"));
243+
xla::gpu::GpuCollectives* nvshmem_collectives =
244+
tsl::down_cast<xla::gpu::GpuCollectives*>(collectives);
245+
if (nvshmem_collectives == nullptr) {
246+
return absl::InternalError("Failed to get NVSHMEM collectives");
247+
}
248+
249+
return nvshmem_collectives;
250+
}
251+
252+
absl::StatusOr<void*> NcclCollectives::Allocate(uint64_t bytes) {
253+
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
254+
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
255+
return nvshmem_collectives->Allocate(bytes);
256+
}
257+
258+
void* ptr = nullptr;
259+
ncclResult_t res = ncclMemAlloc(&ptr, bytes);
260+
if (res != ncclSuccess) {
261+
return absl::InternalError(absl::StrFormat(
262+
"failed to allocate %s (%llu bytes) from device collective memory: %s, "
263+
"Last NCCL warning(error) log entry (may be unrelated): %s",
264+
tsl::strings::HumanReadableNumBytes(bytes), bytes,
265+
ncclGetErrorString(res), ncclGetLastError(nullptr)));
266+
}
267+
VLOG(2) << "Allocated collective memory " << ptr << " of " << bytes
268+
<< " bytes";
269+
return ptr;
270+
}
271+
272+
absl::Status NcclCollectives::Deallocate(void* location) {
273+
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
274+
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
275+
return nvshmem_collectives->Deallocate(location);
276+
}
277+
278+
ncclResult_t res = ncclMemFree(location);
279+
if (res != ncclSuccess) {
280+
return absl::InternalError(absl::StrFormat(
281+
"failed to free device collective memory at %p; result: %s, Last NCCL "
282+
"warning(error) log entry (may be unrelated): %s",
283+
location, ncclGetErrorString(res), ncclGetLastError(nullptr)));
284+
}
285+
286+
VLOG(2) << "Deallocated collective memory " << location;
287+
return absl::OkStatus();
288+
}
289+
290+
class NcclIdStore {
291+
public:
292+
NcclIdStore(int node_id,
293+
absl::flat_hash_map<GlobalDeviceId, int> device_to_node,
294+
std::shared_ptr<KeyValueStoreInterface> kv_store)
295+
: node_id_(node_id),
296+
device_to_node_(std::move(device_to_node)),
297+
kv_store_(std::move(kv_store)) {}
298+
299+
absl::StatusOr<CliqueId> GetNcclUniqueId(const CliqueKey& key) {
300+
auto* gpu_key = tsl::down_cast<const gpu::GpuCliqueKey*>(&key);
301+
if (gpu_key == nullptr) {
302+
return InvalidArgument("Expected GPU clique key");
303+
}
304+
305+
// The caller must ensure that threads calling this method concurrently have
306+
// unique keys, otherwise the global key-value store may hold the wrong
307+
// value.
308+
{
309+
absl::MutexLock lock(&mu_);
310+
auto it = cache_.find(*gpu_key);
311+
if (it != cache_.end()) {
312+
return it->second;
313+
}
314+
}
315+
CliqueId clique_id;
316+
int primary_node_id = device_to_node_.at(gpu_key->root_device());
317+
if (node_id_ == primary_node_id) {
318+
TF_ASSIGN_OR_RETURN(
319+
clique_id, gpu::GpuCollectives::Default()->CreateUniqueCliqueId());
320+
TF_RETURN_IF_ERROR(
321+
kv_store_->Set(gpu_key->ToString(), clique_id.ToString()));
322+
} else {
323+
TF_ASSIGN_OR_RETURN(
324+
std::string id_str,
325+
kv_store_->Get(gpu_key->ToString(), absl::Minutes(10)));
326+
clique_id = CliqueId(id_str);
327+
}
328+
absl::MutexLock lock(&mu_);
329+
auto result = cache_.emplace(*gpu_key, std::move(clique_id));
330+
TF_RET_CHECK(result.second) << "Unique ID already in cache.";
331+
return result.first->second;
332+
}
333+
334+
private:
335+
const int node_id_;
336+
const absl::flat_hash_map<GlobalDeviceId, int> device_to_node_;
337+
const std::shared_ptr<KeyValueStoreInterface> kv_store_;
338+
339+
absl::Mutex mu_;
340+
absl::flat_hash_map<gpu::GpuCliqueKey, CliqueId> cache_ ABSL_GUARDED_BY(mu_);
341+
};
342+
343+
absl::Status NcclCollectives::InitializeTopology(
344+
NcclCollectives::Topology topology) {
345+
if (xla::GetDebugOptionsFromFlags().xla_gpu_experimental_enable_nvshmem()) {
346+
TF_ASSIGN_OR_RETURN(auto* nvshmem_collectives, GetNvshmemCollectives());
347+
TF_RETURN_IF_ERROR(nvshmem_collectives->InitializeTopology(topology));
348+
}
349+
350+
if (topology.num_nodes > 1) {
351+
auto nccl_id_store = std::make_shared<NcclIdStore>(
352+
topology.node_id, topology.device_id_to_node_id,
353+
std::move(topology.kv_store));
354+
topology.gpu_executable_run_options->set_clique_id_callback(
355+
[nccl_id_store](const CliqueKey& key) {
356+
return nccl_id_store->GetNcclUniqueId(key);
357+
});
358+
}
359+
return absl::OkStatus();
360+
}
361+
230362
} // namespace xla::gpu
231363

232364
XLA_COLLECTIVES_REGISTER("gpu", "nccl", 1,

xla/backends/gpu/collectives/nccl_collectives.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ class NcclCollectives : public GpuCollectives {
5757
absl::StatusOr<std::vector<std::unique_ptr<Communicator>>> SplitCommunicators(
5858
absl::Span<const Communicator* const> comms, int32_t color,
5959
absl::Span<const RankId> keys, const Collectives::Config& config) final;
60+
61+
absl::StatusOr<void*> Allocate(uint64_t bytes) final;
62+
63+
absl::Status Deallocate(void* location) final;
64+
65+
absl::Status InitializeTopology(Topology topology) final;
6066
};
6167

6268
} // namespace xla::gpu

xla/backends/gpu/collectives/nvshmem_collectives.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ NvshmemCollectives* NvshmemCollectives::Default() {
5757
LOG(FATAL) << "Unsupported collectives implementation for NVSHMEM";
5858
}
5959

60+
absl::Status NvshmemCollectives::InitializeTopology(Topology topology) {
61+
SetEnvInfo(topology.node_id, topology.num_nodes,
62+
topology.device_count_per_process, topology.kv_store);
63+
return absl::OkStatus();
64+
}
65+
6066
void NvshmemCollectives::SetEnvInfo(
6167
int process_id, size_t num_processes, size_t device_count_per_process,
6268
std::weak_ptr<KeyValueStoreInterface> kv_store) {

0 commit comments

Comments
 (0)