Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ca8f3da

Browse files
committedFeb 21, 2025·
feat: construct IFRT clients with distributed options
1 parent 408e120 commit ca8f3da

File tree

5 files changed

+351
-184
lines changed

5 files changed

+351
-184
lines changed
 

‎deps/ReactantExtra/API.cpp

Lines changed: 131 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -342,14 +342,12 @@ extern "C" void ProfilerServerStop(tsl::profiler::ProfilerServer *server) {
342342
delete server;
343343
}
344344

345-
extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id,
346-
int num_nodes) {
345+
extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id) {
347346
CpuClientOptions options;
348-
// options.kv_store = "etcd";
347+
349348
options.process_id = node_id;
350-
// options.num_nodes = num_nodes;
351-
// options.collectives = num_nodes;
352349
options.asynchronous = asynchronous != 0;
350+
353351
auto client = MyValueOrThrow(GetTfrtCpuClient(options));
354352
return client.release();
355353
}
@@ -1271,28 +1269,6 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
12711269
return wrap(entryFn);
12721270
}
12731271

1274-
extern "C" HeldPjRtClient *
1275-
pjrt_make_cpu_client_shared(uint8_t asynchronous, int node_id, int num_nodes) {
1276-
PjRtClient *client = MakeCPUClient(asynchronous, node_id, num_nodes);
1277-
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1278-
}
1279-
1280-
extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared(
1281-
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1282-
double memory_fraction, bool preallocate, const char *platform_name,
1283-
const char **error, void *distributed_runtime_client) {
1284-
PjRtClient *client = MakeGPUClient(
1285-
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1286-
preallocate, platform_name, error, distributed_runtime_client);
1287-
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1288-
}
1289-
1290-
extern "C" HeldPjRtClient *pjrt_make_tpu_client_shared(const char *tpu_path,
1291-
const char **error) {
1292-
PjRtClient *client = MakeTPUClient(tpu_path, error);
1293-
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1294-
}
1295-
12961272
extern "C" void pjrt_client_dtor(HeldPjRtClient *client) { delete client; }
12971273

12981274
extern "C" int pjrt_client_num_devices(HeldPjRtClient *client) {
@@ -1369,11 +1345,6 @@ extern "C" HeldPjRtClient *pjrt_buffer_get_client(HeldPjRtBuffer *buffer) {
13691345
std::shared_ptr<PjRtClient>(buffer->ptr()->client()));
13701346
}
13711347

1372-
extern "C" ifrt::Client *ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client) {
1373-
xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()};
1374-
return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release();
1375-
}
1376-
13771348
extern "C" void ifrt_client_dtor(ifrt::Client *client) { delete client; }
13781349

13791350
// generic version, but IFRT-PjRt backend only supports SingleDeviceSharding
@@ -1575,61 +1546,62 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
15751546

15761547
#pragma region IfRtClient
15771548

1578-
extern "C" ifrt::proxy::GrpcServer *
1579-
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu(
1580-
const char *c_address, uint8_t asynchronous, int node_id, int num_nodes) {
1581-
std::string address = c_address;
1582-
1583-
return MyValueOrThrow(
1584-
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1585-
address,
1586-
[asynchronous, node_id, num_nodes]()
1587-
-> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1588-
auto pjrt_client = std::shared_ptr<PjRtClient>(
1589-
MakeCPUClient(asynchronous, node_id, num_nodes));
1590-
return std::shared_ptr<ifrt::Client>(
1591-
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1592-
}))
1593-
.release();
1594-
}
1595-
1596-
extern "C" ifrt::proxy::GrpcServer *
1597-
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu(
1598-
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1599-
double memory_fraction, bool preallocate, const char *platform_name,
1600-
const char **error) {
1601-
return MyValueOrThrow(
1602-
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1603-
std::string(),
1604-
[node_id, num_nodes, allowed_devices, num_allowed_devices,
1605-
memory_fraction, preallocate, platform_name,
1606-
error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1607-
auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient(
1608-
node_id, num_nodes, allowed_devices, num_allowed_devices,
1609-
memory_fraction, preallocate, platform_name, error));
1610-
return std::shared_ptr<ifrt::Client>(
1611-
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1612-
}))
1613-
.release();
1614-
}
1549+
// XXX: Bring back with the correct API
1550+
// extern "C" ifrt::proxy::GrpcServer *
1551+
// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu(
1552+
// const char *c_address, uint8_t asynchronous, int node_id) {
1553+
// std::string address = c_address;
1554+
1555+
// return MyValueOrThrow(
1556+
// ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1557+
// address,
1558+
// [asynchronous,
1559+
// node_id]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1560+
// auto pjrt_client = std::shared_ptr<PjRtClient>(
1561+
// MakeCPUClient(asynchronous, node_id));
1562+
// return std::shared_ptr<ifrt::Client>(
1563+
// xla::ifrt::PjRtClient::Create(pjrt_client).release());
1564+
// }))
1565+
// .release();
1566+
// }
16151567

