Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 6da5b9d

Browse files
committedFeb 21, 2025·
feat: initial low-level IFRT API
fix: ifrt HloSharding refactor: split up into IFRT/PJRT feat: IFRT Client APIs feat: IFRT Device API fix: remove global_ordinals feat: add devices list abstraction feat: wrap memory and memory kinds feat: ifrt::HloSharding now working fix: use new ABI chore: run formatter fix: no finalizer feat: initial draft of IFRT.Array interface (#774) * feat: initial draft of IFRT.Array interface * feat: Base.Array to ifrt::Array * feat: buffer to host chore: run formatter fix: bad rebase feat: more proxy servers feat: add ConcreteIFRTArray feat: add ConcreteIFRTNumber refactor: rename ConcreteRNumber to ConcretePJRTNumber revert: concreteifrtarray implementation chore: run formatter feat: ifrt loaded executable
1 parent 83a2c1d commit 6da5b9d

File tree

17 files changed

+744
-105
lines changed

17 files changed

+744
-105
lines changed
 

‎deps/ReactantExtra/API.cpp

Lines changed: 129 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@
7676

7777
// IFRT
7878
#include "xla/python/ifrt/array.h"
79+
#include "xla/python/ifrt/basic_device_list.h"
7980
#include "xla/python/ifrt/client.h"
8081
#include "xla/python/ifrt/compiler.h"
8182
#include "xla/python/ifrt/device.h"
8283
#include "xla/python/ifrt/device_list.h"
83-
#include "xla/python/ifrt/basic_device_list.h"
8484
#include "xla/python/ifrt/dtype.h"
8585
#include "xla/python/ifrt/executable.h"
8686
#include "xla/python/ifrt/hlo/hlo_program.h"
@@ -1469,6 +1469,10 @@ ifrt_pjrt_loaded_executable_dtor(xla::ifrt::PjRtLoadedExecutable *exec) {
14691469

14701470
extern "C" void ifrt_array_dtor(HeldIfrtArray *array) { delete array; }
14711471

1472+
extern "C" void ifrt_loaded_executable_dtor(ifrt::LoadedExecutable *exec) {
1473+
delete exec;
1474+
}
1475+
14721476
extern "C" void ifrt_loaded_executable_execute(
14731477
ifrt::LoadedExecutable *exec, int num_args,
14741478
HeldValue<tsl::RCReference<ifrt::Array>> **op_args,
@@ -1538,38 +1542,56 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
15381542

15391543
#pragma region IfRtClient
15401544

1541-
// right now only making it available for TPU
1542-
// in the future, we would like this for CPU and GPU PjRt backends too
15431545
extern "C" ifrt::proxy::GrpcServer *
1544-
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1545-
const char *c_address, const char *tpu_path, const char **error) {
1546+
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu(
1547+
const char *c_address, uint8_t asynchronous, int node_id, int num_nodes) {
15461548
std::string address = c_address;
15471549

1548-
// taken from `MakeTPUClient`
1549-
std::string tpu_library_path;
1550-
if (auto path = llvm::sys::Process::GetEnv(kEnvTpuLibraryPath)) {
1551-
tpu_library_path = *path;
1552-
} else if (tpu_path) {
1553-
tpu_library_path = std::string(tpu_path);
1554-
} else {
1555-
*error = "Could not find TPU path";
1556-
return nullptr;
1557-
}
1550+
return MyValueOrThrow(
1551+
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1552+
address,
1553+
[asynchronous, node_id, num_nodes]()
1554+
-> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1555+
auto pjrt_client = std::shared_ptr<PjRtClient>(
1556+
MakeCPUClient(asynchronous, node_id, num_nodes));
1557+
return std::shared_ptr<ifrt::Client>(
1558+
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1559+
}))
1560+
.release();
1561+
}
15581562

1559-
const PJRT_Api *pluginLoad =
1560-
LoadPjrtPlugin("tpu", tpu_library_path.c_str(), error);
1561-
if (pluginLoad == nullptr)
1562-
return nullptr;
1563-
auto tpu_status = InitializePjrtPlugin("tpu", error);
1564-
if (tpu_status)
1565-
return nullptr;
1563+
extern "C" ifrt::proxy::GrpcServer *
1564+
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu(
1565+
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1566+
double memory_fraction, bool preallocate, const char *platform_name,
1567+
const char **error) {
1568+
return MyValueOrThrow(
1569+
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1570+
std::string(),
1571+
[node_id, num_nodes, allowed_devices, num_allowed_devices,
1572+
memory_fraction, preallocate, platform_name,
1573+
error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1574+
auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient(
1575+
node_id, num_nodes, allowed_devices, num_allowed_devices,
1576+
memory_fraction, preallocate, platform_name, error));
1577+
return std::shared_ptr<ifrt::Client>(
1578+
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1579+
}))
1580+
.release();
1581+
}
1582+
1583+
extern "C" ifrt::proxy::GrpcServer *
1584+
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1585+
const char *c_address, const char *tpu_path, const char **error) {
1586+
std::string address = c_address;
15661587

15671588
return MyValueOrThrow(
15681589
xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
15691590
address,
1570-
[]() -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1571-
auto pjrt_client =
1572-
std::shared_ptr<xla::PjRtClient>(GetCApiClient("TPU"));
1591+
[tpu_path, error]()
1592+
-> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1593+
auto pjrt_client = std::shared_ptr<xla::PjRtClient>(
1594+
MakeTPUClient(tpu_path, error));
15731595
return std::shared_ptr<xla::ifrt::Client>(
15741596
xla::ifrt::PjRtClient::Create(pjrt_client).release());
15751597
}))
@@ -1604,28 +1626,28 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
16041626
nullptr, // callback `on_connection_update`
16051627
};
16061628
return MyValueOrThrow(
1607-
ifrt::proxy::CreateClient(c_proxy_server_address, options))
1629+
ifrt::proxy::CreateClient(proxy_server_address, options))
16081630
.release();
16091631
}
16101632

