@@ -342,14 +342,12 @@ extern "C" void ProfilerServerStop(tsl::profiler::ProfilerServer *server) {
342342 delete server;
343343}
344344
345- extern " C" PjRtClient *MakeCPUClient (uint8_t asynchronous, int node_id,
346- int num_nodes) {
345+ extern " C" PjRtClient *MakeCPUClient (uint8_t asynchronous, int node_id) {
347346 CpuClientOptions options;
348- // options.kv_store = "etcd";
347+
349348 options.process_id = node_id;
350- // options.num_nodes = num_nodes;
351- // options.collectives = num_nodes;
352349 options.asynchronous = asynchronous != 0 ;
350+
353351 auto client = MyValueOrThrow (GetTfrtCpuClient (options));
354352 return client.release ();
355353}
@@ -1271,28 +1269,6 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
12711269 return wrap (entryFn);
12721270}
12731271
1274- extern " C" HeldPjRtClient *
1275- pjrt_make_cpu_client_shared (uint8_t asynchronous, int node_id, int num_nodes) {
1276- PjRtClient *client = MakeCPUClient (asynchronous, node_id, num_nodes);
1277- return reactant::capture (std::shared_ptr<PjRtClient>(client));
1278- }
1279-
1280- extern " C" HeldPjRtClient *pjrt_make_gpu_client_shared (
1281- int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1282- double memory_fraction, bool preallocate, const char *platform_name,
1283- const char **error, void *distributed_runtime_client) {
1284- PjRtClient *client = MakeGPUClient (
1285- node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1286- preallocate, platform_name, error, distributed_runtime_client);
1287- return reactant::capture (std::shared_ptr<PjRtClient>(client));
1288- }
1289-
1290- extern " C" HeldPjRtClient *pjrt_make_tpu_client_shared (const char *tpu_path,
1291- const char **error) {
1292- PjRtClient *client = MakeTPUClient (tpu_path, error);
1293- return reactant::capture (std::shared_ptr<PjRtClient>(client));
1294- }
1295-
12961272extern " C" void pjrt_client_dtor (HeldPjRtClient *client) { delete client; }
12971273
12981274extern " C" int pjrt_client_num_devices (HeldPjRtClient *client) {
@@ -1369,11 +1345,6 @@ extern "C" HeldPjRtClient *pjrt_buffer_get_client(HeldPjRtBuffer *buffer) {
13691345 std::shared_ptr<PjRtClient>(buffer->ptr ()->client ()));
13701346}
13711347
1372- extern " C" ifrt::Client *ifrt_pjrt_make_client (HeldPjRtClient *pjrt_client) {
1373- xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj ()};
1374- return MyValueOrThrow (xla::ifrt::PjRtClient::Create (options)).release ();
1375- }
1376-
13771348extern " C" void ifrt_client_dtor (ifrt::Client *client) { delete client; }
13781349
13791350// generic version, but IFRT-PjRt backend only supports SingleDeviceSharding
@@ -1577,16 +1548,16 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
15771548
15781549extern " C" ifrt::proxy::GrpcServer *
15791550ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu (
1580- const char *c_address, uint8_t asynchronous, int node_id, int num_nodes ) {
1551+ const char *c_address, uint8_t asynchronous, int node_id) {
15811552 std::string address = c_address;
15821553
15831554 return MyValueOrThrow (
15841555 ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory (
15851556 address,
1586- [asynchronous, node_id, num_nodes]()
1587- -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1557+ [asynchronous,
1558+ node_id]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
15881559 auto pjrt_client = std::shared_ptr<PjRtClient>(
1589- MakeCPUClient (asynchronous, node_id, num_nodes ));
1560+ MakeCPUClient (asynchronous, node_id));
15901561 return std::shared_ptr<ifrt::Client>(
15911562 xla::ifrt::PjRtClient::Create (pjrt_client).release ());
15921563 }))
@@ -1662,24 +1633,88 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
16621633 .release ();
16631634}
16641635
1665- extern " C" ifrt::Client *ifrt_make_pjrt_cpu_client (uint8_t asynchronous,
1666- int node_id, int num_nodes) {
1667- return ifrt_pjrt_make_client (
1668- pjrt_make_cpu_client_shared (asynchronous, node_id, num_nodes));
1636+ extern " C" ifrt::Client *ifrt_pjrt_make_client (HeldPjRtClient *pjrt_client,
1637+ int node_id, int num_nodes,
1638+ void *distributed_runtime_client,
1639+ const char **error,
1640+ std::string key_prefix) {
1641+ ifrt::PjRtClient::CreateOptions options;
1642+ options.pjrt_client = pjrt_client->obj ();
1643+
1644+ if (num_nodes > 1 ) {
1645+ if (distributed_runtime_client == nullptr ) {
1646+ *error =
1647+ " `distributed_runtime_client` must be non-null if `num_nodes` > 1" ;
1648+ return nullptr ;
1649+ }
1650+ auto typed_distributed_runtime_client = static_cast <
1651+ HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *>(
1652+ distributed_runtime_client);
1653+ options.kv_store = GetDistributedKeyValueStore (
1654+ typed_distributed_runtime_client->obj (), key_prefix);
1655+ }
1656+
1657+ options.process_id = node_id;
1658+ options.num_processes = num_nodes;
1659+
1660+ return MyValueOrThrow (xla::ifrt::PjRtClient::Create (options)).release ();
1661+ }
1662+
1663+ extern " C" HeldPjRtClient *pjrt_make_cpu_client_shared (uint8_t asynchronous,
1664+ int node_id) {
1665+ PjRtClient *client = MakeCPUClient (asynchronous, node_id);
1666+ return reactant::capture (std::shared_ptr<PjRtClient>(client));
1667+ }
1668+
1669+ extern " C" ifrt::Client *
1670+ ifrt_make_pjrt_cpu_client (uint8_t asynchronous, int node_id, int num_nodes,
1671+ void *distributed_runtime_client,
1672+ const char **error) {
1673+ HeldPjRtClient *pjrt_client =
1674+ pjrt_make_cpu_client_shared (asynchronous, node_id);
1675+ if (pjrt_client == nullptr )
1676+ return nullptr ;
1677+ return ifrt_pjrt_make_client (pjrt_client, node_id, num_nodes,
1678+ distributed_runtime_client, error, " cpu" );
1679+ }
1680+
1681+ extern " C" HeldPjRtClient *pjrt_make_gpu_client_shared (
1682+ int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1683+ double memory_fraction, bool preallocate, const char *platform_name,
1684+ const char **error, void *distributed_runtime_client) {
1685+ PjRtClient *client = MakeGPUClient (
1686+ node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1687+ preallocate, platform_name, error, distributed_runtime_client);
1688+ return reactant::capture (std::shared_ptr<PjRtClient>(client));
16691689}
16701690
16711691extern " C" ifrt::Client *ifrt_make_pjrt_gpu_client (
16721692 int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
16731693 double memory_fraction, bool preallocate, const char *platform_name,
16741694 const char **error, void *distributed_runtime_client) {
1675- return ifrt_pjrt_make_client ( pjrt_make_gpu_client_shared (
1695+ HeldPjRtClient *pjrt_client = pjrt_make_gpu_client_shared (
16761696 node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1677- preallocate, platform_name, error, distributed_runtime_client));
1697+ preallocate, platform_name, error, distributed_runtime_client);
1698+ if (pjrt_client == nullptr )
1699+ return nullptr ;
1700+ return ifrt_pjrt_make_client (pjrt_client, node_id, num_nodes,
1701+ distributed_runtime_client, error, " gpu" );
16781702}
16791703
1680- extern " C" ifrt::Client *ifrt_make_pjrt_tpu_client (const char *tpu_path,
1681- const char **error) {
1682- return ifrt_pjrt_make_client (pjrt_make_tpu_client_shared (tpu_path, error));
1704+ extern " C" HeldPjRtClient *pjrt_make_tpu_client_shared (const char *tpu_path,
1705+ const char **error) {
1706+ PjRtClient *client = MakeTPUClient (tpu_path, error);
1707+ return reactant::capture (std::shared_ptr<PjRtClient>(client));
1708+ }
1709+
1710+ extern " C" ifrt::Client *
1711+ ifrt_make_pjrt_tpu_client (const char *tpu_path, const char **error, int node_id,
1712+ int num_nodes, void *distributed_runtime_client) {
1713+ HeldPjRtClient *pjrt_client = pjrt_make_tpu_client_shared (tpu_path, error);
1714+ if (pjrt_client == nullptr )
1715+ return nullptr ;
1716+ return ifrt_pjrt_make_client (pjrt_client, node_id, num_nodes,
1717+ distributed_runtime_client, error, " tpu" );
16831718}
16841719
16851720extern " C" void ifrt_FreeClient (ifrt::Client *client) { delete client; }
@@ -1943,7 +1978,7 @@ extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array,
19431978
19441979#pragma endregion
19451980
1946- #pragma region PjRtDistributed
1981+ #pragma region xla::Distributed
19471982
19481983extern " C" HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *
19491984GetDistributedRuntimeClient (char *c_address, int32_t node_id,
0 commit comments