1616-
extern "C" ifrt::proxy::GrpcServer *
1617-
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1618-
const char *c_address, const char *tpu_path, const char **error) {
1619-
std::string address = c_address;
1568+
// extern "C" ifrt::proxy::GrpcServer *
1569+
// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_gpu(
1570+
// int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1571+
// double memory_fraction, bool preallocate, const char *platform_name,
1572+
// const char **error) {
1573+
// return MyValueOrThrow(
1574+
// ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1575+
// std::string(),
1576+
// [node_id, num_nodes, allowed_devices, num_allowed_devices,
1577+
// memory_fraction, preallocate, platform_name,
1578+
// error]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1579+
// auto pjrt_client = std::shared_ptr<PjRtClient>(MakeGPUClient(
1580+
// node_id, num_nodes, allowed_devices, num_allowed_devices,
1581+
// memory_fraction, preallocate, platform_name, error));
1582+
// return std::shared_ptr<ifrt::Client>(
1583+
// xla::ifrt::PjRtClient::Create(pjrt_client).release());
1584+
// }))
1585+
// .release();
1586+
// }
16201587

1621-
return MyValueOrThrow(
1622-
xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1623-
address,
1624-
[tpu_path, error]()
1625-
-> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1626-
auto pjrt_client = std::shared_ptr<xla::PjRtClient>(
1627-
MakeTPUClient(tpu_path, error));
1628-
return std::shared_ptr<xla::ifrt::Client>(
1629-
xla::ifrt::PjRtClient::Create(pjrt_client).release());
1630-
}))
1631-
.release();
1632-
}
1588+
// extern "C" ifrt::proxy::GrpcServer *
1589+
// ifrt_proxy_grpc_server_create_from_ifrt_client_factory_tpu(
1590+
// const char *c_address, const char *tpu_path, const char **error) {
1591+
// std::string address = c_address;
1592+
1593+
// return MyValueOrThrow(
1594+
// xla::ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
1595+
// address,
1596+
// [tpu_path, error]()
1597+
// -> absl::StatusOr<std::shared_ptr<xla::ifrt::Client>> {
1598+
// auto pjrt_client = std::shared_ptr<xla::PjRtClient>(
1599+
// MakeTPUClient(tpu_path, error));
1600+
// return std::shared_ptr<xla::ifrt::Client>(
1601+
// xla::ifrt::PjRtClient::Create(pjrt_client).release());
1602+
// }))
1603+
// .release();
1604+
// }
16331605

16341606
extern "C" void ifrt_proxy_grpc_server_dtor(ifrt::proxy::GrpcServer *server) {
16351607
delete server;
@@ -1662,24 +1634,88 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
16621634
.release();
16631635
}
16641636

