Skip to content

Commit 8026eb9

Browse files
committed
feat: construct IFRT clients with distributed options
1 parent 408e120 commit 8026eb9

File tree

5 files changed

+275
-132
lines changed

5 files changed

+275
-132
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 81 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -342,14 +342,12 @@ extern "C" void ProfilerServerStop(tsl::profiler::ProfilerServer *server) {
342342
delete server;
343343
}
344344

345-
extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id,
346-
int num_nodes) {
345+
extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id) {
347346
CpuClientOptions options;
348-
// options.kv_store = "etcd";
347+
349348
options.process_id = node_id;
350-
// options.num_nodes = num_nodes;
351-
// options.collectives = num_nodes;
352349
options.asynchronous = asynchronous != 0;
350+
353351
auto client = MyValueOrThrow(GetTfrtCpuClient(options));
354352
return client.release();
355353
}
@@ -1271,28 +1269,6 @@ extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC,
12711269
return wrap(entryFn);
12721270
}
12731271

1274-
extern "C" HeldPjRtClient *
1275-
pjrt_make_cpu_client_shared(uint8_t asynchronous, int node_id, int num_nodes) {
1276-
PjRtClient *client = MakeCPUClient(asynchronous, node_id, num_nodes);
1277-
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1278-
}
1279-
1280-
extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared(
1281-
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1282-
double memory_fraction, bool preallocate, const char *platform_name,
1283-
const char **error, void *distributed_runtime_client) {
1284-
PjRtClient *client = MakeGPUClient(
1285-
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1286-
preallocate, platform_name, error, distributed_runtime_client);
1287-
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1288-
}
1289-
1290-
extern "C" HeldPjRtClient *pjrt_make_tpu_client_shared(const char *tpu_path,
1291-
const char **error) {
1292-
PjRtClient *client = MakeTPUClient(tpu_path, error);
1293-
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1294-
}
1295-
12961272
extern "C" void pjrt_client_dtor(HeldPjRtClient *client) { delete client; }
12971273

12981274
extern "C" int pjrt_client_num_devices(HeldPjRtClient *client) {
@@ -1369,11 +1345,6 @@ extern "C" HeldPjRtClient *pjrt_buffer_get_client(HeldPjRtBuffer *buffer) {
13691345
std::shared_ptr<PjRtClient>(buffer->ptr()->client()));
13701346
}
13711347

1372-
extern "C" ifrt::Client *ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client) {
1373-
xla::ifrt::PjRtClient::CreateOptions options = {pjrt_client->obj()};
1374-
return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release();
1375-
}
1376-
13771348
extern "C" void ifrt_client_dtor(ifrt::Client *client) { delete client; }
13781349

