diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 8fec3689c2..3d1c96d7b4 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -79,19 +79,20 @@ #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/hlo/hlo_program.h" #include "xla/python/ifrt/host_callback.h" #include "xla/python/ifrt/index.h" #include "xla/python/ifrt/index_domain.h" +#include "xla/python/ifrt/ir/ifrt_ir_program.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/ifrt/topology.h" #include "xla/python/ifrt/tuple.h" #include "xla/python/ifrt/value.h" -#include "xla/python/ifrt/ir/ifrt_ir_program.h" // IFRT - PJRT #include "xla/python/pjrt_ifrt/pjrt_array.h" @@ -104,6 +105,7 @@ #include "xla/python/pjrt_ifrt/pjrt_memory.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/pjrt_tuple.h" +#include "xla/python/pjrt_ifrt/xla_sharding.h" // IFRT - Proxy (RPC) #include "xla/python/ifrt_proxy/client/registry.h" @@ -421,6 +423,10 @@ extern "C" const char *ClientGetPlatformName(PjRtClient *client) { return cstr_from_string(client->platform_name()); } +extern "C" const char *DeviceGetKind(PjRtDevice *device) { + return cstr_from_string(device->device_kind()); +} + // To keep in sync with JLAllocatorStats in src/XLA.jl struct JLAllocatorStats { int64_t num_allocs; @@ -660,9 +666,10 @@ struct JLOpSharding { bool is_shard_group; int64_t shard_group_id; int32_t shard_group_type; + const void *op_sharding; }; -void OpShardingToJLOpSharding(const xla::OpSharding &op_sharding, +void OpShardingToJLOpSharding(const xla::OpSharding op_sharding, JLOpSharding *jl_op_sharding) { jl_op_sharding->type = op_sharding.type(); jl_op_sharding->replicate_on_last_tile_dim = @@ -737,58 +744,8 @@ void OpShardingToJLOpSharding(const xla::OpSharding &op_sharding, jl_op_sharding->is_shard_group = op_sharding.is_shard_group(); jl_op_sharding->shard_group_id = op_sharding.shard_group_id(); jl_op_sharding->shard_group_type = op_sharding.shard_group_type(); -} - -xla::OpSharding JLOpShardingToOpSharding(const JLOpSharding &jl_op_sharding) { - xla::OpSharding op_sharding; - - op_sharding.set_type(static_cast(jl_op_sharding.type)); - op_sharding.set_replicate_on_last_tile_dim( - jl_op_sharding.replicate_on_last_tile_dim); - - xla::ShapeProto *mutable_shape_proto = op_sharding.mutable_tile_shape(); - - for (int i = 0; i < jl_op_sharding.n_tile_dimensions; i++) { - mutable_shape_proto->add_dimensions(jl_op_sharding.tile_dimensions[i]); - } - - if (jl_op_sharding.n_layout_minor_to_major > 0) { - auto *mutable_layout = mutable_shape_proto->mutable_layout(); - for (int i = 0; i < jl_op_sharding.n_layout_minor_to_major; i++) { - mutable_layout->add_minor_to_major( - jl_op_sharding.layout_minor_to_major[i]); - } - } - - for (int i = 0; i < jl_op_sharding.n_tile_dimensions; i++) { - op_sharding.add_last_tile_dims( - static_cast(jl_op_sharding.last_tile_dims[i])); - } - - for (int i = 0; i < jl_op_sharding.n_tile_assignment_dimensions; i++) { - op_sharding.add_tile_assignment_dimensions( - jl_op_sharding.tile_assignment_dimensions[i]); - } - - for (int i = 0; i < jl_op_sharding.n_tile_assignment_devices; i++) { - op_sharding.add_tile_assignment_devices( - jl_op_sharding.tile_assignment_devices[i]); - } - - for (int i = 0; i < jl_op_sharding.n_iota_reshape_dims; i++) { - op_sharding.add_iota_reshape_dims(jl_op_sharding.iota_reshape_dims[i]); - } - - for (int i = 0; i < jl_op_sharding.n_iota_transpose_perm; i++) { - op_sharding.add_iota_transpose_perm(jl_op_sharding.iota_transpose_perm[i]); - } - op_sharding.set_is_shard_group(jl_op_sharding.is_shard_group); - op_sharding.set_shard_group_id(jl_op_sharding.shard_group_id); - op_sharding.set_shard_group_type(static_cast( - jl_op_sharding.shard_group_type)); - - return op_sharding; + jl_op_sharding->op_sharding = new xla::OpSharding(std::move(op_sharding)); } typedef PjRtFuture<> FutureType; @@ -800,7 +757,6 @@ extern "C" uint8_t FutureIsReady(FutureType *Future) { extern "C" void FutureAwait(FutureType *Future) { Future->Await(); } -// This is used by both the PjRt and IFRT clients xla::CompileOptions GenerateCompileOptions(int64_t device_id, bool is_sharded, const int64_t *mesh_ids, int64_t num_mesh_ids, @@ -830,7 +786,7 @@ xla::CompileOptions GenerateCompileOptions(int64_t device_id, bool is_sharded, for (int64_t i = 0; i < num_mesh_ids; ++i) { int64_t mesh_id = mesh_ids[i]; assert(mesh_id >= 0); - device_assignment(0, mesh_id) = i; + device_assignment(0, i) = mesh_id; } options.executable_build_options.set_device_assignment(device_assignment); @@ -989,31 +945,22 @@ void PrintPjRtBuffer(PjRtBuffer *buffer) { } extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len, - PjRtBuffer **op_args, const int64_t *mesh_ids, - int64_t num_mesh_ids, uint8_t *is_arg_donatable, + PjRtBuffer **op_args, uint8_t *is_arg_donatable, int num_results, PjRtBuffer **op_results, uint8_t *futures, FutureType **future_results) { - // Ensure argument_handles is structured as num_mesh_ids x num_args - std::vector> argument_handles(num_mesh_ids); - int num_args = op_args_len / num_mesh_ids; - - // Distribute arguments across devices - for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) { - int64_t mesh_id = mesh_ids[device_idx]; + xla::DeviceAssignment device_assignment = exec->device_assignment(); + int num_devices = device_assignment.computation_count(); - // Validate mesh_id - if (mesh_id < 0 || mesh_id >= num_mesh_ids) { - ReactantThrowError(("Invalid mesh_id " + std::to_string(mesh_id) + - " at device_idx " + std::to_string(device_idx)) - .c_str()); - } + // Ensure argument_handles is structured as num_devices x num_args + std::vector> argument_handles(num_devices); + int num_args = op_args_len / num_devices; + // Distribute arguments across devices + for (int device_idx = 0; device_idx < num_devices; ++device_idx) { argument_handles[device_idx].reserve(num_args); for (int arg_idx = 0; arg_idx < num_args; ++arg_idx) { - // Assuming op_args is a flat array of size num_devices * num_args - // where arguments for each device are contiguous argument_handles[device_idx].push_back( - op_args[mesh_id * num_args + arg_idx]); + op_args[device_idx * num_args + arg_idx]); } } @@ -1033,19 +980,20 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len, argument_handles), options, returned_futures)); - if (results.size() != num_mesh_ids) { + if (results.size() != num_devices) { ReactantThrowError((" results.size()=" + std::to_string(results.size()) + - " num_mesh_ids=" + std::to_string(num_mesh_ids) + "\n") + " num_devices=" + std::to_string(num_devices) + "\n") .c_str()); } - for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) { - int64_t mesh_id = mesh_ids[device_idx]; - if (results[mesh_id].size() != num_results) { - ReactantThrowError((" results[" + std::to_string(mesh_id) + "].size()=" + - std::to_string(results[mesh_id].size()) + - " num_results=" + std::to_string(num_results) + "\n") - .c_str()); + for (int device_idx = 0; device_idx < num_devices; ++device_idx) { + // Remove mesh_id lookup since we're using device_idx ordering + if (results[device_idx].size() != num_results) { + ReactantThrowError( + (" results[" + std::to_string(device_idx) + + "].size()=" + std::to_string(results[device_idx].size()) + + " num_results=" + std::to_string(num_results) + "\n") + .c_str()); } } @@ -1053,20 +1001,18 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable *exec, int op_args_len, auto future_val = returned_futures.has_value(); *futures = future_val; if (future_val) { - if (returned_futures->size() != num_mesh_ids) { + if (returned_futures->size() != num_devices) { ReactantThrowError((" returned_futures->size()=" + std::to_string(returned_futures->size()) + - " num_mesh_ids=" + std::to_string(num_mesh_ids) + - "\n") + " num_devices=" + std::to_string(num_devices) + "\n") .c_str()); } } // Copy results into the output buffers - for (int device_idx = 0; device_idx < num_mesh_ids; ++device_idx) { - int64_t mesh_id = mesh_ids[device_idx]; + for (int device_idx = 0; device_idx < num_devices; ++device_idx) { for (int result_idx = 0; result_idx < num_results; ++result_idx) { - int flat_index = mesh_id * num_results + result_idx; + int flat_index = device_idx * num_results + result_idx; op_results[flat_index] = results[device_idx][result_idx].release(); if (future_val) { future_results[flat_index] = @@ -1290,86 +1236,54 @@ using HeldPjRtClient = HeldValue>; using HeldPjRtBuffer = HeldValue>; using HeldIfrtArray = HeldValue>; -// deprecated -// extern "C" HeldPjRtClient * reactant_hold_pjrtclient(xla::PjRtClient *client) { -// return reactant::capture(std::shared_ptr(client)); -// } - -extern "C" HeldPjRtClient * pjrt_make_cpu_client_shared( - uint8_t asynchronous, - int node_id, - int num_nodes) -{ - PjRtClient* client = MakeCPUClient(asynchronous, node_id, num_nodes); +extern "C" HeldPjRtClient * +pjrt_make_cpu_client_shared(uint8_t asynchronous, int node_id, int num_nodes) { + PjRtClient *client = MakeCPUClient(asynchronous, node_id, num_nodes); return reactant::capture(std::shared_ptr(client)); } -extern "C" HeldPjRtClient* pjrt_make_gpu_client_shared( - int node_id, - int num_nodes, - int* allowed_devices, - int num_allowed_devices, - double memory_fraction, - bool preallocate, - const char* platform_name, - const char** error) -{ - PjRtClient* client = MakeGPUClient( - node_id, - num_nodes, - allowed_devices, - num_allowed_devices, - memory_fraction, - preallocate, - platform_name, - error - ); +extern "C" HeldPjRtClient * +pjrt_make_gpu_client_shared(int node_id, int num_nodes, int *allowed_devices, + int num_allowed_devices, double memory_fraction, + bool preallocate, const char *platform_name, + const char **error) { + PjRtClient *client = + MakeGPUClient(node_id, num_nodes, allowed_devices, num_allowed_devices, + memory_fraction, preallocate, platform_name, error); return reactant::capture(std::shared_ptr(client)); } -extern "C" HeldPjRtClient* pjrt_make_tpu_client_shared( - const char* tpu_path, - const char** error -) { - PjRtClient* client = MakeTPUClient(tpu_path, error); +extern "C" HeldPjRtClient *pjrt_make_tpu_client_shared(const char *tpu_path, + const char **error) { + PjRtClient *client = MakeTPUClient(tpu_path, error); return reactant::capture(std::shared_ptr(client)); } -extern "C" void pjrt_client_dtor(HeldPjRtClient *client) { - delete client; -} +extern "C" void pjrt_client_dtor(HeldPjRtClient *client) { delete client; } -extern "C" int pjrt_client_num_devices(HeldPjRtClient* client) { +extern "C" int pjrt_client_num_devices(HeldPjRtClient *client) { return client->ptr()->device_count(); } -extern "C" int pjrt_client_num_addressable_devices( - HeldPjRtClient* client -) { +extern "C" int pjrt_client_num_addressable_devices(HeldPjRtClient *client) { return client->ptr()->addressable_device_count(); } -extern "C" int pjrt_client_pid(HeldPjRtClient* client) { +extern "C" int pjrt_client_pid(HeldPjRtClient *client) { return client->ptr()->process_index(); } -extern "C" PjRtDevice* pjrt_client_get_device( - HeldPjRtClient* client, - int device_id -) { +extern "C" PjRtDevice *pjrt_client_get_device(HeldPjRtClient *client, + int device_id) { return ClientGetDevice(client->ptr(), device_id); } -extern "C" PjRtDevice* pjrt_client_get_addressable_device( - HeldPjRtClient* client, - int device_id -) { +extern "C" PjRtDevice * +pjrt_client_get_addressable_device(HeldPjRtClient *client, int device_id) { return ClientGetAddressableDevice(client->ptr(), device_id); } -extern "C" const char* pjrt_client_platform_name( - HeldPjRtClient* client -) { +extern "C" const char *pjrt_client_platform_name(HeldPjRtClient *client) { return ClientGetPlatformName(client->ptr()); } @@ -1379,71 +1293,49 @@ extern "C" const char* pjrt_client_platform_name( // return reactant::capture(std::shared_ptr(buffer)); // } -extern "C" HeldPjRtBuffer* -pjrt_buffer_from_host( - HeldPjRtClient* client, - void* data, - uint64_t ptype, - size_t dim, - int64_t* cshape, - PjRtDevice* device -) { - PjRtBuffer* buffer = ArrayFromHostBuffer( - client->ptr(), - data, - ptype, - dim, - cshape, - device - ); +extern "C" HeldPjRtBuffer *pjrt_buffer_from_host(HeldPjRtClient *client, + void *data, uint64_t ptype, + size_t dim, int64_t *cshape, + PjRtDevice *device) { + PjRtBuffer *buffer = + ArrayFromHostBuffer(client->ptr(), data, ptype, dim, cshape, device); return reactant::capture(std::shared_ptr(buffer)); } -extern "C" void pjrt_buffer_dtor(HeldPjRtBuffer *buffer) { - delete buffer; -} +extern "C" void pjrt_buffer_dtor(HeldPjRtBuffer *buffer) { delete buffer; } -extern "C" void* pjrt_buffer_unsafe_buffer_pointer( - HeldPjRtBuffer* buffer) -{ +extern "C" void *pjrt_buffer_unsafe_buffer_pointer(HeldPjRtBuffer *buffer) { return UnsafeBufferPointer(buffer->ptr()); } -extern "C" bool pjrt_buffer_is_on_cpu(HeldPjRtBuffer* buffer) { +extern "C" bool pjrt_buffer_is_on_cpu(HeldPjRtBuffer *buffer) { return buffer->ptr()->IsOnCpu(); } -extern "C" HeldPjRtBuffer* pjrt_buffer_copy_to_device( - HeldPjRtBuffer* buffer, - PjRtDevice* dst_device) -{ - PjRtBuffer* ret = CopyBufferToDevice(buffer->ptr(), dst_device); +extern "C" HeldPjRtBuffer *pjrt_buffer_copy_to_device(HeldPjRtBuffer *buffer, + PjRtDevice *dst_device) { + PjRtBuffer *ret = CopyBufferToDevice(buffer->ptr(), dst_device); return reactant::capture(std::shared_ptr(ret)); } -extern "C" void pjrt_buffer_to_host(HeldPjRtBuffer* buffer, void* data) -{ +extern "C" void pjrt_buffer_to_host(HeldPjRtBuffer *buffer, void *data) { BufferToHost(buffer->ptr(), data); } -extern "C" void pjrt_buffer_print(HeldPjRtBuffer* buffer) { +extern "C" void pjrt_buffer_print(HeldPjRtBuffer *buffer) { PrintPjRtBuffer(buffer->ptr()); } -extern "C" PjRtDevice* pjrt_buffer_get_device(HeldPjRtBuffer* buffer) { +extern "C" PjRtDevice *pjrt_buffer_get_device(HeldPjRtBuffer *buffer) { return buffer->ptr()->device(); } -extern "C" HeldPjRtClient* pjrt_buffer_get_client( - HeldPjRtBuffer* buffer -) { +extern "C" HeldPjRtClient *pjrt_buffer_get_client(HeldPjRtBuffer *buffer) { return reactant::capture( - std::shared_ptr(buffer->ptr()->client()) - ); + std::shared_ptr(buffer->ptr()->client())); } -extern "C" ifrt::Client* ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client) -{ +extern "C" ifrt::Client *ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client) { xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()}; return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release(); } @@ -1452,142 +1344,67 @@ extern "C" void ifrt_client_dtor(ifrt::Client *client) { delete client; } // generic version, but IFRT-PjRt backend only supports SingleDeviceSharding // and FullyReplicated. use `ifrt_pjrt_array_create` if using IFRT-PjRt. -extern "C" HeldIfrtArray* ifrt_client_make_array_from_host_buffer( - ifrt::Client* client, - void* data, - int dtype_kind, // int - int ndims, - const int64_t* c_shape, - HeldValue>* sharding, - int c_semantics -) { +extern "C" HeldIfrtArray *ifrt_client_make_array_from_host_buffer( + ifrt::Client *client, void *data, + int dtype_kind, // int + int ndims, const int64_t *c_shape, + HeldValue> *sharding, + int c_semantics) { auto dtype = ifrt::DType(static_cast(dtype_kind)); auto shape = ifrt::Shape(absl::Span(c_shape, ndims)); return reactant::capture(MyValueOrThrow(client->MakeArrayFromHostBuffer( - data, - dtype, - shape, - std::nullopt, // byte_strides - sharding->obj(), - static_cast(c_semantics), - []{} // on_done_with_host_buffer - ))); -} - -extern "C" HeldIfrtArray* ifrt_client_make_single_shard_array_from_host_buffer( - ifrt::Client* client, - void* data, - int dtype_kind, // int - int ndims, - const int64_t* c_shape, - int c_semantics, - ifrt::Device* device, - const char* mem_kind -) { + data, dtype, shape, + std::nullopt, // byte_strides + sharding->obj(), + static_cast(c_semantics), + [] {} // on_done_with_host_buffer + ))); +} + +extern "C" HeldIfrtArray *ifrt_client_make_single_shard_array_from_host_buffer( + ifrt::Client *client, void *data, + int dtype_kind, // int + int ndims, const int64_t *c_shape, int c_semantics, ifrt::Device *device, + const char *mem_kind) { auto memory_kind = ifrt::MemoryKind(std::string(mem_kind)); auto sharding = reactant::capture(std::shared_ptr( - ifrt::SingleDeviceSharding::Create(device, memory_kind).release() - )); + ifrt::SingleDeviceSharding::Create(device, memory_kind).release())); return ifrt_client_make_array_from_host_buffer( - client, - data, - dtype_kind, - ndims, - c_shape, - sharding, - c_semantics - ); + client, data, dtype_kind, ndims, c_shape, sharding, c_semantics); } // all arrays are assumed to have same DType -extern "C" HeldIfrtArray* ifrt_client_assemble_array_from_single_shards( - ifrt::Client* client, - int ndims, - const int64_t* c_shape, - HeldValue>* sharding, - int narrays, - HeldIfrtArray** c_arrays, - int c_semantics -) { +extern "C" HeldIfrtArray *ifrt_client_assemble_array_from_single_shards( + ifrt::Client *client, int ndims, const int64_t *c_shape, + HeldValue> *sharding, int narrays, + HeldIfrtArray **c_arrays, int c_semantics) { auto shape = ifrt::Shape(absl::Span(c_shape, ndims)); std::vector> arrays; for (int i = 0; i < narrays; i++) { arrays.emplace_back(c_arrays[i]->obj()); } auto semantics = static_cast(c_semantics); - return reactant::capture(MyValueOrThrow( - client->AssembleArrayFromSingleDeviceArrays( - shape, - sharding->obj(), - static_cast>>(arrays), - semantics - ) - )); -} - -extern "C" int ifrt_client_device_count(ifrt::Client* client) { - return client->device_count(); -} - -extern "C" int ifrt_client_addressable_device_count(ifrt::Client* client) { - return client->addressable_device_count(); -} - -extern "C" void ifrt_client_devices( - ifrt::Client* client, ifrt::Device** out_devices -) { - auto span = client->devices(); - for (int i = 0; i < span.size(); i++) { - out_devices[i] = span[i]; - } -} - -extern "C" void ifrt_client_addressable_devices( - ifrt::Client* client, ifrt::Device** out_devices -) { - auto span = client->addressable_devices(); - for (int i = 0; i < span.size(); i++) { - out_devices[i] = span[i]; - } -} - -extern "C" void ifrt_client_all_devices( - ifrt::Client* client, ifrt::Device** out_devices -) { - auto span = client->GetAllDevices(); - for (int i = 0; i < span.size(); i++) { - out_devices[i] = span[i]; - } -} - -extern "C" ifrt::Device* ifrt_client_lookup_device( - ifrt::Client* client, int dev_id -) { - return MyValueOrThrow(client->LookupDevice(static_cast(dev_id))); -} - -extern "C" ifrt::Device* ifrt_client_lookup_addressable_device( - ifrt::Client* client, int local_hw_id -) { - return MyValueOrThrow(client->LookupAddressableDevice(local_hw_id)); + return reactant::capture( + MyValueOrThrow(client->AssembleArrayFromSingleDeviceArrays( + shape, sharding->obj(), + static_cast>>(arrays), + semantics))); } // we should deprecate this because is IFRT-PjRt specific // try use `ifrt_client_make_single_shard_array_from_host_buffer` instead -extern "C" HeldIfrtArray* ifrt_pjrt_array_create( - ifrt::PjRtClient *client, - HeldValue> *buffer -) { +extern "C" HeldIfrtArray * +ifrt_pjrt_array_create(ifrt::PjRtClient *client, + HeldValue> *buffer) { return reactant::capture(tsl::RCReference( MyValueOrThrow(xla::ifrt::PjRtArray::Create(client, buffer->obj())))); } // TODO how do we compile for other backends? -extern "C" xla::ifrt::LoadedExecutable* ifrt_pjrt_compile( - ifrt::PjRtClient *client, MlirModule cmod, int64_t device_id, - bool is_sharded, const int64_t *mesh_ids, - int64_t num_mesh_ids, const char *xla_gpu_cuda_data_dir -) { +extern "C" xla::ifrt::LoadedExecutable * +ifrt_pjrt_compile(ifrt::PjRtClient *client, MlirModule cmod, int64_t device_id, + bool is_sharded, const int64_t *mesh_ids, + int64_t num_mesh_ids, const char *xla_gpu_cuda_data_dir) { CompileOptions options = GenerateCompileOptions( device_id, is_sharded, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir); @@ -1608,14 +1425,14 @@ extern "C" xla::ifrt::LoadedExecutable* ifrt_pjrt_compile( return exec.release(); } -// we might me interested in the `Compiler::Compile` method variant that accepts `Topology` -extern "C" xla::ifrt::LoadedExecutable* ifrt_compile( - ifrt::Client *client, MlirModule cmod, int64_t device_id, - bool is_sharded, const int64_t *mesh_ids, - int64_t num_mesh_ids, const char *xla_gpu_cuda_data_dir -) { - // TODO we need a `xla::ifrt::CompileOptions` but this is `xla::CompileOptions` - // auto options = std::make_unique( +// we might me interested in the `Compiler::Compile` method variant that accepts +// `Topology` +extern "C" xla::ifrt::LoadedExecutable * +ifrt_compile(ifrt::Client *client, MlirModule cmod, int64_t device_id, + bool is_sharded, const int64_t *mesh_ids, int64_t num_mesh_ids, + const char *xla_gpu_cuda_data_dir) { + // TODO we need a `xla::ifrt::CompileOptions` but this is + // `xla::CompileOptions` auto options = std::make_unique( // GenerateCompileOptions( // device_id, // is_sharded, @@ -1635,10 +1452,13 @@ extern "C" xla::ifrt::LoadedExecutable* ifrt_compile( } } - auto program = std::make_unique(xla::ifrt::IfrtIRProgram(cmod_op)); + auto program = + std::make_unique(xla::ifrt::IfrtIRProgram(cmod_op)); auto compiler = client->GetDefaultCompiler(); - return MyValueOrThrow(compiler->Compile(std::move(program), std::move(options))).release(); + return MyValueOrThrow( + compiler->Compile(std::move(program), std::move(options))) + .release(); } extern "C" void @@ -1648,13 +1468,12 @@ ifrt_pjrt_loaded_executable_dtor(xla::ifrt::PjRtLoadedExecutable *exec) { extern "C" void ifrt_array_dtor(HeldIfrtArray *array) { delete array; } -extern "C" void -ifrt_loaded_executable_execute(ifrt::LoadedExecutable *exec, int num_args, - HeldValue> **op_args, - uint8_t *is_arg_donatable, int num_results, - HeldValue> **op_results, - uint8_t *futures, FutureType **status) -{ +extern "C" void ifrt_loaded_executable_execute( + ifrt::LoadedExecutable *exec, int num_args, + HeldValue> **op_args, + uint8_t *is_arg_donatable, int num_results, + HeldValue> **op_results, uint8_t *futures, + FutureType **status) { std::vector> args; for (int i = 0; i < num_args; i++) { args.emplace_back(op_args[i]->obj()); @@ -1690,8 +1509,8 @@ ifrt_loaded_executable_execute(ifrt::LoadedExecutable *exec, int num_args, // in principle, use ArrayCopySemantics::kAlwaysCopy (=0) extern "C" FutureType * -ifrt_CopyArrayToHostBuffer(HeldIfrtArray *array, - void *data, ifrt::ArrayCopySemantics semantics) { +ifrt_CopyArrayToHostBuffer(HeldIfrtArray *array, void *data, + ifrt::ArrayCopySemantics semantics) { return new FutureType( (*array)->CopyToHostBuffer(data, std::nullopt, semantics)); } @@ -1716,6 +1535,8 @@ FreeHloModule(HeldValue> *hlo_module) { delete hlo_module; } +#pragma region IfRtClient + // right now only making it available for TPU // in the future, we would like this for CPU and GPU PjRt backends too extern "C" ifrt::proxy::GrpcServer * @@ -1785,3 +1606,212 @@ ifrt_proxy_create_client(const char *c_proxy_server_address, ifrt::proxy::CreateClient(c_proxy_server_address, options)) .release(); } + +extern "C" ifrt::Client *ifrt_make_cpu_client(uint8_t asynchronous, int node_id, + int num_nodes) { + return ifrt_pjrt_make_client( + pjrt_make_cpu_client_shared(asynchronous, node_id, num_nodes)); +} + +extern "C" ifrt::Client * +ifrt_make_gpu_client(int node_id, int num_nodes, int *allowed_devices, + int num_allowed_devices, double memory_fraction, + bool preallocate, const char *platform_name, + const char **error) { + return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared( + node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction, + preallocate, platform_name, error)); +} + +extern "C" ifrt::Client *ifrt_make_tpu_client(const char *tpu_path, + const char **error) { + return ifrt_pjrt_make_client(pjrt_make_tpu_client_shared(tpu_path, error)); +} + +extern "C" void ifrt_FreeClient(ifrt::Client *client) { delete client; } + +extern "C" int ifrt_client_device_count(ifrt::Client *client) { + return client->device_count(); +} + +extern "C" int ifrt_client_addressable_device_count(ifrt::Client *client) { + return client->addressable_device_count(); +} + +extern "C" void ifrt_client_devices(ifrt::Client *client, + ifrt::Device **out_devices) { + auto span = client->devices(); + for (int i = 0; i < span.size(); i++) { + out_devices[i] = span[i]; + } +} + +extern "C" void ifrt_client_addressable_devices(ifrt::Client *client, + ifrt::Device **out_devices) { + auto span = client->addressable_devices(); + for (int i = 0; i < span.size(); i++) { + out_devices[i] = span[i]; + } +} + +extern "C" void ifrt_client_all_devices(ifrt::Client *client, + ifrt::Device **out_devices) { + auto span = client->GetAllDevices(); + for (int i = 0; i < span.size(); i++) { + out_devices[i] = span[i]; + } +} + +extern "C" ifrt::Device *ifrt_client_lookup_device(ifrt::Client *client, + int dev_id) { + return MyValueOrThrow( + client->LookupDevice(static_cast(dev_id))); +} + +extern "C" ifrt::Device * +ifrt_client_lookup_addressable_device(ifrt::Client *client, int local_hw_id) { + return MyValueOrThrow(client->LookupAddressableDevice(local_hw_id)); +} + +extern "C" int ifrt_ClientProcessIndex(ifrt::Client *client) { + return client->process_index(); +} + +extern "C" const char *ifrt_ClientGetPlatformName(ifrt::Client *client) { + return cstr_from_string(client->platform_name()); +} + +extern "C" ifrt::Device *ifrt_ClientGetDevice(ifrt::Client *client, int idx) { + return MyValueOrThrow(client->LookupDevice(ifrt::DeviceId(idx))); +} + +extern "C" ifrt::Device *ifrt_ClientGetAddressableDevice(ifrt::Client *client, + int idx) { + return MyValueOrThrow(client->LookupAddressableDevice(idx)); +} + +#pragma endregion + +#pragma region IfRtDevice + +extern "C" int64_t ifrt_DeviceGetGlobalDeviceId(ifrt::Device *device) { + return device->Id().value(); +} + +extern "C" const char *ifrt_DeviceGetKind(ifrt::Device *device) { + return cstr_from_string(device->Kind()); +} + +extern "C" ifrt::Client *ifrt_DeviceToClient(ifrt::Device *device) { + return device->client(); +} + +extern "C" HeldValue> * +ifrt_CreateBasicDeviceListFromDevices(ifrt::Device **device_list, + int32_t num_devices) { + absl::Span devices(device_list, num_devices); + return reactant::capture(ifrt::BasicDeviceList::Create(devices)); +} + +extern "C" const char *ifrt_BasicDeviceListToString( + HeldValue> *device_list) { + return cstr_from_string(device_list->obj()->DebugString()); +} + +extern "C" int ifrt_BasicDeviceListSize( + HeldValue> *device_list) { + return device_list->obj()->size(); +} + +extern "C" ifrt::Device *const ifrt_BasicDeviceListGetDevice( + HeldValue> *device_list, int32_t index) { + return device_list->obj()->devices()[index]; +} + +extern "C" ifrt::Memory *ifrt_DeviceGetDefaultMemory(ifrt::Device *device) { + return MyValueOrThrow(device->DefaultMemory()); +} + +extern "C" ifrt::Memory **ifrt_DeviceGetMemories(ifrt::Device *device, + int32_t *size) { + auto memory_list = device->Memories(); + *size = memory_list.size(); + return const_cast(memory_list.data()); +} + +extern "C" ifrt::MemoryKind *ifrt_MemoryGetMemoryKind(ifrt::Memory *memory) { + ifrt::MemoryKind *memory_kind = new ifrt::MemoryKind(memory->Kind()); + return memory_kind; +} + +extern "C" const char *ifrt_MemoryToString(ifrt::Memory *memory) { + return cstr_from_string(memory->ToString()); +} + +extern "C" const char *ifrt_MemoryKindToString(ifrt::MemoryKind *memory_kind) { + auto memkind = memory_kind->memory_kind(); + if (!memkind.has_value()) + return ""; + return cstr_from_string(memkind.value()); +} + +extern "C" bool ifrt_MemoryKindsAreEqual(ifrt::MemoryKind *a, + ifrt::MemoryKind *b) { + return *a == *b; +} + +#pragma endregion + +#pragma region HloSharding + +extern "C" void free_op_sharding(xla::OpSharding *op_sharding) { + delete op_sharding; +} + +extern "C" void free_hlo_sharding(xla::HloSharding *hlo_sharding) { + delete hlo_sharding; +} + +extern "C" void free_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) { + delete hlo_sharding; +} + +extern "C" xla::HloSharding * +hlo_sharding_from_op_sharding(xla::OpSharding *op_sharding) { + xla::HloSharding *hlo_sharding = new xla::HloSharding( + MyValueOrThrow(xla::HloSharding::FromProto(*op_sharding))); + return hlo_sharding; +} + +extern "C" xla::OpSharding * +hlo_sharding_to_op_sharding(xla::HloSharding *hlo_sharding) { + xla::OpSharding *op_sharding = new xla::OpSharding(hlo_sharding->ToProto()); + return op_sharding; +} + +extern "C" const char * +hlo_sharding_to_string(const xla::HloSharding *hlo_sharding) { + return cstr_from_string(hlo_sharding->ToString(true)); +} + +extern "C" ifrt::HloSharding *ifrt_hlo_sharding_from_xla_hlo_sharding( + HeldValue> *device_list, + ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) { + return ifrt::HloSharding::Create(device_list->obj(), *memory_kind, + *xla_hlo_sharding) + .release(); +} + +extern "C" xla::HloSharding * +ifrt_hlo_sharding_to_xla_hlo_sharding(ifrt::HloSharding *hlo_sharding) { + xla::HloSharding *xla_hlo_sharding = + new xla::HloSharding(hlo_sharding->xla_hlo_sharding()); + return xla_hlo_sharding; +} + +extern "C" const char * +ifrt_hlo_sharding_to_string(ifrt::HloSharding *hlo_sharding) { + return cstr_from_string(hlo_sharding->DebugString()); +} + +#pragma endregion diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index d0bf199c0e..3751810c74 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -474,6 +474,13 @@ cc_library( "-Wl,-exported_symbol,_PjRtLoadedExecutableNumPartitions", "-Wl,-exported_symbol,_ifrt_*", "-Wl,-exported_symbol,_reactant_*", +"-Wl,-exported_symbol,_free_op_sharding", +"-Wl,-exported_symbol,_free_hlo_sharding", +"-Wl,-exported_symbol,_free_ifrt_hlo_sharding", +"-Wl,-exported_symbol,_hlo_sharding_from_op_sharding", +"-Wl,-exported_symbol,_hlo_sharding_to_op_sharding", +"-Wl,-exported_symbol,_hlo_sharding_to_string", +"-Wl,-exported_symbol,_DeviceGetKind", ]}), deps = [ "@enzyme//:EnzymeMLIR",