1665-
extern "C" ifrt::Client *ifrt_make_pjrt_cpu_client(uint8_t asynchronous,
1666-
int node_id, int num_nodes) {
1667-
return ifrt_pjrt_make_client(
1668-
pjrt_make_cpu_client_shared(asynchronous, node_id, num_nodes));
1637+
extern "C" ifrt::Client *ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client,
1638+
int node_id, int num_nodes,
1639+
void *distributed_runtime_client,
1640+
const char **error,
1641+
std::string key_prefix) {
1642+
ifrt::PjRtClient::CreateOptions options;
1643+
options.pjrt_client = pjrt_client->obj();
1644+
1645+
if (num_nodes > 1) {
1646+
if (distributed_runtime_client == nullptr) {
1647+
*error =
1648+
"`distributed_runtime_client` must be non-null if `num_nodes` > 1";
1649+
return nullptr;
1650+
}
1651+
auto typed_distributed_runtime_client = static_cast<
1652+
HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *>(
1653+
distributed_runtime_client);
1654+
options.kv_store = GetDistributedKeyValueStore(
1655+
typed_distributed_runtime_client->obj(), key_prefix);
1656+
}
1657+
1658+
options.process_id = node_id;
1659+
options.num_processes = num_nodes;
1660+
1661+
return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release();
1662+
}
1663+
1664+
extern "C" HeldPjRtClient *pjrt_make_cpu_client_shared(uint8_t asynchronous,
1665+
int node_id) {
1666+
PjRtClient *client = MakeCPUClient(asynchronous, node_id);
1667+
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1668+
}
1669+
1670+
extern "C" ifrt::Client *
1671+
ifrt_make_pjrt_cpu_client(uint8_t asynchronous, int node_id, int num_nodes,
1672+
void *distributed_runtime_client,
1673+
const char **error) {
1674+
HeldPjRtClient *pjrt_client =
1675+
pjrt_make_cpu_client_shared(asynchronous, node_id);
1676+
if (pjrt_client == nullptr)
1677+
return nullptr;
1678+
return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes,
1679+
distributed_runtime_client, error, "cpu");
1680+
}
1681+
1682+
extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared(
1683+
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1684+
double memory_fraction, bool preallocate, const char *platform_name,
1685+
const char **error, void *distributed_runtime_client) {
1686+
PjRtClient *client = MakeGPUClient(
1687+
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1688+
preallocate, platform_name, error, distributed_runtime_client);
1689+
return reactant::capture(std::shared_ptr<PjRtClient>(client));
16691690
}
16701691

16711692
extern "C" ifrt::Client *ifrt_make_pjrt_gpu_client(
16721693
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
16731694
double memory_fraction, bool preallocate, const char *platform_name,
16741695
const char **error, void *distributed_runtime_client) {
1675-
return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared(
1696+
HeldPjRtClient *pjrt_client = pjrt_make_gpu_client_shared(
16761697
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1677-
preallocate, platform_name, error, distributed_runtime_client));
1698+
preallocate, platform_name, error, distributed_runtime_client);
1699+
if (pjrt_client == nullptr)
1700+
return nullptr;
1701+
return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes,
1702+
distributed_runtime_client, error, "gpu");
1703+
}
1704+
1705+
extern "C" HeldPjRtClient *pjrt_make_tpu_client_shared(const char *tpu_path,
1706+
const char **error) {
1707+
PjRtClient *client = MakeTPUClient(tpu_path, error);
1708+
return reactant::capture(std::shared_ptr<PjRtClient>(client));
16781709
}
16791710

1680-
extern "C" ifrt::Client *ifrt_make_pjrt_tpu_client(const char *tpu_path,
1681-
const char **error) {
1682-
return ifrt_pjrt_make_client(pjrt_make_tpu_client_shared(tpu_path, error));
1711+
extern "C" ifrt::Client *
1712+
ifrt_make_pjrt_tpu_client(const char *tpu_path, const char **error, int node_id,
1713+
int num_nodes, void *distributed_runtime_client) {
1714+
HeldPjRtClient *pjrt_client = pjrt_make_tpu_client_shared(tpu_path, error);
1715+
if (pjrt_client == nullptr)
1716+
return nullptr;
1717+
return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes,
1718+
distributed_runtime_client, error, "tpu");
16831719
}
16841720

16851721
extern "C" void ifrt_FreeClient(ifrt::Client *client) { delete client; }
@@ -1943,7 +1979,7 @@ extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array,
19431979

19441980
#pragma endregion
19451981

1946-
#pragma region PjRtDistributed
1982+
#pragma region xla::Distributed
19471983