13791350
// generic version, but IFRT-PjRt backend only supports SingleDeviceSharding
@@ -1577,16 +1548,16 @@ FreeHloModule(HeldValue<std::shared_ptr<xla::HloModule>> *hlo_module) {
15771548

15781549
extern "C" ifrt::proxy::GrpcServer *
15791550
ifrt_proxy_grpc_server_create_from_ifrt_client_factory_cpu(
1580-
const char *c_address, uint8_t asynchronous, int node_id, int num_nodes) {
1551+
const char *c_address, uint8_t asynchronous, int node_id) {
15811552
std::string address = c_address;
15821553

15831554
return MyValueOrThrow(
15841555
ifrt::proxy::GrpcServer::CreateFromIfrtClientFactory(
15851556
address,
1586-
[asynchronous, node_id, num_nodes]()
1587-
-> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
1557+
[asynchronous,
1558+
node_id]() -> absl::StatusOr<std::shared_ptr<ifrt::Client>> {
15881559
auto pjrt_client = std::shared_ptr<PjRtClient>(
1589-
MakeCPUClient(asynchronous, node_id, num_nodes));
1560+
MakeCPUClient(asynchronous, node_id));
15901561
return std::shared_ptr<ifrt::Client>(
15911562
xla::ifrt::PjRtClient::Create(pjrt_client).release());
15921563
}))
@@ -1662,24 +1633,88 @@ ifrt_proxy_create_client(const char *c_proxy_server_address,
16621633
.release();
16631634
}
16641635

1665-
extern "C" ifrt::Client *ifrt_make_pjrt_cpu_client(uint8_t asynchronous,
1666-
int node_id, int num_nodes) {
1667-
return ifrt_pjrt_make_client(
1668-
pjrt_make_cpu_client_shared(asynchronous, node_id, num_nodes));
1636+
extern "C" ifrt::Client *ifrt_pjrt_make_client(HeldPjRtClient *pjrt_client,
1637+
int node_id, int num_nodes,
1638+
void *distributed_runtime_client,
1639+
const char **error,
1640+
std::string key_prefix) {
1641+
ifrt::PjRtClient::CreateOptions options;
1642+
options.pjrt_client = pjrt_client->obj();
1643+
1644+
if (num_nodes > 1) {
1645+
if (distributed_runtime_client == nullptr) {
1646+
*error =
1647+
"`distributed_runtime_client` must be non-null if `num_nodes` > 1";
1648+
return nullptr;
1649+
}
1650+
auto typed_distributed_runtime_client = static_cast<
1651+
HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *>(
1652+
distributed_runtime_client);
1653+
options.kv_store = GetDistributedKeyValueStore(
1654+
typed_distributed_runtime_client->obj(), key_prefix);
1655+
}
1656+
1657+
options.process_id = node_id;
1658+
options.num_processes = num_nodes;
1659+
1660+
return MyValueOrThrow(xla::ifrt::PjRtClient::Create(options)).release();
1661+
}
1662+
1663+
extern "C" HeldPjRtClient *pjrt_make_cpu_client_shared(uint8_t asynchronous,
1664+
int node_id) {
1665+
PjRtClient *client = MakeCPUClient(asynchronous, node_id);
1666+
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1667+
}
1668+
1669+
extern "C" ifrt::Client *
1670+
ifrt_make_pjrt_cpu_client(uint8_t asynchronous, int node_id, int num_nodes,
1671+
void *distributed_runtime_client,
1672+
const char **error) {
1673+
HeldPjRtClient *pjrt_client =
1674+
pjrt_make_cpu_client_shared(asynchronous, node_id);
1675+
if (pjrt_client == nullptr)
1676+
return nullptr;
1677+
return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes,
1678+
distributed_runtime_client, error, "cpu");
1679+
}
1680+
1681+
extern "C" HeldPjRtClient *pjrt_make_gpu_client_shared(
1682+
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
1683+
double memory_fraction, bool preallocate, const char *platform_name,
1684+
const char **error, void *distributed_runtime_client) {
1685+
PjRtClient *client = MakeGPUClient(
1686+
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1687+
preallocate, platform_name, error, distributed_runtime_client);
1688+
return reactant::capture(std::shared_ptr<PjRtClient>(client));
16691689
}
16701690

16711691
extern "C" ifrt::Client *ifrt_make_pjrt_gpu_client(
16721692
int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices,
16731693
double memory_fraction, bool preallocate, const char *platform_name,
16741694
const char **error, void *distributed_runtime_client) {
1675-
return ifrt_pjrt_make_client(pjrt_make_gpu_client_shared(
1695+
HeldPjRtClient *pjrt_client = pjrt_make_gpu_client_shared(
16761696
node_id, num_nodes, allowed_devices, num_allowed_devices, memory_fraction,
1677-
preallocate, platform_name, error, distributed_runtime_client));
1697+
preallocate, platform_name, error, distributed_runtime_client);
1698+
if (pjrt_client == nullptr)
1699+
return nullptr;
1700+
return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes,
1701+
distributed_runtime_client, error, "gpu");
16781702
}
16791703

