Skip to content

Commit 5c67d0a

Browse files
avik-palgiordano
andauthored
feat: the big jll PR (#653)
* chore: bigggg jll * fix: restore the old API * fix: add shardy c headers * fix: section for import * Update deps/ReactantExtra/BUILD Co-authored-by: Mosè Giordano <765740+giordano@users.noreply.github.com> * Update deps/ReactantExtra/BUILD Co-authored-by: Mosè Giordano <765740+giordano@users.noreply.github.com> --------- Co-authored-by: Mosè Giordano <765740+giordano@users.noreply.github.com>
1 parent b5f9ecd commit 5c67d0a

File tree

3 files changed

+132
-30
lines changed

3 files changed

+132
-30
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 119 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
#include "mlir/InitAllPasses.h"
3434
#include "mlir/Pass/PassRegistry.h"
3535
#include "mlir/Transforms/Passes.h"
36-
#include "shardy/dialect/sdy/ir/dialect.h"
3736
#include "src/enzyme_ad/jax/Dialect/Dialect.h"
3837
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
3938
#include "src/enzyme_ad/jax/Passes/Passes.h"
@@ -69,6 +68,10 @@
6968

7069
#include "llvm-c/TargetMachine.h"
7170

71+
// shardy
72+
#include "shardy/dialect/sdy/ir/dialect.h"
73+
#include "shardy/integrations/c/attributes.h"
74+
7275
// IFRT
7376
#include "xla/python/ifrt/array.h"
7477
#include "xla/python/ifrt/client.h"
@@ -530,6 +533,18 @@ extern "C" void BufferToHost(PjRtBuffer *buffer, void *data) {
530533

531534
extern "C" void FreeClient(PjRtClient *client) { delete client; }
532535

536+
extern "C" int64_t PjRtDeviceGetLocalDeviceId(PjRtDevice *device) {
537+
return device->local_device_id().value();
538+
}
539+
540+
extern "C" int64_t PjRtDeviceGetGlobalDeviceId(PjRtDevice *device) {
541+
return device->global_device_id().value();
542+
}
543+
544+
extern "C" int64_t PjRtDeviceGetLocalHardwareId(PjRtDevice *device) {
545+
return device->local_hardware_id().value();
546+
}
547+
533548
#include "xla/service/custom_call_target_registry.h"
534549
extern "C" void RegisterCustomCallTarget(const char *name, void *address,
535550
const char *platform) {
@@ -579,22 +594,30 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
579594
}
580595

581596
/* Note that this */
582-
extern "C" xla::PjRtLoadedExecutable *
583-
ClientCompile(PjRtClient *client, MlirModule cmod, int device_ordinal,
584-
int num_replicas, int num_partitions,
585-
bool use_shardy_partitioner) {
597+
extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client,
598+
MlirModule cmod,
599+
int *global_ordinals,
600+
int num_global_ordinals) {
586601
auto program =
587602
std::make_unique<xla::ifrt::HloProgram>(cast<ModuleOp>(*unwrap(cmod)));
588603

589604
CompileOptions options;
590605

591-
if (device_ordinal >= 0) {
592-
options.executable_build_options.set_device_ordinal(device_ordinal);
606+
// https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L601
607+
int device_count = client->addressable_device_count();
608+
609+
options.executable_build_options.set_num_replicas(device_count);
610+
options.executable_build_options.set_num_partitions(1);
611+
612+
xla::DeviceAssignment device_assignment(device_count, 1);
613+
for (int64_t device_id = 0; device_id < num_global_ordinals; ++device_id) {
614+
int ordinal = global_ordinals[device_id];
615+
if (ordinal < 0) {
616+
continue;
617+
}
618+
device_assignment(ordinal, 0) = device_id;
593619
}
594-
options.executable_build_options.set_num_replicas(num_replicas);
595-
options.executable_build_options.set_num_partitions(num_partitions);
596-
options.executable_build_options.set_use_shardy_partitioner(
597-
use_shardy_partitioner);
620+
options.executable_build_options.set_device_assignment(device_assignment);
598621

599622
auto addressable_devices = client->addressable_devices();
600623
if (!addressable_devices.empty()) {
@@ -605,8 +628,7 @@ ClientCompile(PjRtClient *client, MlirModule cmod, int device_ordinal,
605628
assert(device_ordinal < addressable_devices.size());
606629
auto stats = addressable_devices[device_ordinal]->GetAllocatorStats();
607630
if (stats.ok() && stats->bytes_limit) {
608-
options.executable_build_options.set_device_memory_size(
609-
*stats->bytes_limit);
631+
options.executable_build_options.set_device_memory_size(*stats->bytes_limit);
610632
}
611633
}
612634
auto exec =
@@ -623,12 +645,72 @@ extern "C" uint8_t FutureIsReady(FutureType *Future) {
623645

624646
extern "C" void FutureAwait(FutureType *Future) { Future->Await(); }
625647

648+
extern "C" void XLAExecuteSharded(xla::PjRtLoadedExecutable *exec, int num_args,
649+
PjRtBuffer **op_args, PjRtDevice *device,
650+
uint8_t *is_arg_donatable, int num_results,
651+
PjRtBuffer **op_results, uint8_t *futures,
652+
FutureType **future_results) {
653+
// Create a vector of PjRtBuffer* from the input array.
654+
std::vector<PjRtBuffer *> argument_handles(op_args, op_args + num_args);
655+
656+
// Set up execution options.
657+
ExecuteOptions options;
658+
for (size_t i = 0; i < num_args; i++) {
659+
if (!is_arg_donatable[i]) {
660+
options.non_donatable_input_indices.insert(static_cast<int>(i));
661+
}
662+
}
663+
options.untuple_result = true;
664+
665+
// Optional future to hold asynchronous execution results.
666+
std::optional<PjRtFuture<>> returned_future;
667+
668+
auto results = MyValueOrThrow(
669+
exec->ExecuteSharded(argument_handles,
670+
device, options, returned_future, /*fill_future=*/true));
671+
672+
// Validate the number of results.
673+
if (results.size() != num_results) {
674+
llvm::errs() << "Error: results.size()=" << results.size()
675+
<< " does not match num_results=" << num_results << "\n";
676+
std::abort(); // Terminate if the number of results is incorrect.
677+
}
678+
679+
// Handle futures if they are returned.
680+
if (returned_future.has_value()) {
681+
*futures = true;
682+
for (size_t i = 0; i < num_results; i++) {
683+
future_results[i] = new FutureType(*returned_future);
684+
}
685+
} else {
686+
*futures = false;
687+
}
688+
689+
// Release the results into the output array.
690+
for (size_t i = 0; i < num_results; i++) {
691+
op_results[i] = results[i].release();
692+
}
693+
}
694+
626695
extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args,
627696
PjRtBuffer **op_args, uint8_t *is_arg_donatable,
628697
int num_results, PjRtBuffer **op_results,
629698
uint8_t *futures, FutureType **future_results) {
630-
std::vector<std::vector<PjRtBuffer *>> argument_handles;
631-
argument_handles.emplace_back(op_args, op_args + num_args);
699+
auto client = exec->client();
700+
int num_devices = client->addressable_device_count();
701+
702+
// Ensure argument_handles is structured as num_devices x num_args
703+
std::vector<std::vector<PjRtBuffer *>> argument_handles(num_devices);
704+
705+
// Distribute arguments across devices
706+
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
707+
argument_handles[device_idx].reserve(num_args);
708+
for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) {
709+
// Assuming op_args is a flat array of size num_devices * num_args
710+
// where arguments for each device are contiguous
711+
argument_handles[device_idx].push_back(op_args[device_idx * num_args + arg_idx]);
712+
}
713+
}
632714

633715
ExecuteOptions options;
634716

@@ -637,31 +719,43 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args,
637719
options.non_donatable_input_indices.insert((int)i);
638720
}
639721
options.untuple_result = true;
722+
640723
std::optional<std::vector<FutureType>> returned_futures;
641724
auto results = MyValueOrThrow(
642725
exec->Execute(static_cast<absl::Span<const std::vector<PjRtBuffer *>>>(
643726
argument_handles),
644727
options, returned_futures));
645728

646-
assert(results.size() == 1);
729+
assert(results.size() == num_devices);
647730

648-
if (results[0].size() != num_results) {
649-
llvm::errs() << " results.size()=" << results.size()
650-
<< " num_results=" << num_results << "\n";
731+
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
732+
if (results[device_idx].size() != num_results) {
733+
llvm::errs() << " results[" << device_idx << "].size()=" << results[device_idx].size()
734+
<< " num_results=" << num_results << "\n";
735+
}
736+
assert(results[device_idx].size() == num_results);
651737
}
652-
assert(results[0].size() == num_results);
738+
739+
// Handle returned futures
653740
if (returned_futures) {
654741
*futures = true;
655-
assert(returned_futures->size() == num_results);
656-
for (size_t i = 0; i < num_results; i++) {
657-
future_results[i] = new FutureType((*returned_futures)[i]);
742+
assert(returned_futures->size() == num_devices * num_results);
743+
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
744+
for (int result_idx = 0; result_idx < num_results; ++result_idx) {
745+
int flat_index = device_idx * num_results + result_idx;
746+
future_results[flat_index] = new FutureType((*returned_futures)[flat_index]);
747+
}
658748
}
659749
} else {
660750
*futures = false;
661751
}
662752

663-
for (size_t i = 0; i < num_results; i++) {
664-
op_results[i] = results[0][i].release();
753+
// Copy results into the output buffers
754+
for (int device_idx = 0; device_idx < num_devices; ++device_idx) {
755+
for (int result_idx = 0; result_idx < num_results; ++result_idx) {
756+
int flat_index = device_idx * num_results + result_idx;
757+
op_results[flat_index] = results[device_idx][result_idx].release();
758+
}
665759
}
666760
}
667761

deps/ReactantExtra/BUILD

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,9 @@ cc_library(
360360
],
361361

362362
) + [
363-
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp",
364-
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
363+
"@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp",
364+
"@enzyme_ad//src/enzyme_ad/jax:gpu.cc",
365+
"@enzyme_ad//src/enzyme_ad/jax:cpu.cc",
365366
# "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc",
366367
# "@xla//xla:xla.pb.cc",
367368
"@xla//xla:xla_data.pb.cc",
@@ -448,7 +449,13 @@ cc_library(
448449
"-Wl,-exported_symbol,_ProfilerActivityEnd",
449450
"-Wl,-exported_symbol,_ReactantFuncSetArgAttr",
450451
"-Wl,-exported_symbol,_ReactantCudaDriverGetVersion",
451-
"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions"
452+
"-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions",
453+
"-Wl,-exported_symbol,_PjRtDeviceGetLocalDeviceId",
454+
"-Wl,-exported_symbol,_PjRtDeviceGetGlobalDeviceId",
455+
"-Wl,-exported_symbol,_PjRtDeviceGetLocalHardwareId",
456+
"-Wl,-exported_symbol,_XLAExecuteSharded",
457+
"-Wl,-exported_symbol,_ClientGetPlatformName",
458+
"-Wl,-exported_symbol,_RegisterEnzymeXLACPUHandler",
452459
]}),
453460
deps = [
454461
"@enzyme//:EnzymeMLIR",
@@ -550,13 +557,15 @@ cc_library(
550557
"@xla//xla/mlir/utils:type_util",
551558
"@stablehlo//:stablehlo_capi_objects",
552559
"@stablehlo//:chlo_capi_objects",
560+
"@shardy//shardy/integrations/c:sdy_capi_objects",
553561
"@com_google_absl//absl/hash:hash",
554562
"@com_google_absl//absl/log:initialize",
555563
"@com_google_absl//absl/log:globals",
556564
"@llvm-project//mlir:CAPIIRObjects",
557565
"@llvm-project//mlir:CAPILLVMObjects",
558566
"@jax//jaxlib/mosaic:tpu_dialect_capi_objects",
559567
"@jax//jaxlib/triton:triton_dialect_capi_objects",
568+
"@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl",
560569
] + select({
561570
"@xla//xla/tsl:is_cuda_enabled_and_oss":[
562571
"@xla//xla/stream_executor/cuda:all_runtime",
@@ -568,7 +577,6 @@ cc_library(
568577
"@xla//xla/backends/profiler/gpu:device_tracer",
569578
],
570579
"//conditions:default": [
571-
"@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl",
572580
],
573581
}) + if_rocm([
574582
"@xla//xla/service/gpu:amdgpu_compiler",

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "c38ca3f187ef11de6b2292f3cc55c5eb60530d15"
12+
ENZYMEXLA_COMMIT = "d89468ed883ca18c04346eec10f784bbe2b754fc"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(

0 commit comments

Comments
 (0)