19481984
extern "C" HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *
19491985
GetDistributedRuntimeClient(char *c_address, int32_t node_id,

‎src/xla/Client.jl

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,55 +14,3 @@ function get_addressable_device end
1414
function platform_name end
1515

1616
default_device(client::AbstractClient) = first(addressable_devices(client))
17-
18-
# Clients for Different Backends
19-
function CPUClient(cfunc, node_id=0, num_nodes=1; asynchronous=true)
20-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
21-
client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes)
22-
LLVMclopts("-nvptx-fma-level=1")
23-
return client
24-
end
25-
26-
function GPUClient(
27-
cfunc,
28-
node_id=0,
29-
num_nodes=1,
30-
platform="gpu";
31-
allowed_devices::Union{Nothing,Vector{Int}}=nothing,
32-
distributed_runtime_client::Union{Nothing,DistributedRuntimeClient}=nothing,
33-
)
34-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
35-
refstr = Ref{Cstring}()
36-
37-
num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices)
38-
allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices
39-
distributed_runtime_client =
40-
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
41-
42-
client = ccall(
43-
f,
44-
Ptr{Cvoid},
45-
(Cint, Cint, Ptr{Cvoid}, Cint, Cdouble, Bool, Cstring, Ptr{Cstring}, Ptr{Cvoid}),
46-
node_id,
47-
num_nodes,
48-
allowed_devices,
49-
num_allowed_devices,
50-
XLA_REACTANT_GPU_MEM_FRACTION[],
51-
false,
52-
platform,
53-
refstr,
54-
distributed_runtime_client,
55-
)
56-
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
57-
LLVMclopts("-nvptx-fma-level=1")
58-
return client
59-
end
60-
61-
function TPUClient(cfunc, tpu_path::String)
62-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
63-
refstr = Ref{Cstring}()
64-
client = ccall(f, Ptr{Cvoid}, (Cstring, Ptr{Cstring}), tpu_path, refstr)
65-
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
66-
LLVMclopts("-nvptx-fma-level=1")
67-
return client
68-
end

‎src/xla/IFRT/Client.jl

Lines changed: 113 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
mutable struct Client <: XLA.AbstractClient
22
client::Ptr{Cvoid}
33

4-
function Client(client::Ptr{Cvoid})
5-
@assert client != C_NULL
4+
function Client(client::Ptr{Cvoid}; skip_check::Bool=false)
5+
skip_check || (@assert client != C_NULL)
66
return new(client)
77
end
88
end
@@ -66,22 +66,127 @@ function XLA.platform_name(client::Client)
6666
return XLA.unsafe_string_and_free(str)
6767
end
6868

69+
function XLA.devices(client::Client)
70+
ndevices = Int(XLA.num_devices(client))
71+
devices = Ref{NTuple{ndevices,Ptr{Cvoid}}}()
72+
GC.@preserve client devices begin
73+
@ccall MLIR.API.mlir_c.ifrt_client_devices(
74+
client.client::Ptr{Cvoid}, devices::Ptr{Ptr{Cvoid}}
75+
)::Cvoid
76+
end
77+
return [Device(device) for device in devices[]]
78+
end
79+
80+
function XLA.addressable_devices(client::Client)
81+
naddressable_devices = Int(XLA.num_addressable_devices(client))
82+
addressable_devices = Ref{NTuple{naddressable_devices,Ptr{Cvoid}}}()
83+
GC.@preserve client addressable_devices begin
84+
@ccall MLIR.API.mlir_c.ifrt_client_addressable_devices(
85+
client.client::Ptr{Cvoid}, addressable_devices::Ptr{Ptr{Cvoid}}
86+
)::Cvoid
87+
end
88+
return [Device(device) for device in addressable_devices[]]
89+
end
90+
6991
# Different Backends
7092
const cpu_client_count = Ref(0)
7193
const gpu_client_count = Ref(0)
7294
const tpu_client_count = Ref(0)
7395