1680-
extern "C" ifrt::Client *ifrt_make_pjrt_tpu_client(const char *tpu_path,
1681-
const char **error) {
1682-
return ifrt_pjrt_make_client(pjrt_make_tpu_client_shared(tpu_path, error));
1704+
extern "C" HeldPjRtClient *pjrt_make_tpu_client_shared(const char *tpu_path,
1705+
const char **error) {
1706+
PjRtClient *client = MakeTPUClient(tpu_path, error);
1707+
return reactant::capture(std::shared_ptr<PjRtClient>(client));
1708+
}
1709+
1710+
extern "C" ifrt::Client *
1711+
ifrt_make_pjrt_tpu_client(const char *tpu_path, const char **error, int node_id,
1712+
int num_nodes, void *distributed_runtime_client) {
1713+
HeldPjRtClient *pjrt_client = pjrt_make_tpu_client_shared(tpu_path, error);
1714+
if (pjrt_client == nullptr)
1715+
return nullptr;
1716+
return ifrt_pjrt_make_client(pjrt_client, node_id, num_nodes,
1717+
distributed_runtime_client, error, "tpu");
16831718
}
16841719

16851720
extern "C" void ifrt_FreeClient(ifrt::Client *client) { delete client; }
@@ -1943,7 +1978,7 @@ extern "C" void ifrt_array_copy_to_host_buffer(HeldIfrtArray *array,
19431978

19441979
#pragma endregion
19451980

1946-
#pragma region PjRtDistributed
1981+
#pragma region xla::Distributed
19471982

19481983
extern "C" HeldValue<std::shared_ptr<xla::DistributedRuntimeClient>> *
19491984
GetDistributedRuntimeClient(char *c_address, int32_t node_id,

src/xla/Client.jl

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,55 +14,3 @@ function get_addressable_device end
1414
function platform_name end
1515

1616
default_device(client::AbstractClient) = first(addressable_devices(client))
17-
18-
# Clients for Different Backends
19-
function CPUClient(cfunc, node_id=0, num_nodes=1; asynchronous=true)
20-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
21-
client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes)
22-
LLVMclopts("-nvptx-fma-level=1")
23-
return client
24-
end
25-
26-
function GPUClient(
27-
cfunc,
28-
node_id=0,
29-
num_nodes=1,
30-
platform="gpu";
31-
allowed_devices::Union{Nothing,Vector{Int}}=nothing,
32-
distributed_runtime_client::Union{Nothing,DistributedRuntimeClient}=nothing,
33-
)
34-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
35-
refstr = Ref{Cstring}()
36-
37-
num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices)
38-
allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices
39-
distributed_runtime_client =
40-
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
41-
42-
client = ccall(
43-
f,
44-
Ptr{Cvoid},
45-
(Cint, Cint, Ptr{Cvoid}, Cint, Cdouble, Bool, Cstring, Ptr{Cstring}, Ptr{Cvoid}),
46-
node_id,
47-
num_nodes,
48-
allowed_devices,
49-
num_allowed_devices,
50-
XLA_REACTANT_GPU_MEM_FRACTION[],
51-
false,
52-
platform,
53-
refstr,
54-
distributed_runtime_client,
55-
)
56-
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
57-
LLVMclopts("-nvptx-fma-level=1")
58-
return client
59-
end
60-
61-
function TPUClient(cfunc, tpu_path::String)
62-
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, string(cfunc))
63-
refstr = Ref{Cstring}()
64-
client = ccall(f, Ptr{Cvoid}, (Cstring, Ptr{Cstring}), tpu_path, refstr)
65-
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
66-
LLVMclopts("-nvptx-fma-level=1")
67-
return client
68-
end

src/xla/IFRT/Client.jl

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,102 @@ const cpu_client_count = Ref(0)
7171
const gpu_client_count = Ref(0)
7272
const tpu_client_count = Ref(0)
7373

