diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index fc6d837ecf..f1a3b4c929 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -73,6 +73,19 @@ #include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_executable.h" +// CPU collectives +#include "xla/backends/cpu/collectives/mpi_collectives.h" +#if defined(__linux__) +#include "gloo/transport/tcp/attr.h" +#include "gloo/transport/tcp/device.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" +#include "xla/backends/cpu/collectives/gloo_kv_store.h" +#elif defined(__APPLE__) +#include "gloo/transport/uv/device.h" +#include "xla/backends/cpu/collectives/gloo_collectives.h" +#include "xla/backends/cpu/collectives/gloo_kv_store.h" +#endif // defined(__linux__) + // shardy #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/integrations/c/attributes.h" @@ -83,7 +96,6 @@ // IFRT #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/attribute_map.h" -#include "xla/python/ifrt/basic_device_list.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" @@ -113,6 +125,7 @@ #include "xla/python/pjrt_ifrt/pjrt_memory.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/pjrt_tuple.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/python/pjrt_ifrt/xla_sharding.h" // IFRT - Proxy (RPC) @@ -339,18 +352,26 @@ extern "C" void ProfilerServerStop(tsl::profiler::ProfilerServer *server) { delete server; } -extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id, - int num_nodes) { +PjRtClient *MakeCPUClientInternal( + uint8_t asynchronous, int node_id, + std::optional> collectives) { CpuClientOptions options; - // options.kv_store = "etcd"; + options.process_id = node_id; - // options.num_nodes = num_nodes; - // options.collectives = num_nodes; options.asynchronous = asynchronous != 0; + + if (collectives.has_value()) + options.collectives = collectives.value(); + auto client = MyValueOrThrow(GetTfrtCpuClient(options)); return client.release(); } +extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id) { + std::optional> collectives; + return MakeCPUClientInternal(asynchronous, node_id, collectives); +} + // xla/python/xla.cc 390 extern "C" PjRtClient * MakeGPUClient(int node_id, int num_nodes, int *allowed_devices, @@ -1165,28 +1186,6 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, return wrap(entryFn); } -extern "C" HeldPjRtClient * -pjrt_make_cpu_client_shared(uint8_t asynchronous, int node_id, int num_nodes) { - PjRtClient *client = MakeCPUClient(asynchronous, node_id, num_nodes); - return reactant::capture(std::shared_ptr(client)); -} - -extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared( - int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices, - double memory_fraction, bool preallocate, const char *platform_name, - const char **error, void *distributed_runtime_client) { - PjRtClient *client = MakeGPUClient( - node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction, - preallocate, platform_name, error, distributed_runtime_client); - return reactant::capture(std::shared_ptr(client)); -} - -extern "C" HeldPjRtClient *pjrt_make_tpu_client_shared(const char *tpu_path, - const char **error) { - PjRtClient *client = MakeTPUClient(tpu_path, error); - return reactant::capture(std::shared_ptr(client)); -} - extern "C" void pjrt_client_dtor(HeldPjRtClient *client) { delete client; } extern "C" int pjrt_client_num_devices(HeldPjRtClient *client) { @@ -1263,11 +1262,6 @@ extern "C" HeldPjRtClient *pjrt_buffer_get_client(HeldPjRtBuffer *buffer) { std::shared_ptr(buffer->ptr()->client())); } -extern "C" ifrt::Client *ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client) { - xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()}; - return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release(); -} - extern "C" void ifrt_client_dtor(ifrt::Client *client) { delete client; } // generic version, but IFRT-PjRt backend only supports SingleDeviceSharding @@ -1302,21 +1296,21 @@ extern "C" HeldIfrtArray *ifrt_client_make_single_shard_array_from_host_buffer( } // all arrays are assumed to have same DType +// each process only provides arrays for its own addressable devices extern "C" HeldIfrtArray *ifrt_client_assemble_array_from_single_shards( - ifrt::Client *client, int ndims, const int64_t *c_shape, - HeldValue> *sharding, int narrays, - HeldIfrtArray **c_arrays, int c_semantics) { - auto shape = ifrt::Shape(absl::Span(c_shape, ndims)); + ifrt::Client *client, int32_t ndims, const int64_t *c_shape, + HeldValue> *sharding, int32_t narrays, + HeldIfrtArray **c_arrays, int32_t c_semantics) { + ifrt::Shape shape = ifrt::Shape(absl::Span(c_shape, ndims)); std::vector> arrays; for (int i = 0; i < narrays; i++) { arrays.emplace_back(c_arrays[i]->obj()); } - auto semantics = static_cast(c_semantics); return reactant::capture( MyValueOrThrow(client->AssembleArrayFromSingleDeviceArrays( - shape, sharding->obj(), - static_cast>>(arrays), - semantics))); + shape, sharding->obj(), absl::MakeSpan(arrays), + static_cast(c_semantics), + ifrt::SingleDeviceShardSemantics::kAddressableShards))); } // we should deprecate this because is IFRT-PjRt specific @@ -1328,48 +1322,16 @@ ifrt_pjrt_array_create(ifrt::PjRtClient *client, MyValueOrThrow(xla::ifrt::PjRtArray::Create(client, buffer->obj())))); } -// TODO how do we compile for other backends? -extern "C" xla::ifrt::LoadedExecutable * -ifrt_pjrt_compile(ifrt::PjRtClient *client, MlirModule cmod, int64_t device_id, - bool is_sharded, const int64_t *mesh_ids, - int64_t num_mesh_ids, const char *xla_gpu_cuda_data_dir) { - CompileOptions options = GenerateCompileOptions( - device_id, is_sharded, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir); - - mlir::ModuleOp cmod_op = cast(*unwrap(cmod)); - if (is_sharded) { - // https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460 - auto status = xla::ExportShardyForHloRoundTrip(cmod_op); - if (!status.ok()) { - ReactantThrowError(status.ToString().c_str()); - } - } - - // TODO can't create LoadedExecutable from mlir::ModuleOp on IFRT-proxy - // backend - auto exec = MyValueOrThrow(xla::ifrt::PjRtLoadedExecutable::Create( - client, cmod_op, options, - std::vector>())); - return exec.release(); -} - // we might me interested in the `Compiler::Compile` method variant that accepts // `Topology` extern "C" xla::ifrt::LoadedExecutable * ifrt_compile(ifrt::Client *client, MlirModule cmod, int64_t device_id, bool is_sharded, const int64_t *mesh_ids, int64_t num_mesh_ids, const char *xla_gpu_cuda_data_dir) { - // TODO we need a `xla::ifrt::CompileOptions` but this is - // `xla::CompileOptions` auto options = std::make_unique( - // GenerateCompileOptions( - // device_id, - // is_sharded, - // mesh_ids, - // num_mesh_ids, - // xla_gpu_cuda_data_dir - // ) - // ); - auto options = std::make_unique(); + xla::CompileOptions compile_options = GenerateCompileOptions( + device_id, is_sharded, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir); + auto options = std::make_unique( + xla::ifrt::XlaCompileOptions(compile_options)); mlir::ModuleOp cmod_op = cast(*unwrap(cmod)); if (is_sharded) { @@ -1381,7 +1343,7 @@ ifrt_compile(ifrt::Client *client, MlirModule cmod, int64_t device_id, } auto program = - std::make_unique(xla::ifrt::IfrtIRProgram(cmod_op)); + std::make_unique(xla::ifrt::HloProgram(cmod_op)); auto compiler = client->GetDefaultCompiler(); return MyValueOrThrow( @@ -1396,45 +1358,6 @@ ifrt_pjrt_loaded_executable_dtor(xla::ifrt::PjRtLoadedExecutable *exec) { extern "C" void ifrt_array_dtor(HeldIfrtArray *array) { delete array; } -extern "C" void ifrt_loaded_executable_execute( - ifrt::LoadedExecutable *exec, int num_args, - HeldValue> **op_args, - uint8_t *is_arg_donatable, int num_results, - HeldValue> **op_results, uint8_t *futures, - FutureType **status) { - std::vector> args; - for (int i = 0; i < num_args; i++) { - args.emplace_back(op_args[i]->obj()); - } - - ifrt::ExecuteOptions options; - for (size_t i = 0; i < num_args; i++) { - if (!is_arg_donatable[i]) { - options.non_donatable_input_indices.insert(static_cast(i)); - } - } - options.fill_status = true; - - auto result = MyValueOrThrow(exec->Execute( - static_cast>>(args), - options, /* devices */ std::nullopt)); - - if (result.outputs.size() != num_results) { - llvm::errs() << "Error: results.size()=" << result.outputs.size() - << " does not match num_results=" << num_results << "\n"; - std::abort(); // Terminate if the number of results is incorrect. - } - - // there is only 1 status and is valid because we set `options.fill_status = - // true` - *futures = true; - *status = new FutureType(result.status); - - for (int i = 0; i < num_results; i++) { - op_results[i] = reactant::capture(result.outputs[i]); - } -} - // in principle, use ArrayCopySemantics::kAlwaysCopy (=0) extern "C" FutureType * ifrt_CopyArrayToHostBuffer(HeldIfrtArray *array, void *data, @@ -1465,43 +1388,65 @@ FreeHloModule(HeldValue> *hlo_module) { #pragma region IfRtClient -// right now only making it available for TPU -// in the future, we would like this for CPU and GPU PjRt backends too -extern "C" ifrt::proxy::GrpcServer * -ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu( - const char *c_address, const char *tpu_path, const char **error) { - std::string address = c_address; - - // taken from `MakeTPUClient` - std::string tpu_library_path; - if (auto path = llvm::sys::Process::GetEnv(kEnvTpuLibraryPath)) { - tpu_library_path = *path; - } else if (tpu_path) { - tpu_library_path = std::string(tpu_path); - } else { - *error = "Could not find TPU path"; - return nullptr; - } +// XXX: Bring back with the correct API +// extern "C" ifrt::proxy::GrpcServer * +// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu( +// const char *c_address, uint8_t asynchronous, int node_id) { +// std::string address = c_address; + +// return MyValueOrThrow( +// ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( +// address, +// [asynchronous, +// node_id]() -> absl::StatusOr> +// { +// auto pjrt_client = std::shared_ptr( +// MakeCPUClient(asynchronous, node_id)); +// return std::shared_ptr( +// xla::ifrt::PjRtClient::Create(pjrt_client).release()); +// })) +// .release(); +// } - const PJRT_Api *pluginLoad = - LoadPjrtPlugin("tpu", tpu_library_path.c_str(), error); - if (pluginLoad == nullptr) - return nullptr; - auto tpu_status = InitializePjrtPlugin("tpu", error); - if (tpu_status) - return nullptr; +// extern "C" ifrt::proxy::GrpcServer * +// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu( +// int node_id, int num_nodes, int *allowed_devices, int +// num_allowed_devices, double memory_fraction, bool preallocate, const char +// *platform_name, const char **error) { +// return MyValueOrThrow( +// ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( +// std::string(), +// [node_id, num_nodes, allowed_devices, num_allowed_devices, +// memory_fraction, preallocate, platform_name, +// error]() -> absl::StatusOr> { +// auto pjrt_client = +// std::shared_ptr(MakeGPUClient( +// node_id, num_nodes, allowed_devices, +// num_allowed_devices, memory_fraction, preallocate, +// platform_name, error)); +// return std::shared_ptr( +// xla::ifrt::PjRtClient::Create(pjrt_client).release()); +// })) +// .release(); +// } - return MyValueOrThrow( - xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( - address, - [](xla::ifrt::AttributeMap initialization_data) - -> absl::StatusOr> { - auto pjrt_client = - std::shared_ptr(GetCApiClient("TPU")); - return xla::ifrt::PjRtClient::Create(std::move(pjrt_client)); - })) - .release(); -} +// extern "C" ifrt::proxy::GrpcServer * +// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu( +// const char *c_address, const char *tpu_path, const char **error) { +// std::string address = c_address; +// +// return MyValueOrThrow( +// xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory( +// address, +// [](xla::ifrt::AttributeMap initialization_data) -> +// absl::StatusOr> { +// auto pjrt_client = +// std::shared_ptr(GetCApiClient("TPU")); +// return +// xla::ifrt::PjRtClient::Create(std::move(pjrt_client)); +// })) +// .release(); +// } extern "C" void ifrt_proxy_grpc_server_dtor(ifrt::proxy::GrpcServer *server) { delete server; @@ -1530,29 +1475,143 @@ ifrt_proxy_create_client(const char *c_proxy_server_address, nullptr, // callback `on_connection_update` }; return MyValueOrThrow( - ifrt::proxy::CreateClient(c_proxy_server_address, options)) + ifrt::proxy::CreateClient(proxy_server_address, options)) .release(); } -extern "C" ifrt::Client *ifrt_make_cpu_client(uint8_t asynchronous, int node_id, - int num_nodes) { - return ifrt_pjrt_make_client( - pjrt_make_cpu_client_shared(asynchronous, node_id, num_nodes)); +extern "C" ifrt::Client *ifrt_pjrt_make_client( + HeldPjRtClient *pjrt_client, int node_id, int num_nodes, + void *distributed_runtime_client, const char **error, + std::string key_prefix, + std::optional> kv_store) { + ifrt::PjRtClient::CreateOptions options; + options.pjrt_client = pjrt_client->obj(); + + if (num_nodes > 1) { + if (distributed_runtime_client == nullptr) { + *error = + "`distributed_runtime_client` must be non-null if `num_nodes` > 1"; + return nullptr; + } + if (kv_store.has_value()) { + options.kv_store = kv_store.value(); + } else { + auto typed_distributed_runtime_client = static_cast< + HeldValue> *>( + distributed_runtime_client); + options.kv_store = GetDistributedKeyValueStore( + typed_distributed_runtime_client->obj(), key_prefix); + } + } + + options.process_id = node_id; + options.num_processes = num_nodes; + + return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release(); +} + +extern "C" HeldPjRtClient *pjrt_make_cpu_client_shared( + uint8_t asynchronous, int node_id, + std::optional> collectives) { + PjRtClient *client = + MakeCPUClientInternal(asynchronous, node_id, collectives); + return reactant::capture(std::shared_ptr(client)); } +const char *const kMpiTrampolineLibEnv = "MPITRAMPOLINE_LIB"; + extern "C" ifrt::Client * -ifrt_make_gpu_client(int node_id, int num_nodes, int *allowed_devices, - int num_allowed_devices, double memory_fraction, - bool preallocate, const char *platform_name, - const char **error, void *distributed_runtime_client) { - return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared( +ifrt_make_pjrt_cpu_client(uint8_t asynchronous, int node_id, int num_nodes, + void *distributed_runtime_client, + const char **error) { + std::optional> collectives; + std::optional> kv_store; + + if (distributed_runtime_client != nullptr) { + auto mpi_trampoline_path = llvm::sys::Process::GetEnv(kMpiTrampolineLibEnv); + if (mpi_trampoline_path) { + // Use MPI + // TODO: How do we Finalize?? + auto mpi_collectives = std::make_shared(); + collectives = mpi_collectives; + static_cast(mpi_collectives.get())->Init(); + } else { + // Use Gloo + auto typed_distributed_runtime_client = static_cast< + HeldValue> *>( + distributed_runtime_client); + kv_store = + GetDistributedKeyValueStore(typed_distributed_runtime_client->obj(), + /*key_prefix=*/"cpu:"); + auto gloo_kv_store = + std::make_unique(kv_store.value()); +#if defined(__linux__) + auto tcp_attrs = gloo::transport::tcp::attr(); + auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs); + collectives = std::make_shared( + std::move(gloo_kv_store), std::move(tcp_device)); +#elif defined(__APPLE__) + auto uv_attrs = gloo::transport::uv::attr(); + auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs); + collectives = std::make_shared( + std::move(gloo_kv_store), std::move(uv_device)); +#else + ReactantThrowError( + "Gloo TCP Collectives only implemented for linux and macos"); +#endif + } + } + + HeldPjRtClient *pjrt_client = + pjrt_make_cpu_client_shared(asynchronous, node_id, collectives); + if (pjrt_client == nullptr) + return nullptr; + return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes, + distributed_runtime_client, error, "cpu", + kv_store); +} + +extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared( + int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices, + double memory_fraction, bool preallocate, const char *platform_name, + const char **error, void *distributed_runtime_client) { + PjRtClient *client = MakeGPUClient( node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction, - preallocate, platform_name, error, distributed_runtime_client)); + preallocate, platform_name, error, distributed_runtime_client); + return reactant::capture(std::shared_ptr(client)); } -extern "C" ifrt::Client *ifrt_make_tpu_client(const char *tpu_path, - const char **error) { - return ifrt_pjrt_make_client(pjrt_make_tpu_client_shared(tpu_path, error)); +extern "C" ifrt::Client *ifrt_make_pjrt_gpu_client( + int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices, + double memory_fraction, bool preallocate, const char *platform_name, + const char **error, void *distributed_runtime_client) { + HeldPjRtClient *pjrt_client = pjrt_make_gpu_client_shared( + node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction, + preallocate, platform_name, error, distributed_runtime_client); + if (pjrt_client == nullptr) + return nullptr; + std::optional> kv_store; + return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes, + distributed_runtime_client, error, "gpu", + kv_store); +} + +extern "C" HeldPjRtClient *pjrt_make_tpu_client_shared(const char *tpu_path, + const char **error) { + PjRtClient *client = MakeTPUClient(tpu_path, error); + return reactant::capture(std::shared_ptr(client)); +} + +extern "C" ifrt::Client * +ifrt_make_pjrt_tpu_client(const char *tpu_path, const char **error, int node_id, + int num_nodes, void *distributed_runtime_client) { + HeldPjRtClient *pjrt_client = pjrt_make_tpu_client_shared(tpu_path, error); + if (pjrt_client == nullptr) + return nullptr; + std::optional> kv_store; + return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes, + distributed_runtime_client, error, "tpu", + kv_store); } extern "C" void ifrt_FreeClient(ifrt::Client *client) { delete client; } @@ -1633,26 +1692,14 @@ extern "C" ifrt::Client *ifrt_DeviceToClient(ifrt::Device *device) { return device->client(); } -extern "C" HeldValue> * -ifrt_CreateBasicDeviceListFromDevices(ifrt::Device **device_list, - int32_t num_devices) { - absl::Span devices(device_list, num_devices); - return reactant::capture(ifrt::BasicDeviceList::Create(devices)); -} - -extern "C" const char *ifrt_BasicDeviceListToString( - HeldValue> *device_list) { - return cstr_from_string(device_list->obj()->DebugString()); +extern "C" bool ifrt_DeviceIsAddressable(ifrt::Device *device) { + return device->IsAddressable(); } -extern "C" int ifrt_BasicDeviceListSize( - HeldValue> *device_list) { - return device_list->obj()->size(); -} - -extern "C" ifrt::Device *const ifrt_BasicDeviceListGetDevice( - HeldValue> *device_list, int32_t index) { - return device_list->obj()->devices()[index]; +tsl::RCReference ifrt_CreateDeviceListFromDevices( + ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices) { + absl::Span devices(device_list, num_devices); + return client->MakeDeviceList(devices); } extern "C" ifrt::Memory *ifrt_DeviceGetDefaultMemory(ifrt::Device *device) { @@ -1848,11 +1895,16 @@ hlo_sharding_to_string(const xla::HloSharding *hlo_sharding) { return cstr_from_string(hlo_sharding->ToString(true)); } +extern "C" ifrt::MemoryKind *ifrt_memory_kind_from_string(const char *c_str) { + return new ifrt::MemoryKind(std::string(c_str)); +} + extern "C" ifrt::HloSharding *ifrt_hlo_sharding_from_xla_hlo_sharding( - HeldValue> *device_list, + ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices, ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) { - return ifrt::HloSharding::Create(device_list->obj(), *memory_kind, - *xla_hlo_sharding) + return ifrt::HloSharding::Create( + ifrt_CreateDeviceListFromDevices(client, device_list, num_devices), + *memory_kind, *xla_hlo_sharding) .release(); } @@ -1868,9 +1920,133 @@ ifrt_hlo_sharding_to_string(ifrt::HloSharding *hlo_sharding) { return cstr_from_string(hlo_sharding->DebugString()); } +extern "C" ifrt::HloSharding *ifrt_sharding_to_ifrt_hlo_sharding( + HeldValue> *sharding) { + const ifrt::Sharding *val = sharding->obj().get(); + if (!llvm::isa(val)) + ReactantThrowError("Expected a HloSharding"); + return new ifrt::HloSharding(*llvm::dyn_cast(val)); +} + +extern "C" void +free_ifrt_sharding(HeldValue> *sharding) { + delete sharding; +} + +extern "C" HeldValue> * +ifrt_sharding_from_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) { + return reactant::capture(std::shared_ptr(hlo_sharding)); +} + +extern "C" HeldValue> * +ifrt_sharding_from_hlo_sharding(ifrt::Client *client, + ifrt::Device **device_list, int32_t num_devices, + ifrt::MemoryKind *memory_kind, + xla::HloSharding *xla_hlo_sharding) { + return ifrt_sharding_from_ifrt_hlo_sharding( + ifrt_hlo_sharding_from_xla_hlo_sharding(client, device_list, num_devices, + memory_kind, xla_hlo_sharding)); +} + +extern "C" bool ifrt_sharding_is_single_device_sharding( + HeldValue> *sharding) { + return llvm::isa(sharding->obj().get()); +} + +extern "C" bool ifrt_sharding_is_fully_replicated( + HeldValue> *sharding) { + return sharding->obj()->IsFullyReplicated(); +} + +extern "C" const char * +ifrt_sharding_to_string(HeldValue> *sharding) { + return cstr_from_string(sharding->obj()->DebugString()); +} + +extern "C" int32_t ifrt_sharding_devices_size( + HeldValue> *sharding) { + return sharding->obj()->devices()->size(); +} + +extern "C" void ifrt_sharding_to_device_list( + HeldValue> *sharding, + ifrt::Device **devices) { + auto device_list = sharding->obj()->devices()->devices(); + for (int i = 0; i < device_list.size(); i++) { + devices[i] = device_list[i]; + } +} + +#pragma endregion + +typedef ifrt::Future<> IfRtFutureType; + +extern "C" void ifrt_free_future(IfRtFutureType *Future) { delete Future; } + +extern "C" uint8_t ifrt_future_is_ready(IfRtFutureType *Future) { + return Future->IsReady(); +} + +extern "C" void ifrt_future_await(IfRtFutureType *Future) { Future->Await(); } + +#pragma region IfRtArray + +extern "C" void ifrt_free_array(HeldIfrtArray *array) { delete array; } + +extern "C" int64_t *ifrt_array_shape(HeldIfrtArray *array) { + auto dims = + static_cast>(array->obj()->shape().dims()); + int64_t *dims_ptr = new int64_t[dims.size()]; + std::copy(dims.begin(), dims.end(), dims_ptr); + return dims_ptr; +} + +extern "C" int64_t ifrt_array_ndims(HeldIfrtArray *array) { + return array->obj()->shape().dims().size(); +} + +extern "C" ifrt::DType ifrt_array_eltype(HeldIfrtArray *array) { + return array->obj()->dtype(); +} + +extern "C" ifrt::Client *ifrt_array_to_client(HeldIfrtArray *array) { + return array->obj()->client(); +} + +extern "C" HeldValue> * +ifrt_array_to_sharding(HeldIfrtArray *array) { + return reactant::capture(array->obj()->shared_ptr_sharding()); +} + +extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array, + void *data) { + std::optional> byte_strides; + auto future = array->obj()->CopyToHostBuffer( + data, byte_strides, static_cast(0)); + future.Await(); + return; +} + +extern "C" HeldIfrtArray **ifrt_array_disassemble_into_single_device_arrays( + HeldIfrtArray *array, int32_t c_semantics, + int32_t c_single_device_shard_semantics, int32_t *narrays) { + std::vector> single_device_arrays = + MyValueOrThrow(array->obj()->DisassembleIntoSingleDeviceArrays( + static_cast(c_semantics), + static_cast( + c_single_device_shard_semantics))); + + *narrays = single_device_arrays.size(); + HeldIfrtArray **arrays = new HeldIfrtArray *[single_device_arrays.size()]; + for (int i = 0; i < single_device_arrays.size(); i++) { + arrays[i] = reactant::capture(std::move(single_device_arrays[i])); + } + return arrays; +} + #pragma endregion -#pragma region PjRtDistributed +#pragma region xla::Distributed extern "C" HeldValue> * GetDistributedRuntimeClient(char *c_address, int32_t node_id, @@ -1985,3 +2161,118 @@ hloShardingToTensorShardingAttr(const xla::HloSharding *hloSharding, } #pragma endregion + +#pragma region ifrt::LoadedExecutable + +extern "C" void ifrt_loaded_executable_dtor(ifrt::LoadedExecutable *exec) { + delete exec; +} + +extern "C" void ifrt_loaded_executable_execute( + ifrt::LoadedExecutable *exec, int num_args, + HeldValue> **op_args, + uint8_t *is_arg_donatable, int num_results, + HeldValue> **op_results, uint8_t *futures, + FutureType **status) { + std::vector> args; + for (int i = 0; i < num_args; i++) { + args.emplace_back(op_args[i]->obj()); + } + + ifrt::ExecuteOptions options; + for (size_t i = 0; i < num_args; i++) { + if (!is_arg_donatable[i]) { + options.non_donatable_input_indices.insert(static_cast(i)); + } + } + options.fill_status = true; + + auto result = MyValueOrThrow(exec->Execute( + static_cast>>(args), + options, /* devices */ std::nullopt)); + + if (result.outputs.size() != num_results) { + llvm::errs() << "Error: results.size()=" << result.outputs.size() + << " does not match num_results=" << num_results << "\n"; + std::abort(); // Terminate if the number of results is incorrect. + } + + // there is only 1 status and is valid because we set `options.fill_status = + // true` + *futures = true; + *status = new FutureType(result.status); + + for (int i = 0; i < num_results; i++) { + op_results[i] = reactant::capture(result.outputs[i]); + } +} + +extern "C" ifrt::Client * +ifrt_loaded_executable_client(ifrt::LoadedExecutable *exec) { + return exec->client(); +} + +extern "C" void +ifrt_loaded_executable_get_parameter_shardings(ifrt::LoadedExecutable *exec, + xla::OpSharding **op_shardings, + int32_t num_op_shardings) { + std::optional> shardings = + exec->GetParameterShardings(); + if (!shardings.has_value()) { + ReactantThrowError( + "No sharding found for the output of the loaded executable"); + } + + std::vector hlo_op_shardings = shardings.value(); + if (num_op_shardings != hlo_op_shardings.size()) { + ReactantThrowError(("Expected " + std::to_string(num_op_shardings) + + " shardings, got " + + std::to_string(hlo_op_shardings.size())) + .c_str()); + } + + for (int32_t i = 0; i < num_op_shardings; i++) { + op_shardings[i] = new xla::OpSharding(hlo_op_shardings[i]); + } +} + +extern "C" void +ifrt_loaded_executable_get_output_shardings(ifrt::LoadedExecutable *exec, + xla::OpSharding **op_shardings, + int32_t num_op_shardings) { + std::optional> shardings = + exec->GetOutputShardings(); + if (!shardings.has_value()) { + ReactantThrowError( + "No sharding found for the output of the loaded executable"); + } + + std::vector hlo_op_shardings = shardings.value(); + if (num_op_shardings != hlo_op_shardings.size()) { + ReactantThrowError(("Expected " + std::to_string(num_op_shardings) + + " shardings, got " + + std::to_string(hlo_op_shardings.size())) + .c_str()); + } + + for (int32_t i = 0; i < num_op_shardings; i++) { + op_shardings[i] = new xla::OpSharding(hlo_op_shardings[i]); + } +} + +extern "C" void +ifrt_loaded_executable_get_hlo_modules(ifrt::LoadedExecutable *exec, + void **hlo_modules, int32_t *nmodules) { + auto hlo_modules_vec = MyValueOrThrow(exec->GetHloModules()); + *nmodules = hlo_modules_vec.size(); + for (int32_t i = 0; i < *nmodules; i++) { + hlo_modules[i] = reactant::capture(hlo_modules_vec[i]); + } +} + +extern "C" int32_t +ifrt_loaded_executable_num_devices(ifrt::LoadedExecutable *exec) { + return static_cast(exec->num_devices()); +} + +#pragma endregion diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 50f6395096..4d661f165f 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -602,6 +602,7 @@ cc_library( "@xla//xla/backends/profiler/cpu:metadata_utils", "@xla//xla/backends/profiler/tpu:tpu_tracer", "@xla//xla/python:profiler_utils", + "@xla//xla/backends/cpu/collectives:mpi_collectives", "@tsl//tsl/platform:env_impl", "@xla//xla/stream_executor:stream_executor_impl", @@ -632,7 +633,20 @@ cc_library( }) + if_rocm([ "@xla//xla/service/gpu:amdgpu_compiler", "@xla//xla/backends/profiler/gpu:device_tracer", - ]), + ]) + select({ + # gloo tcp transport only builds on linux + "@xla//xla/tsl:macos": [ + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + "@gloo//:transport_uv", + ], + "@xla//xla/tsl:windows": [], + "//conditions:default": [ + "@xla//xla/backends/cpu/collectives:gloo_collectives", + "@xla//xla/backends/cpu/collectives:gloo_kv_store", + "@gloo//:transport_tcp", + ], + }), ) # cc_shared_library(