74-
# XXX: We need other backends to support sharding
75-
for (backend, fname, counter) in (
76-
(:CPUClient, "ifrt_make_pjrt_cpu_client", :cpu_client_count),
77-
(:GPUClient, "ifrt_make_pjrt_gpu_client", :gpu_client_count),
78-
(:TPUClient, "ifrt_make_pjrt_tpu_client", :tpu_client_count),
96+
for (backend, counter) in (
97+
(:CPUClient, :cpu_client_count),
98+
(:GPUClient, :gpu_client_count),
99+
(:TPUClient, :tpu_client_count),
79100
)
101+
main_fn = Symbol(:MakeIFRTPJRT, backend)
80102
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
81103
if checkcount
82104
@assert $(counter)[] == 0
105+
end
106+
client, refstr = $(main_fn)(args...; kwargs...)
107+
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
108+
XLA.LLVMclopts("-nvptx-fma-level=1")
109+
if checkcount
110+
# Only increment the counter if we successfully created a client
83111
$(counter)[] += 1
84112
end
85-
return Client(XLA.$(backend)($(fname), args...; kwargs...))
113+
return Client(client)
86114
end
87115
end
116+
117+
function MakeIFRTPJRTCPUClient(;
118+
node_id::Integer=0,
119+
num_nodes::Integer=1,
120+
asynchronous::Bool=true,
121+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
122+
)
123+
refstr = Ref{Cstring}()
124+
distributed_runtime_client =
125+
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
126+
127+
GC.@preserve refstr distributed_runtime_client begin
128+
client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_cpu_client(
129+
asynchronous::UInt8,
130+
node_id::Cint,
131+
num_nodes::Cint,
132+
distributed_runtime_client::Ptr{Cvoid},
133+
refstr::Ptr{Cstring},
134+
)::Ptr{Cvoid}
135+
end
136+
137+
return client, refstr
138+
end
139+
140+
function MakeIFRTPJRTGPUClient(;
141+
node_id::Integer=0,
142+
num_nodes::Integer=1,
143+
platform::String="gpu",
144+
allowed_devices::Union{Nothing,Vector{Int}}=nothing,
145+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
146+
)
147+
refstr = Ref{Cstring}()
148+
149+
num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices)
150+
allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices
151+
distributed_runtime_client =
152+
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
153+
154+
GC.@preserve refstr allowed_devices distributed_runtime_client begin
155+
client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_gpu_client(
156+
node_id::Cint,
157+
num_nodes::Cint,
158+
allowed_devices::Ptr{Cvoid},
159+
num_allowed_devices::Cint,
160+
XLA.XLA_REACTANT_GPU_MEM_FRACTION[]::Cdouble,
161+
XLA.XLA_REACTANT_GPU_PREALLOCATE[]::Bool,
162+
platform::Cstring,
163+
refstr::Ptr{Cstring},
164+
distributed_runtime_client::Ptr{Cvoid},
165+
)::Ptr{Cvoid}
166+
end
167+
168+
return client, refstr
169+
end
170+
171+
function MakeIFRTPJRTTPUClient(;
172+
tpu_path::String,
173+
node_id::Integer=0,
174+
num_nodes::Integer=1,
175+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
176+
)
177+
refstr = Ref{Cstring}()
178+
distributed_runtime_client =
179+
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
180+
181+
GC.@preserve refstr distributed_runtime_client begin
182+
client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_tpu_client(
183+
tpu_path::Cstring,
184+
refstr::Ptr{Cstring},
185+
node_id::Cint,
186+
num_nodes::Cint,
187+
distributed_runtime_client::Ptr{Cvoid},
188+
)::Ptr{Cvoid}
189+
end
190+
191+
return client, refstr
192+
end

‎src/xla/PJRT/Client.jl

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,20 +89,91 @@ const cpu_client_count = Ref(0)
8989
const gpu_client_count = Ref(0)
9090
const tpu_client_count = Ref(0)
9191