74-
# XXX: We need other backends to support sharding
75-
for (backend, fname, counter) in (
76-
(:CPUClient, "ifrt_make_pjrt_cpu_client", :cpu_client_count),
77-
(:GPUClient, "ifrt_make_pjrt_gpu_client", :gpu_client_count),
78-
(:TPUClient, "ifrt_make_pjrt_tpu_client", :tpu_client_count),
74+
for (backend, counter) in (
75+
(:CPUClient, :cpu_client_count),
76+
(:GPUClient, :gpu_client_count),
77+
(:TPUClient, :tpu_client_count),
7978
)
79+
main_fn = Symbol(:MakeIFRTPJRT, backend)
8080
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
8181
if checkcount
8282
@assert $(counter)[] == 0
83+
end
84+
client = Client($(main_fn)(args...; kwargs...))
85+
XLA.LLVMclopts("-nvptx-fma-level=1")
86+
if checkcount
87+
# Only increment the counter if we successfully created a client
8388
$(counter)[] += 1
8489
end
85-
return Client(XLA.$(backend)($(fname), args...; kwargs...))
90+
return client
8691
end
8792
end
93+
94+
function MakeIFRTPJRTCPUClient(;
95+
node_id::Integer=0,
96+
num_nodes::Integer=1,
97+
asynchronous::Bool=true,
98+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
99+
)
100+
refstr = Ref{Cstring}()
101+
distributed_runtime_client =
102+
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
103+
104+
GC.@preserve refstr distributed_runtime_client begin
105+
client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_cpu_client(
106+
asynchronous::UInt8,
107+
node_id::Cint,
108+
num_nodes::Cint,
109+
distributed_runtime_client::Ptr{Cvoid},
110+
refstr::Ptr{Cstring},
111+
)::Ptr{Cvoid}
112+
end
113+
114+
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
115+
return Client(client)
116+
end
117+
118+
function MakeIFRTPJRTGPUClient(;
119+
node_id::Integer=0,
120+
num_nodes::Integer=1,
121+
platform::String="gpu",
122+
allowed_devices::Union{Nothing,Vector{Int}}=nothing,
123+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
124+
)
125+
refstr = Ref{Cstring}()
126+
127+
num_allowed_devices = allowed_devices === nothing ? 0 : length(allowed_devices)
128+
allowed_devices = allowed_devices === nothing ? C_NULL : allowed_devices
129+
distributed_runtime_client =
130+
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
131+
132+
GC.@preserve refstr allowed_devices distributed_runtime_client begin
133+
client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_gpu_client(
134+
node_id::Cint,
135+
num_nodes::Cint,
136+
allowed_devices::Ptr{Cvoid},
137+
num_allowed_devices::Cint,
138+
XLA.XLA_REACTANT_GPU_MEM_FRACTION[]::Cdouble,
139+
XLA.XLA_REACTANT_GPU_PREALLOCATE[]::Bool,
140+
platform::Cstring,
141+
refstr::Ptr{Cstring},
142+
distributed_runtime_client::Ptr{Cvoid},
143+
)::Ptr{Cvoid}
144+
end
145+
146+
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
147+
return Client(client)
148+
end
149+
150+
function MakeIFRTPJRTTPUClient(;
151+
tpu_path::String,
152+
node_id::Integer=0,
153+
num_nodes::Integer=1,
154+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
155+
)
156+
refstr = Ref{Cstring}()
157+
distributed_runtime_client =
158+
distributed_runtime_client === nothing ? C_NULL : distributed_runtime_client.client
159+
160+
GC.@preserve refstr distributed_runtime_client begin
161+
client = @ccall MLIR.API.mlir_c.ifrt_make_pjrt_tpu_client(
162+
tpu_path::Cstring,
163+
refstr::Ptr{Cstring},
164+
node_id::Cint,
165+
num_nodes::Cint,
166+
distributed_runtime_client::Ptr{Cvoid},
167+
)::Ptr{Cvoid}
168+
end
169+
170+
client == C_NULL && throw(AssertionError(unsafe_string(refstr[])))
171+
return Client(client)
172+
end

0 commit comments

Comments
 (0)