3333#include " mlir/InitAllPasses.h"
3434#include " mlir/Pass/PassRegistry.h"
3535#include " mlir/Transforms/Passes.h"
36- #include " shardy/dialect/sdy/ir/dialect.h"
3736#include " src/enzyme_ad/jax/Dialect/Dialect.h"
3837#include " src/enzyme_ad/jax/Implementations/XLADerivatives.h"
3938#include " src/enzyme_ad/jax/Passes/Passes.h"
6968
7069#include " llvm-c/TargetMachine.h"
7170
71+ // shardy
72+ #include " shardy/dialect/sdy/ir/dialect.h"
73+ #include " shardy/integrations/c/attributes.h"
74+
7275// IFRT
7376#include " xla/python/ifrt/array.h"
7477#include " xla/python/ifrt/client.h"
@@ -530,6 +533,18 @@ extern "C" void BufferToHost(PjRtBuffer *buffer, void *data) {
530533
531534extern " C" void FreeClient (PjRtClient *client) { delete client; }
532535
536+ extern " C" int64_t PjRtDeviceGetLocalDeviceId (PjRtDevice *device) {
537+ return device->local_device_id ().value ();
538+ }
539+
540+ extern " C" int64_t PjRtDeviceGetGlobalDeviceId (PjRtDevice *device) {
541+ return device->global_device_id ().value ();
542+ }
543+
544+ extern " C" int64_t PjRtDeviceGetLocalHardwareId (PjRtDevice *device) {
545+ return device->local_hardware_id ().value ();
546+ }
547+
533548#include " xla/service/custom_call_target_registry.h"
534549extern " C" void RegisterCustomCallTarget (const char *name, void *address,
535550 const char *platform) {
@@ -579,22 +594,30 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
579594}
580595
581596/* Note that this */
582- extern " C" xla::PjRtLoadedExecutable *
583- ClientCompile (PjRtClient *client, MlirModule cmod, int device_ordinal ,
584- int num_replicas, int num_partitions ,
585- bool use_shardy_partitioner ) {
597+ extern " C" xla::PjRtLoadedExecutable *ClientCompile (PjRtClient *client,
598+ MlirModule cmod ,
599+ int *global_ordinals ,
600+ int num_global_ordinals ) {
586601 auto program =
587602 std::make_unique<xla::ifrt::HloProgram>(cast<ModuleOp>(*unwrap (cmod)));
588603
589604 CompileOptions options;
590605
591- if (device_ordinal >= 0 ) {
592- options.executable_build_options .set_device_ordinal (device_ordinal);
606+ // https://github.com/pytorch/xla/blob/8b2414094578e829b99a8383877c86d357eeb682/torch_xla/csrc/runtime/pjrt_computation_client.cc#L601
607+ int device_count = client->addressable_device_count ();
608+
609+ options.executable_build_options .set_num_replicas (device_count);
610+ options.executable_build_options .set_num_partitions (1 );
611+
612+ xla::DeviceAssignment device_assignment (device_count, 1 );
613+ for (int64_t device_id = 0 ; device_id < num_global_ordinals; ++device_id) {
614+ int ordinal = global_ordinals[device_id];
615+ if (ordinal < 0 ) {
616+ continue ;
617+ }
618+ device_assignment (ordinal, 0 ) = device_id;
593619 }
594- options.executable_build_options .set_num_replicas (num_replicas);
595- options.executable_build_options .set_num_partitions (num_partitions);
596- options.executable_build_options .set_use_shardy_partitioner (
597- use_shardy_partitioner);
620+ options.executable_build_options .set_device_assignment (device_assignment);
598621
599622 auto addressable_devices = client->addressable_devices ();
600623 if (!addressable_devices.empty ()) {
@@ -605,8 +628,7 @@ ClientCompile(PjRtClient *client, MlirModule cmod, int device_ordinal,
605628 assert (device_ordinal < addressable_devices.size ());
606629 auto stats = addressable_devices[device_ordinal]->GetAllocatorStats ();
607630 if (stats.ok () && stats->bytes_limit ) {
608- options.executable_build_options .set_device_memory_size (
609- *stats->bytes_limit );
631+ options.executable_build_options .set_device_memory_size (*stats->bytes_limit );
610632 }
611633 }
612634 auto exec =
@@ -623,12 +645,72 @@ extern "C" uint8_t FutureIsReady(FutureType *Future) {
623645
624646extern " C" void FutureAwait (FutureType *Future) { Future->Await (); }
625647
648+ extern " C" void XLAExecuteSharded (xla::PjRtLoadedExecutable *exec, int num_args,
649+ PjRtBuffer **op_args, PjRtDevice *device,
650+ uint8_t *is_arg_donatable, int num_results,
651+ PjRtBuffer **op_results, uint8_t *futures,
652+ FutureType **future_results) {
653+ // Create a vector of PjRtBuffer* from the input array.
654+ std::vector<PjRtBuffer *> argument_handles (op_args, op_args + num_args);
655+
656+ // Set up execution options.
657+ ExecuteOptions options;
658+ for (size_t i = 0 ; i < num_args; i++) {
659+ if (!is_arg_donatable[i]) {
660+ options.non_donatable_input_indices .insert (static_cast <int >(i));
661+ }
662+ }
663+ options.untuple_result = true ;
664+
665+ // Optional future to hold asynchronous execution results.
666+ std::optional<PjRtFuture<>> returned_future;
667+
668+ auto results = MyValueOrThrow (
669+ exec->ExecuteSharded (argument_handles,
670+ device, options, returned_future, /* fill_future=*/ true ));
671+
672+ // Validate the number of results.
673+ if (results.size () != num_results) {
674+ llvm::errs () << " Error: results.size()=" << results.size ()
675+ << " does not match num_results=" << num_results << " \n " ;
676+ std::abort (); // Terminate if the number of results is incorrect.
677+ }
678+
679+ // Handle futures if they are returned.
680+ if (returned_future.has_value ()) {
681+ *futures = true ;
682+ for (size_t i = 0 ; i < num_results; i++) {
683+ future_results[i] = new FutureType (*returned_future);
684+ }
685+ } else {
686+ *futures = false ;
687+ }
688+
689+ // Release the results into the output array.
690+ for (size_t i = 0 ; i < num_results; i++) {
691+ op_results[i] = results[i].release ();
692+ }
693+ }
694+
626695extern " C" void XLAExecute (xla::PjRtLoadedExecutable *exec, int num_args,
627696 PjRtBuffer **op_args, uint8_t *is_arg_donatable,
628697 int num_results, PjRtBuffer **op_results,
629698 uint8_t *futures, FutureType **future_results) {
630- std::vector<std::vector<PjRtBuffer *>> argument_handles;
631- argument_handles.emplace_back (op_args, op_args + num_args);
699+ auto client = exec->client ();
700+ int num_devices = client->addressable_device_count ();
701+
702+ // Ensure argument_handles is structured as num_devices x num_args
703+ std::vector<std::vector<PjRtBuffer *>> argument_handles (num_devices);
704+
705+ // Distribute arguments across devices
706+ for (int device_idx = 0 ; device_idx < num_devices; ++device_idx) {
707+ argument_handles[device_idx].reserve (num_args);
708+ for (int arg_idx = 0 ; arg_idx < num_args; ++arg_idx) {
709+ // Assuming op_args is a flat array of size num_devices * num_args
710+ // where arguments for each device are contiguous
711+ argument_handles[device_idx].push_back (op_args[device_idx * num_args + arg_idx]);
712+ }
713+ }
632714
633715 ExecuteOptions options;
634716
@@ -637,31 +719,43 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int num_args,
637719 options.non_donatable_input_indices .insert ((int )i);
638720 }
639721 options.untuple_result = true ;
722+
640723 std::optional<std::vector<FutureType>> returned_futures;
641724 auto results = MyValueOrThrow (
642725 exec->Execute (static_cast <absl::Span<const std::vector<PjRtBuffer *>>>(
643726 argument_handles),
644727 options, returned_futures));
645728
646- assert (results.size () == 1 );
729+ assert (results.size () == num_devices );
647730
648- if (results[0 ].size () != num_results) {
649- llvm::errs () << " results.size()=" << results.size ()
650- << " num_results=" << num_results << " \n " ;
731+ for (int device_idx = 0 ; device_idx < num_devices; ++device_idx) {
732+ if (results[device_idx].size () != num_results) {
733+ llvm::errs () << " results[" << device_idx << " ].size()=" << results[device_idx].size ()
734+ << " num_results=" << num_results << " \n " ;
735+ }
736+ assert (results[device_idx].size () == num_results);
651737 }
652- assert (results[0 ].size () == num_results);
738+
739+ // Handle returned futures
653740 if (returned_futures) {
654741 *futures = true ;
655- assert (returned_futures->size () == num_results);
656- for (size_t i = 0 ; i < num_results; i++) {
657- future_results[i] = new FutureType ((*returned_futures)[i]);
742+ assert (returned_futures->size () == num_devices * num_results);
743+ for (int device_idx = 0 ; device_idx < num_devices; ++device_idx) {
744+ for (int result_idx = 0 ; result_idx < num_results; ++result_idx) {
745+ int flat_index = device_idx * num_results + result_idx;
746+ future_results[flat_index] = new FutureType ((*returned_futures)[flat_index]);
747+ }
658748 }
659749 } else {
660750 *futures = false ;
661751 }
662752
663- for (size_t i = 0 ; i < num_results; i++) {
664- op_results[i] = results[0 ][i].release ();
753+ // Copy results into the output buffers
754+ for (int device_idx = 0 ; device_idx < num_devices; ++device_idx) {
755+ for (int result_idx = 0 ; result_idx < num_results; ++result_idx) {
756+ int flat_index = device_idx * num_results + result_idx;
757+ op_results[flat_index] = results[device_idx][result_idx].release ();
758+ }
665759 }
666760}
667761
0 commit comments