92-
for (backend, fname, counter) in (
93-
(:CPUClient, "MakeCPUClient", :cpu_client_count),
94-
(:GPUClient, "MakeGPUClient", :gpu_client_count),
95-
(:TPUClient, "MakeTPUClient", :tpu_client_count),
92+
for (backend, counter) in (
93+
(:CPUClient, :cpu_client_count),
94+
(:GPUClient, :gpu_client_count),
95+
(:TPUClient, :tpu_client_count),
9696
)
97+
main_fn = Symbol(:Make, backend)
9798
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
9899
if checkcount
99100
@assert $(counter)[] == 0
100101
end
101-
client = Client(XLA.$(backend)($(fname), args...; kwargs...))
102+
client = Client($(main_fn)(args...; kwargs...))
103+
XLA.LLVMclopts("-nvptx-fma-level=1")
102104
if checkcount
103105
# Only increment the counter if we successfully created a client
104106
$(counter)[] += 1
105107
end
106108
return client
107109
end
108110
end
111+
112+
function MakeCPUClient(;
113+
node_id::Integer=0,
114+
num_nodes::Integer=1,
115+
asynchronous::Bool=true,
116+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
117+
)
118+
@assert num_nodes == 1 "`PJRT.MakeCPUClient` does not support num_nodes > 1"
119+
@assert distributed_runtime_client === nothing "`PJRT.MakeCPUClient` does not support \
120+
distributed_runtime_client"
121+
122+
return @ccall MLIR.API.mlir_c.MakeCPUClient(
123+
asynchronous::UInt8, node_id::Cint
124+
)::Ptr{Cvoid}
125+
end
126+
127+
function MakeGPUClient(;
128+
node_id::Integer=0,
129+
num_nodes::Integer=1,
130+
platform::String="gpu",
131+
allowed_devices::Union{Nothing,Vector{Int}}=nothing,
132+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
133+
)
134+
refstr = Ref{Cstring}()
135+
136+
num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices)
137+
allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices
138+
distributed_runtime_client =
139+
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
140+
141+
GC.@preserve refstr allowed_devices distributed_runtime_client begin
142+
client = @ccall MLIR.API.mlir_c.MakeGPUClient(
143+
node_id::Cint,
144+
num_nodes::Cint,
145+
allowed_devices::Ptr{Cvoid},
146+
num_allowed_devices::Cint,
147+
XLA.XLA_REACTANT_GPU_MEM_FRACTION[]::Cdouble,
148+
XLA.XLA_REACTANT_GPU_PREALLOCATE[]::Bool,
149+
platform::Cstring,
150+
refstr::Ptr{Cstring},
151+
distributed_runtime_client::Ptr{Cvoid},
152+
)::Ptr{Cvoid}
153+
end
154+
155+
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
156+
return client
157+
end
158+
159+
function MakeTPUClient(;
160+
tpu_path::String,
161+
node_id::Integer=0,
162+
num_nodes::Integer=1,
163+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
164+
)
165+
@assert node_id == 0 "`PJRT.MakeTPUClient` does not support node_id"
166+
@assert num_nodes == 1 "`PJRT.MakeTPUClient` does not support num_nodes > 1"
167+
@assert distributed_runtime_client === nothing "`PJRT.MakeTPUClient` does not support \
168+
distributed_runtime_client"
169+
170+
refstr = Ref{Cstring}()
171+
GC.@preserve refstr begin
172+
client = @ccall MLIR.API.mlir_c.MakeTPUClient(
173+
tpu_path::Cstring, refstr::Ptr{Cstring}
174+
)::Ptr{Cvoid}
175+
end
176+
177+
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
178+
return client
179+
end

‎src/xla/XLA.jl

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,25 @@ include("PJRT/PJRT.jl")
3737

3838
include("IFRT/IFRT.jl")
3939

40-
@kwdef mutable struct BackendState
40+
@kwdef mutable struct PJRTBackendState
4141
initialized::Bool = false
4242
clients::Dict{String,PJRT.Client} = Dict{String,PJRT.Client}()
4343
default_client::PJRT.Client = PJRT.Client(C_NULL; skip_check=true)
4444
end
4545

46-
function Base.getproperty(bs::BackendState, sym::Symbol)
46+
function Base.getproperty(bs::PJRTBackendState, sym::Symbol)
4747
(sym === :initialized || bs.initialized) && return getfield(bs, sym)
48-
initialize_default_clients!(bs)
48+
initialize_default_pjrt_clients!(bs)
4949
return getfield(bs, sym)
5050
end
5151

