Skip to content

Commit 90b0d1d

Browse files
authored
fix: multi-device execution and sharding [take III] (EnzymeAD#713)
* fix: use different API * test: add simple 8 device test * fix: correct result attributes * fix: try running tests on single device * chore: remove unused field * fix: attributes * feat: expose code_mhlo and initial OpSharding * feat: add API to fetch XLA OpSharding * feat: add convertion from JLOpSharding to OpSharding * feat: load OpSharding of the outputs from XLA * fix: don't force replicated sharding * refactor: store internal shardinfo as tuples * fix: make sure to reverse the dims * feat: convert OpSharding to correct sharding info * chore: bump jll * test: more tests fixed * chore: bump jll
1 parent a31b2e0 commit 90b0d1d

22 files changed

+670
-442
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.2.26"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
99
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
10+
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1011
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1112
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1213
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
@@ -61,6 +62,7 @@ ArrayInterface = "7.17.1"
6162
CEnum = "0.5"
6263
CUDA = "5.6"
6364
Downloads = "1.6"
65+
EnumX = "1"
6466
Enzyme = "0.13.28"
6567
EnzymeCore = "0.8.8"
6668
Functors = "0.5"
@@ -79,7 +81,7 @@ PythonCall = "0.9"
7981
Random = "1.10"
8082
Random123 = "1.7"
8183
ReactantCore = "0.1.5"
82-
Reactant_jll = "0.0.64"
84+
Reactant_jll = "0.0.66"
8385
Scratch = "1.2"
8486
Sockets = "1.10"
8587
SpecialFunctions = "2.4"

docs/src/api/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ within_compile
2525

2626
```@docs
2727
@code_hlo
28+
@code_mhlo
2829
```
2930

3031
## Profile XLA

ext/ReactantCUDAExt.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,9 +1067,8 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
10671067
T = eltype(A)
10681068
N = ndims(A)
10691069
if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive
1070-
if sharding isa Reactant.Sharding.NoSharding ||
1071-
sharding isa Reactant.Sharding.FinalizedNoSharding
1072-
return Reactant.ConcreteRArray{T,N,1,Reactant.Sharding.FinalizedNoSharding}
1070+
if !Sharding.is_sharded(sharding)
1071+
return Reactant.ConcreteRArray{T,N,1,Reactant.Sharding.NoShardInfo}
10731072
else
10741073
error("TODO: implement sharding")
10751074
end

src/Compiler.jl

Lines changed: 98 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,7 @@ end
3636
@nospecialize(obj::AbstractArray{T}), field, val
3737
) where {T}
3838
ancestor_obj = ancestor(obj)
39-
if isbitstype(T) || ancestor_obj isa RArray
40-
if val isa XLA.AsyncBuffer
41-
if Reactant.Sharding.is_sharded(ancestor_obj)
42-
error("`val` can't be a buffer if `obj` is sharded")
43-
else
44-
return Base.setfield!(obj, field, (val,))
45-
end
46-
end
47-
return Base.setfield!(obj, field, val)
48-
end
39+
(isbitstype(T) || ancestor_obj isa RArray) && return Base.setfield!(obj, field, val)
4940
return Base.setindex!(obj, val, field)
5041
end
5142

@@ -75,40 +66,48 @@ function create_result(
7566
return Expr(:new, T, elems...)
7667
end
7768

69+
function __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
70+
device_to_array_slices, partition_spec = path_to_shard_info[path]
71+
delete!(path_to_shard_info, path)
72+
sharding = Reactant.Sharding.NamedSharding(sharding_mesh, partition_spec)
73+
return Reactant.Sharding.ShardInfo(sharding, device_to_array_slices)
74+
end
75+
7876
function create_result(
79-
tocopy::ConcreteRNumber{T}, path, result_stores, path_to_shard_info, sharding_mesh
80-
) where {T}
77+
tocopy::ConcreteRNumber{T,D,S}, path, result_stores, path_to_shard_info, sharding_mesh
78+
) where {T,D,S}
8179
if haskey(result_stores, path)
8280
restore = result_stores[path]
8381
delete!(result_stores, path)
84-
return :(ConcreteRNumber{$T}($restore))
82+
if path_to_shard_info !== nothing # restore sharding
83+
sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
84+
return :(ConcreteRNumber{$T,length($(restore)),$(typeof(sharding))}(
85+
($(restore)...,), $sharding
86+
))
87+
else
88+
return :(ConcreteRNumber{$T}($restore))
89+
end
90+
end
91+
92+
if path_to_shard_info !== nothing # restore sharding
93+
sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
94+
return :(ConcreteRNumber{$T,length($(tocopy.data)),$(typeof(sharding))}(
95+
($(tocopy.data...,)), $sharding
96+
))
8597
end
8698
# We will set the data for this later
8799
return :(ConcreteRNumber{$T}($(tocopy.data)))
88100
end
89101

