8181
8282// IFRT
8383#include " xla/python/ifrt/array.h"
84- #include " xla/python/ifrt/basic_device_list.h"
8584#include " xla/python/ifrt/client.h"
8685#include " xla/python/ifrt/compiler.h"
8786#include " xla/python/ifrt/device.h"
@@ -1799,26 +1798,10 @@ extern "C" ifrt::Client *ifrt_DeviceToClient(ifrt::Device *device) {
17991798 return device->client ();
18001799}
18011800
1802- extern " C" HeldValue<tsl::RCReference<ifrt::DeviceList>> *
1803- ifrt_CreateBasicDeviceListFromDevices (ifrt::Device **device_list,
1804- int32_t num_devices) {
1801+ tsl::RCReference<ifrt::DeviceList> ifrt_CreateDeviceListFromDevices (
1802+ ifrt::Client *client, ifrt::Device **device_list, int32_t num_devices) {
18051803 absl::Span<ifrt::Device *const > devices (device_list, num_devices);
1806- return reactant::capture (ifrt::BasicDeviceList::Create (devices));
1807- }
1808-
1809- extern " C" const char *ifrt_BasicDeviceListToString (
1810- HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list) {
1811- return cstr_from_string (device_list->obj ()->DebugString ());
1812- }
1813-
1814- extern " C" int ifrt_BasicDeviceListSize (
1815- HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list) {
1816- return device_list->obj ()->size ();
1817- }
1818-
1819- extern " C" ifrt::Device *const ifrt_BasicDeviceListGetDevice (
1820- HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list, int32_t index) {
1821- return device_list->obj ()->devices ()[index];
1804+ return client->MakeDeviceList (devices);
18221805}
18231806
18241807extern " C" ifrt::Memory *ifrt_DeviceGetDefaultMemory (ifrt::Device *device) {
@@ -1888,10 +1871,11 @@ hlo_sharding_to_string(const xla::HloSharding *hlo_sharding) {
18881871}
18891872
18901873extern " C" ifrt::HloSharding *ifrt_hlo_sharding_from_xla_hlo_sharding (
1891- HeldValue<tsl::RCReference< ifrt::DeviceList>> * device_list,
1874+ ifrt::Client *client, ifrt::Device ** device_list, int32_t num_devices ,
18921875 ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1893- return ifrt::HloSharding::Create (device_list->obj (), *memory_kind,
1894- *xla_hlo_sharding)
1876+ return ifrt::HloSharding::Create (
1877+ ifrt_CreateDeviceListFromDevices (client, device_list, num_devices),
1878+ *memory_kind, *xla_hlo_sharding)
18951879 .release ();
18961880}
18971881
@@ -1918,12 +1902,13 @@ ifrt_sharding_from_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) {
19181902}
19191903
19201904extern " C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1921- ifrt_sharding_from_hlo_sharding (
1922- HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list,
1923- ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1905+ ifrt_sharding_from_hlo_sharding (ifrt::Client *client,
1906+ ifrt::Device **device_list, int32_t num_devices,
1907+ ifrt::MemoryKind *memory_kind,
1908+ xla::HloSharding *xla_hlo_sharding) {
19241909 return ifrt_sharding_from_ifrt_hlo_sharding (
1925- ifrt_hlo_sharding_from_xla_hlo_sharding (device_list, memory_kind ,
1926- xla_hlo_sharding));
1910+ ifrt_hlo_sharding_from_xla_hlo_sharding (client, device_list, num_devices ,
1911+ memory_kind, xla_hlo_sharding));
19271912}
19281913
19291914extern " C" const char *
0 commit comments