52-
function Base.setproperty!(bs::BackendState, sym::Symbol, val)
52+
function Base.setproperty!(bs::PJRTBackendState, sym::Symbol, val)
5353
(sym === :initialized || bs.initialized) && return setfield!(bs, sym, val)
54-
initialize_default_clients!(bs)
54+
initialize_default_pjrt_clients!(bs)
5555
return setfield!(bs, sym, val)
5656
end
5757

58-
const global_backend_state = BackendState()
58+
const global_backend_state = PJRTBackendState()
5959
const global_state = State()
6060

6161
client(backend::String) = global_backend_state.clients[backend]
@@ -74,8 +74,13 @@ end
7474

7575
function update_global_state!(args...; kwargs...)
7676
update!(global_state, args...; kwargs...)
77-
# We need to update the clients based on the new state
78-
initialize_default_clients!(global_backend_state)
77+
# We conditionally initialize for now, since a lot of options that are set are not
78+
# necessarily supported by PJRT. This makes testing for IFRT quite hard.
79+
# Once we move to IFRT completely, we can remove this.
80+
if global_backend_state.initialized
81+
# We need to update the clients based on the new state
82+
initialize_default_pjrt_clients!(global_backend_state)
83+
end
7984
return nothing
8085
end
8186

@@ -112,16 +117,27 @@ function __init__()
112117
return nothing
113118
end
114119

115-
function initialize_default_clients!(state::BackendState)
120+
function initialize_default_pjrt_clients!(state::PJRTBackendState)
116121
was_initialized = state.initialized
117122
state.initialized = true
123+
distributed_runtime_client = if global_state.num_processes > 1
124+
@assert global_state.client !== nothing
125+
global_state.client
126+
else
127+
nothing
128+
end
129+
common_kwargs = (;
130+
node_id=global_state.process_id,
131+
num_nodes=global_state.num_processes,
132+
distributed_runtime_client,
133+
)
118134

119135
# CPU
120136
if was_initialized && haskey(state.clients, "cpu")
121137
XLA.free_client(state.clients["cpu"])
122138
XLA.PJRT.cpu_client_count[] -= 1
123139
end
124-
cpu = PJRT.CPUClient(global_state.process_id, global_state.num_processes)
140+
cpu = PJRT.CPUClient(; common_kwargs..., asynchronous=true)
125141
state.clients["cpu"] = cpu
126142
state.default_client = cpu
127143

@@ -144,8 +160,9 @@ function initialize_default_clients!(state::BackendState)
144160
XLA.free_client(state.clients["tpu"])
145161
XLA.PJRT.tpu_client_count[] -= 1
146162
end
147-
# XXX: process_id? num_processes?
148-
tpu = PJRT.TPUClient(dataset_dir * "/libtpu.so")
163+
tpu = PJRT.TPUClient(;
164+
tpu_path=dataset_dir * "/libtpu.so", common_kwargs...
165+
)
149166
state.clients["tpu"] = tpu
150167
state.default_client = tpu
151168
catch e
@@ -154,22 +171,12 @@ function initialize_default_clients!(state::BackendState)
154171
else
155172
if !Reactant.precompiling()
156173
try
157-
distributed_runtime_client = if global_state.num_processes > 1
158-
@assert global_state.client !== nothing
159-
global_state.client
160-
else
161-
nothing
162-
end
163-
164174
if was_initialized && haskey(state.clients, "gpu")
165175
XLA.free_client(state.clients["gpu"])
166176
XLA.PJRT.gpu_client_count[] -= 1
167177
end
168-
gpu = PJRT.GPUClient(
169-
global_state.process_id,
170-
global_state.num_processes;
171-
allowed_devices=global_state.local_device_ids,
172-
distributed_runtime_client,
178+
gpu = PJRT.GPUClient(;
179+
common_kwargs..., allowed_devices=global_state.local_device_ids
173180
)
174181
state.clients["gpu"] = gpu
175182
state.default_client = gpu

0 commit comments

Comments
 (0)
Please sign in to comment.