90-
function __construct_sharding_for_carray(
91-
::ConcreteRArray{T,N,D,S}, path, _, path_to_shard_info, sharding_mesh
92-
) where {T,N,D,S}
93-
device_to_array_slices, partition_spec = path_to_shard_info[path]
94-
delete!(path_to_shard_info, path)
95-
sharding = Reactant.Sharding.NamedSharding(sharding_mesh, partition_spec)
96-
return Reactant.Sharding.FinalizedNamedSharding{typeof(sharding),ndims(sharding_mesh)}(
97-
sharding, device_to_array_slices
98-
)
99-
end
100-
101102
function create_result(
102103
tocopy::ConcreteRArray{T,N,D,S}, path, result_stores, path_to_shard_info, sharding_mesh
103104
) where {T,N,D,S}
104105
if haskey(result_stores, path)
105106
restore = result_stores[path]
106107
delete!(result_stores, path)
107108
if path_to_shard_info !== nothing # restore sharding
108-
sharding = __construct_sharding_for_carray(
109-
tocopy, path, result_stores, path_to_shard_info, sharding_mesh
110-
)
111-
return :(ConcreteRArray{$T,$N,$(ndims(sharding_mesh)),$(typeof(sharding))}(
109+
sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
110+
return :(ConcreteRArray{$T,$N,length($(restore)),$(typeof(sharding))}(
112111
($(restore)...,), $(tocopy.shape), $sharding
113112
))
114113
else
@@ -117,10 +116,8 @@ function create_result(
117116
end
118117

119118
if path_to_shard_info !== nothing # restore sharding
120-
sharding = __construct_sharding_for_carray(
121-
tocopy, path, result_stores, path_to_shard_info, sharding_mesh
122-
)
123-
return :(ConcreteRArray{$T,$N,$(ndims(sharding_mesh)),$(typeof(sharding))}(
119+
sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
120+
return :(ConcreteRArray{$T,$N,length($(tocopy.data)),$(typeof(sharding))}(
124121
($(tocopy.data)...,), $(tocopy.shape), $sharding
125122
))
126123
end
@@ -365,6 +362,7 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo
365362
"binary_op_transpose_simplify_or",
366363
"binary_op_transpose_simplify_and",
367364
"binary_op_transpose_simplify_xor",
365+
"associative_binary_op_reordering<1>",
368366
"transpose_unary_transpose_abs",
369367
"transpose_unary_transpose_neg",
370368
"transpose_unary_transpose_sqrt",
@@ -380,12 +378,15 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo
380378
"transpose_unary_transpose_sine",
381379
"transpose_unary_transpose_tanh",
382380
"transpose_broadcast_in_dim_to_broadcast_in_dim<16>",
381+
"scatter_indices_are_unique",
382+
"transpose_reduce_simplify",
383383
"replace_neg_add_with_subtract",
384384
"log_const_prop<1>",
385385
"log_plus_one_const_prop<1>",
386386
"binop_const_simplify",
387387
"transpose_broadcast_in_dim_to_broadcast_in_dim",
388388
"not_select_simplify",
389+
"scatter_update_computation_const_prop",
389390
"common_compare_expression_rewrite",
390391
"compare_select_simplify",
391392
"while_simplify<1>",
@@ -794,10 +795,12 @@ function compile_mlir!(
794795
results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)]
795796
nresults = MLIR.IR.Value[]
796797
linear_results2 = TracedType[]
798+
results_mask = falses(length(results))
797799
for (i, op) in enumerate(results)
798800
if !MLIR.IR.is_block_arg(op)
799801
push!(nresults, op)
800802
push!(linear_results2, linear_results[i])
803+
results_mask[i] = true
801804
continue
802805
end
803806
push!(preserved_args, (linear_results[i], MLIR.IR.block_arg_num(op)))
@@ -812,11 +815,18 @@ function compile_mlir!(
812815

813816
out_tys2 = [MLIR.IR.type(a) for a in nresults]
814817

818+
res_attrs = MLIR.IR.attr(compiled_f, "res_attrs")
819+
if res_attrs isa MLIR.IR.Attribute
820+
res_attrs = [
821+
res_attrs[i - 1] for (i, present) in enumerate(results_mask) if present
822+
]
823+
end
824+
815825
func3 = MLIR.Dialects.func.func_(;
816826
sym_name="main",
817827
function_type=MLIR.IR.FunctionType(in_tys, out_tys2),
818828
arg_attrs=MLIR.IR.attr(compiled_f, "arg_attrs"),
819-
res_attrs=MLIR.IR.attr(compiled_f, "res_attrs"),
829+
res_attrs,
820830
no_inline=MLIR.IR.attr(compiled_f, "no_inline"),
821831
body=MLIR.IR.Region(),
822832
)
@@ -837,7 +847,6 @@ function compile_mlir!(
837847
linear_args,
838848
in_tys,
839849
linear_results2,
840-
mlir_fn_res.linear_result_shard_info,
841850
mlir_fn_res.num_partitions,
842851
mlir_fn_res.num_replicas,
843852
mlir_fn_res.is_sharded,
@@ -862,6 +871,22 @@ macro code_hlo(args...)
862871
$(first)($(compiled))))
863872
end
864873

874+
"""
875+
@code_mhlo [optimize = ...] [no_nan = <true/false>] f(args...)
876+
877+
Similar to `@code_hlo`, but prints the module after running the XLA compiler.
878+
"""
879+
macro code_mhlo(args...)
880+
default_options = Dict{Symbol,Any}(
881+
:optimize => true, :no_nan => false, :client => nothing
882+
)
883+
compile_expr, (; compiled) = compile_call_expr(
884+
__module__, compile_xla, default_options, args...
885+
)
886+
return esc(:($(compile_expr);
887+
$(first)($(compiled))))
888+
end
889+
865890
"""
866891
@compile [optimize = ...] [no_nan = <true/false>] [sync = <true/false>] f(args...)
867892
"""
@@ -998,7 +1023,7 @@ function codegen_flatten!(
9981023

9991024
if is_sharded
10001025
carg = inv_seen_args[arg]
1001-
if carg isa ConcreteRArray && Reactant.Sharding.is_sharded(carg)
1026+
if Reactant.Sharding.is_sharded(carg)
10021027
for j in 1:length(mesh)
10031028
sbuf = Symbol(:sbuf_, i, "_", j)
10041029
push!(flatten_names, sbuf)
@@ -1007,17 +1032,11 @@ function codegen_flatten!(
10071032
else
10081033
# Warn here first and then replicate the input across all devices on the
10091034
# mesh
1010-
if carg isa ConcreteRArray
1011-
@warn "Input $carg is not sharded, replicating across all devices. It \
1012-
is recommended to replicate the input across all devices on the \
1013-
mesh manually using `Reactant.Sharding.NamedSharding`" maxlog = 1
1014-
end
1035+
@warn "Input $carg is not sharded, replicating across all devices. It \
1036+
is recommended to replicate the input across all devices on the \
1037+
mesh manually using `Reactant.Sharding.NamedSharding`" maxlog = 1
10151038
buf = Symbol(:buf_, i)
1016-
if carg isa ConcreteRArray
1017-
push!(flatten_code, :($buf = XLA.synced_buffer(only($usbuf))))
1018-
else
1019-
push!(flatten_code, :($buf = XLA.synced_buffer($usbuf)))
1020-
end
1039+
push!(flatten_code, :($buf = XLA.synced_buffer(only($usbuf))))
10211040
for j in 1:length(mesh)
10221041
device_id = mesh.device_ids[j]
10231042
device_ordinal = XLA.device_ordinal(client, device_id)
@@ -1030,9 +1049,7 @@ function codegen_flatten!(
10301049
else
10311050
sbuf = Symbol(:sbuf_, i)
10321051
push!(flatten_names, sbuf)
1033-
if arg isa TracedRNumber
1034-
push!(flatten_code, :($sbuf = XLA.synced_buffer($usbuf)))
1035-
elseif arg isa TracedRArray
1052+
if arg isa TracedRArray || arg isa TracedRNumber
10361053
push!(flatten_code, :($sbuf = only(XLA.synced_buffer($usbuf))))
10371054
else
10381055
error("Unsupported type $(typeof(arg))")
@@ -1061,7 +1078,6 @@ function codegen_unflatten!(
10611078
concrete_result,
10621079
result_stores,
10631080
path_to_shard_info,
1064-
is_sharded::Bool,
10651081
linear_result_shard_info,
10661082
sharding_mesh,
10671083
)
@@ -1369,26 +1385,28 @@ function compile_xla(f, args; client=nothing, kwargs...)
13691385
mlir_fn_res.is_sharded,
13701386
)
13711387

1372-
mlir_fn_res.num_partitions > 1 && (device = nothing)
1373-
13741388
# Attach a name, and partitioning attributes to the module
13751389
__add_mhlo_attributes_and_name!(
13761390
mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas
13771391
)
13781392

13791393
# compile MLIR module to XLA executable
1380-
is_sharded = mlir_fn_res.num_partitions > 1
1381-
if is_sharded
1382-
# mesh_shape = collect(Int64, size(mlir_fn_res.sharding_mesh))
1383-
mesh_ids = collect(Int64, vec(mlir_fn_res.sharding_mesh.device_ids))
1394+
mlir_fn_res.is_sharded && (device = nothing)
1395+
mesh_ids = if mlir_fn_res.is_sharded
1396+
collect(Int64, mlir_fn_res.sharding_mesh.device_ids)
13841397
else
1385-
# mesh_shape = Int64[]
1386-
mesh_ids = Int64[]
1398+
Int64[]
13871399
end
1388-
# exec = XLA.Compile(client, device, mod; is_sharded, mesh_ids, mesh_shape)
1389-
exec = XLA.Compile(client, device, mod; is_sharded, mesh_ids)
1400+
exec = XLA.Compile(
1401+
client,
1402+
device,
1403+
mod;
1404+
num_results=length(mlir_fn_res.linear_results),
1405+
mlir_fn_res.is_sharded,
1406+
mesh_ids,
1407+
)
13901408

1391-
return exec, mlir_fn_res, device, client
1409+
return mod, exec, mlir_fn_res, device, client
13921410
finally
13931411
MLIR.IR.deactivate!(ctx)
13941412
end
@@ -1398,7 +1416,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
13981416
end
13991417

14001418
function compile(f, args; sync=false, kwargs...)
1401-
exec, mlir_fn_res, device, client = compile_xla(f, args; kwargs...)
1419+
_, exec, mlir_fn_res, device, client = compile_xla(f, args; kwargs...)
14021420
(; linear_args, seen_args, linear_results, preserved_args, concrete_result) =
14031421
mlir_fn_res
14041422

@@ -1408,11 +1426,7 @@ function compile(f, args; sync=false, kwargs...)
14081426
end
14091427

14101428
result_stores = Dict{Tuple,Symbol}()
1411-
path_to_shard_info = if mlir_fn_res.is_sharded
1412-
Dict{Tuple,Tuple{Array{Vector{UnitRange{Int}}},Tuple}}()
1413-
else
1414-
nothing
1415-
end
1429+
path_to_shard_info = mlir_fn_res.is_sharded ? Dict{Tuple,Tuple}() : nothing
14161430

14171431
# generate Julia `Thunk` code
14181432
flatten_arg_names, flatten_code = codegen_flatten!(
@@ -1431,9 +1445,25 @@ function compile(f, args; sync=false, kwargs...)
14311445
donated_args_mask,
14321446
length(linear_results),
14331447
mlir_fn_res.is_sharded,
1434-
mlir_fn_res.is_sharded ? vec(mlir_fn_res.sharding_mesh.device_ids) : Int64[],
1448+
if mlir_fn_res.is_sharded
1449+
collect(Int64, mlir_fn_res.sharding_mesh.device_ids)
1450+
else
1451+
Int64[]
1452+
end,
14351453
)
14361454

1455+
linear_result_shard_info = if mlir_fn_res.is_sharded
1456+
# Generate a tuple of DeviceToArraySlices and PartitionSpecs
1457+
output_shardings = XLA.get_output_shardings(exec)
1458+
XLA.compute_array_indices_and_partition_spec.(
1459+
output_shardings,
1460+
size.(mlir_fn_res.linear_results),
1461+
(mlir_fn_res.sharding_mesh,),
1462+
)
1463+
else
1464+
ntuple(Returns(nothing), length(linear_results))
1465+
end
1466+
14371467
unflatten_code = codegen_unflatten!(
14381468
linear_args,
14391469
preserved_args,
@@ -1442,8 +1472,7 @@ function compile(f, args; sync=false, kwargs...)
14421472
concrete_result,
14431473
result_stores,
14441474
path_to_shard_info,
1445-
mlir_fn_res.is_sharded,
1446-
mlir_fn_res.linear_result_shard_info,
1475+
linear_result_shard_info,
14471476
mlir_fn_res.sharding_mesh,
14481477
)
14491478

0 commit comments

Comments
 (0)