1611-
extern "C" ifrt::Client *ifrt_make_cpu_client(uint8_t asynchronous, int node_id,
1612-
int num_nodes) {
1633+
extern "C" ifrt::Client *ifrt_make_pjrt_cpu_client(uint8_t asynchronous,
1634+
int node_id, int num_nodes) {
16131635
return ifrt_pjrt_make_client(
16141636
pjrt_make_cpu_client_shared(asynchronous, node_id, num_nodes));
16151637
}
16161638

16171639
extern "C" ifrt::Client *
1618-
ifrt_make_gpu_client(int node_id, int num_nodes, int *allowed_devices,
1619-
int num_allowed_devices, double memory_fraction,
1620-
bool preallocate, const char *platform_name,
1621-
const char **error) {
1640+
ifrt_make_pjrt_gpu_client(int node_id, int num_nodes, int *allowed_devices,
1641+
int num_allowed_devices, double memory_fraction,
1642+
bool preallocate, const char *platform_name,
1643+
const char **error) {
16221644
return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared(
16231645
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
16241646
preallocate, platform_name, error));
16251647
}
16261648

1627-
extern "C" ifrt::Client *ifrt_make_tpu_client(const char *tpu_path,
1628-
const char **error) {
1649+
extern "C" ifrt::Client *ifrt_make_pjrt_tpu_client(const char *tpu_path,
1650+
const char **error) {
16291651
return ifrt_pjrt_make_client(pjrt_make_tpu_client_shared(tpu_path, error));
16301652
}
16311653

@@ -1815,4 +1837,77 @@ ifrt_hlo_sharding_to_string(ifrt::HloSharding *hlo_sharding) {
18151837
return cstr_from_string(hlo_sharding->DebugString());
18161838
}
18171839

1840+
extern "C" void
1841+
free_ifrt_sharding(HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1842+
delete sharding;
1843+
}
1844+
1845+
extern "C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1846+
ifrt_sharding_from_ifrt_hlo_sharding(ifrt::HloSharding *hlo_sharding) {
1847+
return reactant::capture(std::shared_ptr<ifrt::Sharding>(hlo_sharding));
1848+
}
1849+
1850+
extern "C" HeldValue<std::shared_ptr<ifrt::Sharding>> *
1851+
ifrt_sharding_from_hlo_sharding(
1852+
HeldValue<tsl::RCReference<ifrt::DeviceList>> *device_list,
1853+
ifrt::MemoryKind *memory_kind, xla::HloSharding *xla_hlo_sharding) {
1854+
return ifrt_sharding_from_ifrt_hlo_sharding(
1855+
ifrt_hlo_sharding_from_xla_hlo_sharding(device_list, memory_kind,
1856+
xla_hlo_sharding));
1857+
}
1858+
1859+
extern "C" const char *
1860+
ifrt_sharding_to_string(HeldValue<std::shared_ptr<ifrt::Sharding>> *sharding) {
1861+
return cstr_from_string(sharding->obj()->DebugString());
1862+
}
1863+
1864+
#pragma endregion
1865+
1866+
typedef ifrt::Future<> IfRtFutureType;
1867+
1868+
extern "C" void ifrt_free_future(IfRtFutureType *Future) { delete Future; }
1869+
1870+
extern "C" uint8_t ifrt_future_is_ready(IfRtFutureType *Future) {
1871+
return Future->IsReady();
1872+
}
1873+
1874+
extern "C" void ifrt_future_await(IfRtFutureType *Future) { Future->Await(); }
1875+
1876+
#pragma region IfRtArray
1877+
1878+
extern "C" void ifrt_free_array(HeldIfrtArray *array) { delete array; }
1879+
1880+
extern "C" int64_t *ifrt_array_shape(HeldIfrtArray *array) {
1881+
absl::Span<const long> dims = array->obj()->shape().dims();
1882+
int64_t *dims_ptr = new int64_t[dims.size()];
1883+
std::copy(dims.begin(), dims.end(), dims_ptr);
1884+
return dims_ptr;
1885+
}
1886+
1887+
extern "C" int64_t ifrt_array_ndims(HeldIfrtArray *array) {
1888+
return array->obj()->shape().dims().size();
1889+
}
1890+
1891+
extern "C" ifrt::DType ifrt_array_eltype(HeldIfrtArray *array) {
1892+
return array->obj()->dtype();
1893+
}
1894+
1895+
extern "C" ifrt::Client *ifrt_array_to_client(HeldIfrtArray *array) {
1896+
return array->obj()->client();
1897+
}
1898+
1899+
extern "C" HeldValue<std::shared_ptr<const ifrt::Sharding>> *
1900+
ifrt_array_to_sharding(HeldIfrtArray *array) {
1901+
return reactant::capture(array->obj()->shared_ptr_sharding());
1902+
}
1903+
1904+
extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array,
1905+
void *data) {
1906+
std::optional<absl::Span<const int64_t>> byte_strides;
1907+
auto future = array->obj()->CopyToHostBuffer(
1908+
data, byte_strides, static_cast<ifrt::ArrayCopySemantics>(0));
1909+
future.Await();
1910+
return;
1911+
}
1912+
18181913
#pragma endregion

‎src/xla/Buffer.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@ function buffer_on_cpu end
55
function to_host end
66
function unsafe_buffer_pointer end
77
function copy_buffer_to_device end
8+
function sharding end
9+
10+
Base.convert(::Type{Array}, buffer::AbstractBuffer) = convert(Array{eltype(buffer)}, buffer)
11+
12+
function Base.convert(::Type{<:Array{T}}, buffer::AbstractBuffer) where {T}
13+
arr = zeros(T, reverse(size(buffer))...)
14+
XLA.to_host(buffer, arr)
15+
return arr
16+
end
817

918
@inline function client(
1019
buffers::Union{Array{<:AbstractBuffer},NTuple{<:Any,AbstractBuffer}}
@@ -19,3 +28,48 @@ end
1928
)
2029
return map(synced_buffer, buffers)
2130
end
31+
32+
function Base.show(io::IO, mime::MIME"text/plain", buffer::B) where {B<:AbstractBuffer}
33+
print(io, "$(B) storing ")
34+
show(io, mime, convert(Array, buffer))
35+
return nothing
36+
end
37+
38+
# Async Buffers
39+
abstract type AbstractAsyncBuffer <: AbstractBuffer end
40+
41+
Base.isempty(buffer::AbstractAsyncBuffer) = buffer.buffer.buffer == C_NULL
42+
43+
function Base.convert(T::Type{Array}, buffer::AbstractAsyncBuffer)
44+
XLA.await(buffer)
45+
return convert(T, buffer.buffer)
46+
end
47+
48+
function Base.convert(T::Type{<:Array{T1}}, buffer::AbstractAsyncBuffer) where {T1}
49+
XLA.await(buffer)
50+
return convert(T, buffer.buffer)
51+
end
52+
53+
for op in (:(Base.ndims), :(Base.size), :(Base.eltype), :device, :client, :sharding)
54+
@eval $op(buffer::AbstractAsyncBuffer) = $op(buffer.buffer)
55+
end
56+
57+
function XLA.synced_buffer(buffer::AbstractAsyncBuffer)
58+
XLA.await(buffer)
59+
return buffer.buffer
60+
end
61+
62+
function XLA.await(buffer::AbstractAsyncBuffer)
63+
buffer.future === nothing && return nothing
64+
future = buffer.future
65+
buffer.future = nothing
66+
XLA.await(future)
67+
return nothing
68+
end
69+
70+
function XLA.is_ready(buffer::AbstractAsyncBuffer)
71+
buffer.future === nothing && return true
72+
return XLA.is_ready(buffer.future)
73+
end
74+
75+
XLA.buffer_on_cpu(buffer::AbstractAsyncBuffer) = XLA.buffer_on_cpu(buffer.buffer)

‎src/xla/IFRT/Array.jl

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
mutable struct Array <: XLA.AbstractBuffer
2+
buffer::Ptr{Cvoid}
3+
4+
function Array(buffer::Ptr{Cvoid})
5+
return finalizer(free_ifrt_array, new(buffer))
6+
end
7+
end
8+
9+
function Array(client::Client, array::Base.Array{T,N}, device::Device) where {T,N}
10+
sizear = collect(Int64, reverse(size(array)))
11+
buffer = GC.@preserve array sizear begin
12+
@ccall MLIR.API.mlir_c.ifrt_client_make_single_shard_array_from_host_buffer(
13+
client.client::Ptr{Cvoid},
14+
pointer(array)::Ptr{T},
15+
XLA.primitive_type(T)::UInt64,
16+
N::Csize_t,
17+
pointer(sizear)::Ptr{Int64},
18+
0::Cint, # kAlwaysCopy
19+
device.device::Ptr{Cvoid},
20+
string(convert(MemoryKind, XLA.default_memory(device)))::Cstring,
21+
)::Ptr{Cvoid}
22+
end
23+
return Array(buffer)
24+
end
25+
26+
function Array(client::Client, array::Base.Array{T,N}, sharding::HloSharding) where {T,N}
27+
return Array(client, array, convert(Sharding, sharding))
28+
end
29+
30+
function Array(client::Client, array::Base.Array{T,N}, sharding::Sharding) where {T,N}
31+
sizear = collect(Int64, reverse(size(array)))
32+
buffer = GC.@preserve array sizear begin
33+
@ccall MLIR.API.mlir_c.ifrt_client_make_array_from_host_buffer(
34+
client.client::Ptr{Cvoid},
35+
pointer(array)::Ptr{T},
36+
XLA.primitive_type(T)::Cint,
37+
N::Csize_t,
38+
pointer(sizear)::Ptr{Int64},
39+
sharding.ptr::Ptr{Cvoid},
40+
0::Cint, # kAlwaysCopy
41+
)::Ptr{Cvoid}
42+
end
43+
return Array(buffer)
44+
end
45+
46+
@inline function free_ifrt_array(buffer::Array)
47+
sbuffer = buffer.buffer
48+
if sbuffer != C_NULL
49+
@ccall MLIR.API.mlir_c.ifrt_free_array(sbuffer::Ptr{Cvoid})::Cvoid
50+
end
51+
end
52+
53+
function Base.ndims(buffer::Array)
54+
GC.@preserve buffer begin
55+
return @ccall MLIR.API.mlir_c.ifrt_array_ndims(buffer.buffer::Ptr{Cvoid})::Int64
56+
end
57+
end
58+
59+
function Base.size(buffer::Array)
60+
GC.@preserve buffer begin
61+
sz = @ccall MLIR.API.mlir_c.ifrt_array_shape(buffer.buffer::Ptr{Cvoid})::Ptr{Int64}
62+
end
63+
return Tuple(unsafe_wrap(Base.Array, sz, ndims(buffer)))
64+
end
65+
66+
function Base.eltype(buffer::Array)
67+
GC.@preserve buffer begin
68+
return XLA.julia_type(
69+
@ccall MLIR.API.mlir_c.ifrt_array_eltype(buffer.buffer::Ptr{Cvoid})::Cint
70+
)
71+
end
72+
end
73+
74+
function XLA.device(::Array)
75+
return error("IFRT.Array can be sharded/replicated across multiple devices. Hence, \
76+
`XLA.device` is not defined.")
77+
end
78+
79+
function XLA.client(buffer::Array)
80+
GC.@preserve buffer begin
81+
return Client(
82+
@ccall MLIR.API.mlir_c.ifrt_array_to_client(
83+
buffer.buffer::Ptr{Cvoid}
84+
)::Ptr{Cvoid}
85+
)
86+
end
87+
end
88+
89+
XLA.synced_buffer(buffer::Array) = buffer
90+
91+
function XLA.buffer_on_cpu(::Array)
92+
return error("IFRT.Array does not support `XLA.buffer_on_cpu`")
93+
end
94+
95+
function XLA.to_host(buffer::Array, data)
96+
GC.@preserve buffer data begin
97+
@ccall MLIR.API.mlir_c.ifrt_array_copy_to_host_buffer(
98+
buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid}
99+
)::Cvoid
100+
end
101+
return nothing
102+
end
103+
104+
function XLA.unsafe_buffer_pointer(::Array)
105+
return error("IFRT.Array does not support `XLA.unsafe_buffer_pointer`")
106+
end
107+
108+
function XLA.copy_buffer_to_device(::Array, ::Device)
109+
return error("IFRT.Array does not support `XLA.copy_buffer_to_device`")
110+
end
111+
112+
function XLA.sharding(buffer::Array)
113+
GC.@preserve buffer begin
114+
return Sharding(
115+
@ccall MLIR.API.mlir_c.ifrt_array_to_sharding(
116+
buffer.buffer::Ptr{Cvoid}
117+
)::Ptr{Cvoid}
118+
)
119+
end
120+
end

‎src/xla/IFRT/AsyncArray.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
mutable struct AsyncArray <: XLA.AbstractAsyncBuffer
2+
buffer::Array
3+
future::Union{Future,Nothing}
4+
end
5+
6+
const AsyncEmptyArray = AsyncArray(Array(C_NULL), nothing)
7+
8+
AsyncArray(args...; kwargs...) = AsyncArray(Array(args...; kwargs...), nothing)

‎src/xla/IFRT/Client.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
mutable struct Client <: XLA.AbstractClient
2+
client::Ptr{Cvoid}
3+
4+
function Client(client::Ptr{Cvoid})
5+
@assert client != C_NULL
6+
return new(client)
7+
end
8+
end
9+
10+
function XLA.free_client(client::Client)
11+
GC.@preserve client begin
12+
@ccall MLIR.API.mlir_c.ifrt_FreeClient(client.client::Ptr{Cvoid})::Cvoid
13+
end
14+
end
15+
16+
function XLA.num_devices(client::Client)
17+
GC.@preserve client begin
18+
return @ccall MLIR.API.mlir_c.ifrt_client_device_count(
19+
client.client::Ptr{Cvoid}
20+
)::Cint
21+
end
22+
end
23+
24+
function XLA.num_addressable_devices(client::Client)
25+
GC.@preserve client begin
26+
return @ccall MLIR.API.mlir_c.ifrt_client_addressable_device_count(
27+
client.client::Ptr{Cvoid}
28+
)::Cint
29+
end
30+
end
31+
32+
function XLA.process_index(client::Client)
33+
GC.@preserve client begin
34+
return @ccall MLIR.API.mlir_c.ifrt_ClientProcessIndex(
35+
client.client::Ptr{Cvoid}
36+
)::Cint
37+
end
38+
end
39+
40+
function XLA.get_device(client::Client, idx)
41+
GC.@preserve client begin
42+
return Device(
43+
@ccall MLIR.API.mlir_c.ifrt_client_lookup_device(
44+
client.client::Ptr{Cvoid}, idx::Cint
45+
)::Ptr{Cvoid}
46+
)
47+
end
48+
end
49+
50+
function XLA.get_addressable_device(client::Client, idx)
51+
GC.@preserve client begin
52+
return Device(
53+
@ccall MLIR.API.mlir_c.ifrt_client_lookup_addressable_device(
54+
client.client::Ptr{Cvoid}, idx::Cint
55+
)::Ptr{Cvoid}
56+
)
57+
end
58+
end
59+
60+
function XLA.platform_name(client::Client)
61+
GC.@preserve client begin
62+
str = @ccall MLIR.API.mlir_c.ifrt_ClientGetPlatformName(
63+
client.client::Ptr{Cvoid}
64+
)::Cstring
65+
end
66+
return XLA.unsafe_string_and_free(str)
67+
end
68+
69+
# Different Backends
70+
const cpu_client_count = Ref(0)
71+
const gpu_client_count = Ref(0)
72+
const tpu_client_count = Ref(0)
73+
74+
# XXX: We need other backends to support sharding
75+
for (backend, fname, counter) in (
76+
(:CPUClient, "ifrt_make_pjrt_cpu_client", :cpu_client_count),
77+
(:GPUClient, "ifrt_make_pjrt_gpu_client", :gpu_client_count),
78+
(:TPUClient, "ifrt_make_pjrt_tpu_client", :tpu_client_count),
79+
)
80+
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
81+
if checkcount
82+
@assert $(counter)[] == 0
83+
$(counter)[] += 1
84+
end
85+
return Client(XLA.$(backend)($(fname), args...; kwargs...))
86+
end
87+
end

‎src/xla/IFRT/Device.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
struct Device <: XLA.AbstractDevice
2+
device::Ptr{Cvoid}
3+
end
4+
5+
function XLA.client(device::Device)
6+
GC.@preserve device begin
7+
return Client(
8+
@ccall MLIR.API.mlir_c.ifrt_DeviceToClient(
9+
device.device::Ptr{Cvoid}
10+
)::Ptr{Cvoid}
11+
)
12+
end
13+
end
14+
15+
function XLA.device_ordinal(device::Device)
16+
GC.@preserve device begin
17+
return @ccall MLIR.API.mlir_c.ifrt_DeviceGetGlobalDeviceId(
18+
device.device::Ptr{Cvoid}
19+
)::Int64
20+
end
21+
end
22+
23+
function XLA.device_kind(device::Device)
24+
GC.@preserve device begin
25+
str = @ccall MLIR.API.mlir_c.ifrt_DeviceGetKind(device.device::Ptr{Cvoid})::Cstring
26+
end
27+
return XLA.unsafe_string_and_free(str)
28+
end
29+
30+
function XLA.get_local_device_id(::Device)
31+
return error("Not implemented for ifrt devices")
32+
end
33+
34+
function XLA.default_memory(device::Device)
35+
GC.@preserve device begin
36+
return Memory(
37+
@ccall MLIR.API.mlir_c.ifrt_DeviceGetDefaultMemory(
38+
device.device::Ptr{Cvoid}
39+
)::Ptr{Cvoid}
40+
)
41+
end
42+
end
43+
44+
function XLA.memories(device::Device)
45+
memories_size = Ref{Int32}(0)
46+
GC.@preserve device memories_size begin
47+
ptr = @ccall MLIR.API.mlir_c.ifrt_DeviceGetMemories(
48+
device.device::Ptr{Cvoid}, memories_size::Ptr{Int32}
49+
)::Ptr{Ptr{Cvoid}}
50+
end
51+
memories = Vector{Memory}(undef, memories_size[])
52+
for i in 1:memories_size[]
53+
memories[i] = Memory(unsafe_load(ptr, i))
54+
end
55+
return memories
56+
end
57+
58+
# Device List
59+
## TODO: This is semi-deprecated in openxla. At some point we want to just replace this with
60+
## a simple vector of devices
61+
struct BasicDeviceList <: AbstractVector{Device}
62+
ptr::Ptr{Cvoid}
63+
64+
function BasicDeviceList(devices::AbstractVector{Device})
65+
GC.@preserve devices begin
66+
ptr = @ccall MLIR.API.mlir_c.ifrt_CreateBasicDeviceListFromDevices(
67+
[d.device for d in devices]::Ptr{Ptr{Cvoid}}, length(devices)::Int32
68+
)::Ptr{Cvoid}
69+
end
70+
return new(ptr)
71+
end
72+
end
73+
74+
function Base.getindex(device_list::BasicDeviceList, index::Integer)
75+
if !(1 index length(device_list))
76+
throw(BoundsError(device_list, index))
77+
end
78+
GC.@preserve device_list begin
79+
device_ptr = @ccall MLIR.API.mlir_c.ifrt_BasicDeviceListGetDevice(
80+
device_list.ptr::Ptr{Cvoid}, (index - 1)::Int32
81+
)::Ptr{Cvoid}
82+
end
83+
return Device(device_ptr)
84+
end
85+
86+
function Base.size(device_list::BasicDeviceList)
87+
GC.@preserve device_list begin
88+
len = @ccall MLIR.API.mlir_c.ifrt_BasicDeviceListSize(
89+
device_list.ptr::Ptr{Cvoid}
90+
)::Int32
91+
end
92+
return (len,)
93+
end
94+
95+
function Base.string(device_list::BasicDeviceList)
96+
GC.@preserve device_list begin
97+
str = @ccall MLIR.API.mlir_c.ifrt_BasicDeviceListToString(
98+
device_list.ptr::Ptr{Cvoid}
99+
)::Cstring
100+
end
101+
return XLA.unsafe_string_and_free(str)
102+
end
103+
104+
function XLA.default_memory(device_list::AbstractVector{Device})
105+
default_memories = XLA.default_memory.(device_list)
106+
default_memory_kinds = convert.(MemoryKind, default_memories)
107+
if !allequal(default_memory_kinds)
108+
error("All devices must have the same default memory")
109+
end
110+
return first(default_memories)
111+
end

‎src/xla/IFRT/Future.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
mutable struct Future <: XLA.AbstractFuture
2+
future::Ptr{Cvoid}
3+
4+
function Future(future::Ptr{Cvoid})
5+
@assert future != C_NULL
6+
return finalizer(free_future, new(future))
7+
end
8+
end
9+
10+
@inline function free_future(future::Future)
11+
@ccall MLIR.API.mlir_c.ifrt_free_future(future.future::Ptr{Cvoid})::Cvoid
12+
end
13+
14+
function XLA.is_ready(future::Future)
15+
GC.@preserve future begin
16+
return (@ccall MLIR.API.mlir_c.ifrt_future_is_ready(
17+
future.future::Ptr{Cvoid}
18+
)::UInt8) != 0
19+
end
20+
end
21+
22+
@inline function XLA.await(future::Future)::Nothing
23+
GC.@preserve future begin
24+
@ccall MLIR.API.mlir_c.ifrt_future_await(future.future::Ptr{Cvoid})::Cvoid
25+
end
26+
return nothing
27+
end

‎src/xla/IFRT/IFRT.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module IFRT
2+
3+
using ..Reactant: Reactant, MLIR
4+
using ..XLA: XLA
5+
6+
include("Client.jl")
7+
include("Device.jl")
8+
include("Memory.jl")
9+
include("Future.jl")
10+
include("Sharding.jl")
11+
include("Array.jl")
12+
include("AsyncArray.jl")
13+
include("LoadedExecutable.jl")
14+
15+
end

‎src/xla/IFRT/LoadedExecutable.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
mutable struct LoadedExecutable <: XLA.AbstractLoadedExecutable
2+
exec::Ptr{Cvoid}
3+
4+
function LoadedExecutable(ptr::Ptr{Cvoid})
5+
@assert ptr != C_NULL
6+
return finalizer(free_exec, new(ptr))
7+
end
8+
end
9+
10+
function free_exec(exec::LoadedExecutable)
11+
GC.@preserve exec begin
12+
@ccall MLIR.API.mlir_c.ifrt_loaded_executable_dtor(exec.exec::Ptr{Cvoid})::Cvoid
13+
end
14+
end

‎src/xla/IFRT/Memory.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
mutable struct Memory <: XLA.AbstractMemory
2+
ptr::Ptr{Cvoid}
3+
end
4+
5+
function Base.show(io::IO, memory::Memory)
6+
GC.@preserve memory begin
7+
str = @ccall MLIR.API.mlir_c.ifrt_MemoryToString(memory.ptr::Ptr{Cvoid})::Cstring
8+
end
9+
print(io, "XLA.IFRT.Memory(\"", XLA.unsafe_string_and_free(str), "\")")
10+
return nothing
11+
end
12+
13+
mutable struct MemoryKind <: XLA.AbstractMemoryKind
14+
ptr::Ptr{Cvoid}
15+
end
16+
17+
function Base.convert(::Type{MemoryKind}, memory::Memory)
18+
GC.@preserve memory begin
19+
return MemoryKind(
20+
@ccall MLIR.API.mlir_c.ifrt_MemoryGetMemoryKind(
21+
memory.ptr::Ptr{Cvoid}
22+
)::Ptr{Cvoid}
23+
)
24+
end
25+
end
26+
27+
function Base.:(==)(a::MemoryKind, b::MemoryKind)
28+
GC.@preserve a b begin
29+
return @ccall MLIR.API.mlir_c.ifrt_MemoryKindsAreEqual(
30+
a.ptr::Ptr{Cvoid}, b.ptr::Ptr{Cvoid}
31+
)::Bool
32+
end
33+
end
34+
35+
function Base.string(memory_kind::MemoryKind)
36+
GC.@preserve memory_kind begin
37+
str = @ccall MLIR.API.mlir_c.ifrt_MemoryKindToString(
38+
memory_kind.ptr::Ptr{Cvoid}
39+
)::Cstring
40+
end
41+
return XLA.unsafe_string_and_free(str)
42+
end
43+
44+
function Base.show(io::IO, memory_kind::MemoryKind)
45+
print(io, "XLA.IFRT.MemoryKind(\"", string(memory_kind), "\")")
46+
return nothing
47+
end

‎src/xla/IFRT/Sharding.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# xla::ifrt::HloSharding (distinct from xla::HloSharding)
2+
mutable struct HloSharding
3+
ptr::Ptr{Cvoid}
4+
5+
function HloSharding(ptr::Ptr{Cvoid})
6+
@assert ptr != C_NULL
7+
return finalizer(free_hlo_sharding, new(ptr))
8+
end
9+
end
10+
11+
function free_hlo_sharding(hlo_sharding::HloSharding)
12+
@ccall MLIR.API.mlir_c.free_ifrt_hlo_sharding(hlo_sharding.ptr::Ptr{Cvoid})::Cvoid
13+
end
14+
15+
function HloSharding(device_list::BasicDeviceList, xla_hlo_sharding::XLA.HloSharding)
16+
default_memory_kind = convert(MemoryKind, XLA.default_memory(device_list))
17+
GC.@preserve device_list default_memory_kind xla_hlo_sharding begin
18+
return HloSharding(
19+
@ccall MLIR.API.mlir_c.ifrt_hlo_sharding_from_xla_hlo_sharding(
20+
device_list.ptr::Ptr{Cvoid},
21+
default_memory_kind.ptr::Ptr{Cvoid},
22+
xla_hlo_sharding.ptr::Ptr{Cvoid},
23+
)::Ptr{Cvoid}
24+
)
25+
end
26+
end
27+
28+
function Base.string(hlo_sharding::HloSharding)
29+
GC.@preserve hlo_sharding begin
30+
str = @ccall MLIR.API.mlir_c.ifrt_hlo_sharding_to_string(
31+
hlo_sharding.ptr::Ptr{Cvoid}
32+
)::Cstring
33+
end
34+
return XLA.unsafe_string_and_free(str)
35+
end
36+
37+
function Base.show(io::IO, ::MIME"text/plain", hlo_sharding::HloSharding)
38+
print(io, "XLA.IFRT.HloSharding(\"", string(hlo_sharding), "\")")
39+
return nothing
40+
end
41+
42+
# HloSharding is more specific than Sharding. But Sharding is a neater way to deal with
43+
# most of the IFRT APIs.
44+
mutable struct Sharding
45+
ptr::Ptr{Cvoid}
46+
47+
function Sharding(ptr::Ptr{Cvoid})
48+
@assert ptr != C_NULL
49+
return finalizer(free_sharding, new(ptr))
50+
end
51+
end
52+
53+
function Sharding(device_list::BasicDeviceList, xla_hlo_sharding::XLA.HloSharding)
54+
return convert(Sharding, HloSharding(device_list, xla_hlo_sharding))
55+
end
56+
57+
function free_sharding(sharding::Sharding)
58+
@ccall MLIR.API.mlir_c.free_ifrt_sharding(sharding.ptr::Ptr{Cvoid})::Cvoid
59+
end
60+
61+
function Base.convert(::Type{Sharding}, hlo_sharding::HloSharding)
62+
GC.@preserve hlo_sharding begin
63+
return Sharding(
64+
@ccall MLIR.API.mlir_c.ifrt_sharding_from_ifrt_hlo_sharding(
65+
hlo_sharding.ptr::Ptr{Cvoid}
66+
)::Ptr{Cvoid}
67+
)
68+
end
69+
end
70+
71+
function Base.string(sharding::Sharding)
72+
GC.@preserve sharding begin
73+
str = @ccall MLIR.API.mlir_c.ifrt_sharding_to_string(
74+
sharding.ptr::Ptr{Cvoid}
75+
)::Cstring
76+
end
77+
return XLA.unsafe_string_and_free(str)
78+
end
79+
80+
function Base.show(io::IO, ::MIME"text/plain", sharding::Sharding)
81+
print(io, "XLA.IFRT.Sharding(\"", string(sharding), "\")")
82+
return nothing
83+
end

‎src/xla/PJRT/AsyncBuffer.jl

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,8 @@
1-
mutable struct AsyncBuffer <: XLA.AbstractBuffer
1+
mutable struct AsyncBuffer <: XLA.AbstractAsyncBuffer
22
buffer::Buffer
33
future::Union{Future,Nothing}
44
end
55

66
const AsyncEmptyBuffer = AsyncBuffer(Buffer(C_NULL), nothing)
77

8-
function AsyncBuffer(client::Client, array::Array{T,N}, device::Device) where {T,N}
9-
return AsyncBuffer(Buffer(client, array, device), nothing)
10-
end
11-
12-
Base.isempty(buffer::AsyncBuffer) = buffer == AsyncEmptyBuffer
13-
14-
function Base.convert(::Type{<:Array{T}}, buffer::AsyncBuffer) where {T}
15-
XLA.await(buffer)
16-
return convert(Array{T}, buffer.buffer)
17-
end
18-
19-
for op in (:(Base.ndims), :(Base.size), :device, :client)
20-
@eval $op(buffer::AsyncBuffer) = $op(buffer.buffer)
21-
end
22-
23-
function XLA.synced_buffer(buffer::AsyncBuffer)
24-
XLA.await(buffer)
25-
return buffer.buffer
26-
end
27-
28-
function XLA.await(buffer::AsyncBuffer)
29-
buffer.future === nothing && return nothing
30-
future = buffer.future
31-
buffer.future = nothing
32-
XLA.await(future)
33-
return nothing
34-
end
35-
36-
function XLA.is_ready(buffer::AsyncBuffer)
37-
buffer.future === nothing && return true
38-
return XLA.is_ready(buffer.future)
39-
end
40-
41-
XLA.buffer_on_cpu(buffer::AsyncBuffer) = XLA.buffer_on_cpu(buffer.buffer)
42-
43-
XLA.client(buffer::AsyncBuffer) = XLA.client(buffer.buffer)
44-
XLA.device(buffer::AsyncBuffer) = XLA.device(buffer.buffer)
8+
AsyncBuffer(args...; kwargs...) = AsyncBuffer(Buffer(args...; kwargs...), nothing)

‎src/xla/PJRT/Buffer.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,15 @@ function Base.size(buffer::Buffer)
3838
GC.@preserve buffer begin
3939
sz = @ccall MLIR.API.mlir_c.BufferShape(buffer.buffer::Ptr{Cvoid})::Ptr{Int64}
4040
end
41-
return [unsafe_load(sz, i) for i in 1:ndims(buffer)]
41+
return Tuple(unsafe_wrap(Array, sz, ndims(buffer)))
42+
end
43+
44+
function Base.eltype(buffer::Buffer)
45+
GC.@preserve buffer begin
46+
return XLA.julia_type(
47+
@ccall MLIR.API.mlir_c.BufferPrimitiveType(buffer.buffer::Ptr{Cvoid})::Cint
48+
)
49+
end
4250
end
4351

4452
function XLA.device(buffer::Buffer)
@@ -65,12 +73,6 @@ function XLA.buffer_on_cpu(buffer::Buffer)
6573
end
6674
end
6775

68-
function Base.convert(::Type{<:Array{T}}, buffer::Buffer) where {T}
69-
arr = zeros(T, reverse(size(buffer))...)
70-
XLA.to_host(buffer, arr)
71-
return arr
72-
end
73-
7476
function XLA.to_host(buffer::Buffer, data)
7577
GC.@preserve buffer begin
7678
@ccall MLIR.API.mlir_c.BufferToHost(
@@ -94,3 +96,5 @@ function XLA.copy_buffer_to_device(buffer::Buffer, dev::Device)
9496
)
9597
end
9698
end
99+
100+
XLA.sharding(::Buffer) = error("PJRT Buffers are not sharded.")

‎src/xla/Utils.jl

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,37 +13,40 @@ function reactant_err(msg::Cstring)::Cvoid
1313
end
1414

1515
# https://github.com/openxla/xla/blob/4bfb5c82a427151d6fe5acad8ebe12cee403036a/xla/xla_data.proto#L29
16-
@inline primitive_type(::Type{Bool}) = 1
17-
18-
@inline primitive_type(::Type{Int8}) = 2
19-
@inline primitive_type(::Type{UInt8}) = 6
20-
21-
@inline primitive_type(::Type{Int16}) = 3
22-
@inline primitive_type(::Type{UInt16}) = 7
23-
24-
@inline primitive_type(::Type{Int32}) = 4
25-
@inline primitive_type(::Type{UInt32}) = 8
26-
27-
@inline primitive_type(::Type{Int64}) = 5
28-
@inline primitive_type(::Type{UInt64}) = 9
29-
30-
@inline primitive_type(::Type{Float16}) = 10
31-
@inline primitive_type(::Type{Float32}) = 11
32-
33-
@inline primitive_type(::Type{Reactant.F8E5M2}) = 19
34-
@inline primitive_type(::Type{Reactant.F8E4M3FN}) = 20
35-
@inline primitive_type(::Type{Reactant.F8E4M3B11FNUZ}) = 23
36-
@inline primitive_type(::Type{Reactant.F8E5M2FNUZ}) = 24
37-
@inline primitive_type(::Type{Reactant.F8E4M3FNUZ}) = 25
16+
primitive_types_list = [
17+
(1, Bool),
18+
(2, Int8),
19+
(6, UInt8),
20+
(3, Int16),
21+
(7, UInt16),
22+
(4, Int32),
23+
(8, UInt32),
24+
(5, Int64),
25+
(9, UInt64),
26+
(10, Float16),
27+
(11, Float32),
28+
(19, Reactant.F8E5M2),
29+
(20, Reactant.F8E4M3FN),
30+
(23, Reactant.F8E4M3B11FNUZ),
31+
(24, Reactant.F8E5M2FNUZ),
32+
(25, Reactant.F8E4M3FNUZ),
33+
(12, Float64),
34+
(15, Complex{Float32}),
35+
(18, Complex{Float64}),
36+
]
3837

3938
@static if isdefined(Core, :BFloat16)
40-
@inline primitive_type(::Type{Core.BFloat16}) = 16
39+
push!(primitive_types_list, (16, Core.BFloat16))
4140
end
4241

43-
@inline primitive_type(::Type{Float64}) = 12
42+
for (int_val, jl_type) in primitive_types_list
43+
@eval begin
44+
@inline primitive_type(::Type{$(jl_type)}) = $(int_val)
45+
@inline julia_type(::Val{$(int_val)}) = $(jl_type)
46+
end
47+
end
4448

45-
@inline primitive_type(::Type{Complex{Float32}}) = 15
46-
@inline primitive_type(::Type{Complex{Float64}}) = 18
49+
@inline julia_type(@nospecialize(x::Integer)) = julia_type(Val(Int64(x)))
4750

4851
function unsafe_string_and_free(str::Cstring, args...)
4952
str_jl = unsafe_string(str, args...)

‎src/xla/XLA.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ include("Memory.jl")
3333

3434
include("PJRT/PJRT.jl")
3535

36+
include("IFRT/IFRT.jl")
37+
3638
const backends = Dict{String,PJRT.Client}()
3739
const default_backend = Ref{PJRT.Client}()
3840
const default_device_idx = Ref{Int}(0)

‎test/ifrt_manual.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Testing manual IFRT buffer creation + compilation + execution
2+
using Reactant
3+
using Reactant: XLA
4+
using Reactant.XLA: IFRT

‎test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
6262
@safetestset "Custom Number Types" include("custom_number_types.jl")
6363
end
6464
@safetestset "Sharding" include("sharding.jl")
65+
@safetestset "IFRT Low-Level API" include("ifrt_manual.jl")
6566
end
6667

6768
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"

0 commit comments

Comments
 (0)
Please sign in to comment.