@@ -421,6 +421,10 @@ extern "C" const char *ClientGetPlatformName(PjRtClient *client) {
421421 return cstr_from_string (client->platform_name ());
422422}
423423
424+ extern " C" const char *DeviceGetKind (PjRtDevice *device) {
425+ return cstr_from_string (device->device_kind ());
426+ }
427+
424428// To keep in sync with JLAllocatorStats in src/XLA.jl
425429struct JLAllocatorStats {
426430 int64_t num_allocs;
@@ -1258,36 +1262,6 @@ reactant_release_pjrtbuffer(HeldValue<std::shared_ptr<PjRtBuffer>> *buffer) {
12581262 delete buffer;
12591263}
12601264
1261- extern " C" ifrt::Client *
1262- ifrt_pjrt_MakeClient (HeldValue<std::shared_ptr<PjRtClient>> *pjrt_client) {
1263- xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj ()};
1264- return MyValueOrThrow (xla::ifrt::PjRtClient::Create (options)).release ();
1265- }
1266-
1267- extern " C" ifrt::Client *MakeCPUIfrtClient (uint8_t asynchronous, int node_id,
1268- int num_nodes) {
1269- return ifrt_pjrt_MakeClient (reactant_hold_pjrtclient (
1270- MakeCPUClient (asynchronous, node_id, num_nodes)));
1271- }
1272-
1273- extern " C" ifrt::Client *
1274- MakeGPUIfrtClient (int node_id, int num_nodes, int *allowed_devices,
1275- int num_allowed_devices, double memory_fraction,
1276- bool preallocate, const char *platform_name,
1277- const char **error) {
1278- return ifrt_pjrt_MakeClient (reactant_hold_pjrtclient (
1279- MakeGPUClient (node_id, num_nodes, allowed_devices, num_allowed_devices,
1280- memory_fraction, preallocate, platform_name, error)));
1281- }
1282-
1283- extern " C" ifrt::Client *MakeTPUIfrtClient (const char *tpu_path,
1284- const char **error) {
1285- return ifrt_pjrt_MakeClient (
1286- reactant_hold_pjrtclient (MakeTPUClient (tpu_path, error)));
1287- }
1288-
1289- extern " C" void ifrt_FreeClient (ifrt::Client *client) { delete client; }
1290-
12911265extern " C" xla::ifrt::LoadedExecutable *
12921266ifrt_ClientCompile (ifrt::PjRtClient *client, MlirModule cmod, int64_t device_id,
12931267 bool is_sharded, const int64_t *mesh_ids,
@@ -1399,6 +1373,8 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
13991373 delete hlo_module;
14001374}
14011375
1376+ #pragma region IfRtClient
1377+
14021378// right now only making it available for TPU
14031379// in the future, we would like this for CPU and GPU PjRt backends too
14041380extern " C" ifrt::proxy::GrpcServer *
@@ -1469,6 +1445,79 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
14691445 .release ();
14701446}
14711447
1448+ extern " C" ifrt::Client *
1449+ ifrt_pjrt_MakeClient (HeldValue<std::shared_ptr<PjRtClient>> *pjrt_client) {
1450+ xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj ()};
1451+ return MyValueOrThrow (xla::ifrt::PjRtClient::Create (options)).release ();
1452+ }
1453+
1454+ extern " C" ifrt::Client *MakeCPUIfrtClient (uint8_t asynchronous, int node_id,
1455+ int num_nodes) {
1456+ return ifrt_pjrt_MakeClient (reactant_hold_pjrtclient (
1457+ MakeCPUClient (asynchronous, node_id, num_nodes)));
1458+ }
1459+
1460+ extern " C" ifrt::Client *
1461+ MakeGPUIfrtClient (int node_id, int num_nodes, int *allowed_devices,
1462+ int num_allowed_devices, double memory_fraction,
1463+ bool preallocate, const char *platform_name,
1464+ const char **error) {
1465+ return ifrt_pjrt_MakeClient (reactant_hold_pjrtclient (
1466+ MakeGPUClient (node_id, num_nodes, allowed_devices, num_allowed_devices,
1467+ memory_fraction, preallocate, platform_name, error)));
1468+ }
1469+
1470+ extern " C" ifrt::Client *MakeTPUIfrtClient (const char *tpu_path,
1471+ const char **error) {
1472+ return ifrt_pjrt_MakeClient (
1473+ reactant_hold_pjrtclient (MakeTPUClient (tpu_path, error)));
1474+ }
1475+
1476+ extern " C" void ifrt_FreeClient (ifrt::Client *client) { delete client; }
1477+
1478+ extern " C" int ifrt_ClientNumDevices (ifrt::Client *client) {
1479+ return client->device_count ();
1480+ }
1481+
1482+ extern " C" int ifrt_ClientNumAddressableDevices (ifrt::Client *client) {
1483+ return client->addressable_device_count ();
1484+ }
1485+
1486+ extern " C" int ifrt_ClientProcessIndex (ifrt::Client *client) {
1487+ return client->process_index ();
1488+ }
1489+
1490+ extern " C" const char *ifrt_ClientGetPlatformName (ifrt::Client *client) {
1491+ return cstr_from_string (client->platform_name ());
1492+ }
1493+
1494+ extern " C" ifrt::Device *ifrt_ClientGetDevice (ifrt::Client *client, int idx) {
1495+ return MyValueOrThrow (client->LookupDevice (ifrt::DeviceId (idx)));
1496+ }
1497+
1498+ extern " C" ifrt::Device *ifrt_ClientGetAddressableDevice (ifrt::Client *client,
1499+ int idx) {
1500+ return MyValueOrThrow (client->LookupAddressableDevice (idx));
1501+ }
1502+
1503+ #pragma endregion
1504+
1505+ #pragma region IfRtDevice
1506+
1507+ extern " C" int64_t ifrt_DeviceGetGlobalDeviceId (ifrt::Device *device) {
1508+ return device->Id ().value ();
1509+ }
1510+
1511+ extern " C" const char *ifrt_DeviceGetKind (ifrt::Device *device) {
1512+ return cstr_from_string (device->Kind ());
1513+ }
1514+
1515+ extern " C" ifrt::Client *ifrt_DeviceToClient (ifrt::Device *device) {
1516+ return device->client ();
1517+ }
1518+
1519+ #pragma endregion
1520+
14721521#pragma region HloSharding
14731522
14741523extern " C" void free_op_sharding (xla::OpSharding *op_sharding) {
0 commit comments