Skip to content

fix: multi-device execution and sharding [take III] #713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.2.26"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand Down Expand Up @@ -61,6 +62,7 @@ ArrayInterface = "7.17.1"
CEnum = "0.5"
CUDA = "5.6"
Downloads = "1.6"
EnumX = "1"
Enzyme = "0.13.28"
EnzymeCore = "0.8.8"
Functors = "0.5"
Expand All @@ -79,7 +81,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.5"
Reactant_jll = "0.0.64"
Reactant_jll = "0.0.66"
Scratch = "1.2"
Sockets = "1.10"
SpecialFunctions = "2.4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ within_compile

```@docs
@code_hlo
@code_mhlo
```

## Profile XLA
Expand Down
5 changes: 2 additions & 3 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1067,9 +1067,8 @@ Base.@nospecializeinfer function Reactant.traced_type_inner(
T = eltype(A)
N = ndims(A)
if mode == Reactant.ArrayToConcrete && T <: Reactant.ReactantPrimitive
if sharding isa Reactant.Sharding.NoSharding ||
sharding isa Reactant.Sharding.FinalizedNoSharding
return Reactant.ConcreteRArray{T,N,1,Reactant.Sharding.FinalizedNoSharding}
if !Sharding.is_sharded(sharding)
return Reactant.ConcreteRArray{T,N,1,Reactant.Sharding.NoShardInfo}
else
error("TODO: implement sharding")
end
Expand Down
167 changes: 98 additions & 69 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,7 @@ end
@nospecialize(obj::AbstractArray{T}), field, val
) where {T}
ancestor_obj = ancestor(obj)
if isbitstype(T) || ancestor_obj isa RArray
if val isa XLA.AsyncBuffer
if Reactant.Sharding.is_sharded(ancestor_obj)
error("`val` can't be a buffer if `obj` is sharded")
else
return Base.setfield!(obj, field, (val,))
end
end
return Base.setfield!(obj, field, val)
end
(isbitstype(T) || ancestor_obj isa RArray) && return Base.setfield!(obj, field, val)
return Base.setindex!(obj, val, field)
end

Expand Down Expand Up @@ -75,40 +66,48 @@ function create_result(
return Expr(:new, T, elems...)
end

function __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
device_to_array_slices, partition_spec = path_to_shard_info[path]
delete!(path_to_shard_info, path)
sharding = Reactant.Sharding.NamedSharding(sharding_mesh, partition_spec)
return Reactant.Sharding.ShardInfo(sharding, device_to_array_slices)
end

function create_result(
tocopy::ConcreteRNumber{T}, path, result_stores, path_to_shard_info, sharding_mesh
) where {T}
tocopy::ConcreteRNumber{T,D,S}, path, result_stores, path_to_shard_info, sharding_mesh
) where {T,D,S}
if haskey(result_stores, path)
restore = result_stores[path]
delete!(result_stores, path)
return :(ConcreteRNumber{$T}($restore))
if path_to_shard_info !== nothing # restore sharding
sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
return :(ConcreteRNumber{$T,length($(restore)),$(typeof(sharding))}(
($(restore)...,), $sharding
))
else
return :(ConcreteRNumber{$T}($restore))
end
end

if path_to_shard_info !== nothing # restore sharding
sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
return :(ConcreteRNumber{$T,length($(tocopy.data)),$(typeof(sharding))}(
($(tocopy.data...,)), $sharding
))
end
# We will set the data for this later
return :(ConcreteRNumber{$T}($(tocopy.data)))
end

function __construct_sharding_for_carray(
::ConcreteRArray{T,N,D,S}, path, _, path_to_shard_info, sharding_mesh
) where {T,N,D,S}
device_to_array_slices, partition_spec = path_to_shard_info[path]
delete!(path_to_shard_info, path)
sharding = Reactant.Sharding.NamedSharding(sharding_mesh, partition_spec)
return Reactant.Sharding.FinalizedNamedSharding{typeof(sharding),ndims(sharding_mesh)}(
sharding, device_to_array_slices
)
end

function create_result(
tocopy::ConcreteRArray{T,N,D,S}, path, result_stores, path_to_shard_info, sharding_mesh
) where {T,N,D,S}
if haskey(result_stores, path)
restore = result_stores[path]
delete!(result_stores, path)
if path_to_shard_info !== nothing # restore sharding
sharding = __construct_sharding_for_carray(
tocopy, path, result_stores, path_to_shard_info, sharding_mesh
)
return :(ConcreteRArray{$T,$N,$(ndims(sharding_mesh)),$(typeof(sharding))}(
sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
return :(ConcreteRArray{$T,$N,length($(restore)),$(typeof(sharding))}(
($(restore)...,), $(tocopy.shape), $sharding
))
else
Expand All @@ -117,10 +116,8 @@ function create_result(
end

if path_to_shard_info !== nothing # restore sharding
sharding = __construct_sharding_for_carray(
tocopy, path, result_stores, path_to_shard_info, sharding_mesh
)
return :(ConcreteRArray{$T,$N,$(ndims(sharding_mesh)),$(typeof(sharding))}(
sharding = __reconstruct_shardinfo(path, path_to_shard_info, sharding_mesh)
return :(ConcreteRArray{$T,$N,length($(tocopy.data)),$(typeof(sharding))}(
($(tocopy.data)...,), $(tocopy.shape), $sharding
))
end
Expand Down Expand Up @@ -365,6 +362,7 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo
"binary_op_transpose_simplify_or",
"binary_op_transpose_simplify_and",
"binary_op_transpose_simplify_xor",
"associative_binary_op_reordering<1>",
"transpose_unary_transpose_abs",
"transpose_unary_transpose_neg",
"transpose_unary_transpose_sqrt",
Expand All @@ -380,12 +378,15 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo
"transpose_unary_transpose_sine",
"transpose_unary_transpose_tanh",
"transpose_broadcast_in_dim_to_broadcast_in_dim<16>",
"scatter_indices_are_unique",
"transpose_reduce_simplify",
"replace_neg_add_with_subtract",
"log_const_prop<1>",
"log_plus_one_const_prop<1>",
"binop_const_simplify",
"transpose_broadcast_in_dim_to_broadcast_in_dim",
"not_select_simplify",
"scatter_update_computation_const_prop",
"common_compare_expression_rewrite",
"compare_select_simplify",
"while_simplify<1>",
Expand Down Expand Up @@ -794,10 +795,12 @@ function compile_mlir!(
results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)]
nresults = MLIR.IR.Value[]
linear_results2 = TracedType[]
results_mask = falses(length(results))
for (i, op) in enumerate(results)
if !MLIR.IR.is_block_arg(op)
push!(nresults, op)
push!(linear_results2, linear_results[i])
results_mask[i] = true
continue
end
push!(preserved_args, (linear_results[i], MLIR.IR.block_arg_num(op)))
Expand All @@ -812,11 +815,18 @@ function compile_mlir!(

out_tys2 = [MLIR.IR.type(a) for a in nresults]

res_attrs = MLIR.IR.attr(compiled_f, "res_attrs")
if res_attrs isa MLIR.IR.Attribute
res_attrs = [
res_attrs[i - 1] for (i, present) in enumerate(results_mask) if present
]
end

func3 = MLIR.Dialects.func.func_(;
sym_name="main",
function_type=MLIR.IR.FunctionType(in_tys, out_tys2),
arg_attrs=MLIR.IR.attr(compiled_f, "arg_attrs"),
res_attrs=MLIR.IR.attr(compiled_f, "res_attrs"),
res_attrs,
no_inline=MLIR.IR.attr(compiled_f, "no_inline"),
body=MLIR.IR.Region(),
)
Expand All @@ -837,7 +847,6 @@ function compile_mlir!(
linear_args,
in_tys,
linear_results2,
mlir_fn_res.linear_result_shard_info,
mlir_fn_res.num_partitions,
mlir_fn_res.num_replicas,
mlir_fn_res.is_sharded,
Expand All @@ -862,6 +871,22 @@ macro code_hlo(args...)
$(first)($(compiled))))
end

"""
@code_mhlo [optimize = ...] [no_nan = <true/false>] f(args...)

Similar to `@code_hlo`, but prints the module after running the XLA compiler.
"""
macro code_mhlo(args...)
default_options = Dict{Symbol,Any}(
:optimize => true, :no_nan => false, :client => nothing
)
compile_expr, (; compiled) = compile_call_expr(
__module__, compile_xla, default_options, args...
)
return esc(:($(compile_expr);
$(first)($(compiled))))
end

"""
@compile [optimize = ...] [no_nan = <true/false>] [sync = <true/false>] f(args...)
"""
Expand Down Expand Up @@ -998,7 +1023,7 @@ function codegen_flatten!(

if is_sharded
carg = inv_seen_args[arg]
if carg isa ConcreteRArray && Reactant.Sharding.is_sharded(carg)
if Reactant.Sharding.is_sharded(carg)
for j in 1:length(mesh)
sbuf = Symbol(:sbuf_, i, "_", j)
push!(flatten_names, sbuf)
Expand All @@ -1007,17 +1032,11 @@ function codegen_flatten!(
else
# Warn here first and then replicate the input across all devices on the
# mesh
if carg isa ConcreteRArray
@warn "Input $carg is not sharded, replicating across all devices. It \
is recommended to replicate the input across all devices on the \
mesh manually using `Reactant.Sharding.NamedSharding`" maxlog = 1
end
@warn "Input $carg is not sharded, replicating across all devices. It \
is recommended to replicate the input across all devices on the \
mesh manually using `Reactant.Sharding.NamedSharding`" maxlog = 1
buf = Symbol(:buf_, i)
if carg isa ConcreteRArray
push!(flatten_code, :($buf = XLA.synced_buffer(only($usbuf))))
else
push!(flatten_code, :($buf = XLA.synced_buffer($usbuf)))
end
push!(flatten_code, :($buf = XLA.synced_buffer(only($usbuf))))
for j in 1:length(mesh)
device_id = mesh.device_ids[j]
device_ordinal = XLA.device_ordinal(client, device_id)
Expand All @@ -1030,9 +1049,7 @@ function codegen_flatten!(
else
sbuf = Symbol(:sbuf_, i)
push!(flatten_names, sbuf)
if arg isa TracedRNumber
push!(flatten_code, :($sbuf = XLA.synced_buffer($usbuf)))
elseif arg isa TracedRArray
if arg isa TracedRArray || arg isa TracedRNumber
push!(flatten_code, :($sbuf = only(XLA.synced_buffer($usbuf))))
else
error("Unsupported type $(typeof(arg))")
Expand Down Expand Up @@ -1061,7 +1078,6 @@ function codegen_unflatten!(
concrete_result,
result_stores,
path_to_shard_info,
is_sharded::Bool,
linear_result_shard_info,
sharding_mesh,
)
Expand Down Expand Up @@ -1369,26 +1385,28 @@ function compile_xla(f, args; client=nothing, kwargs...)
mlir_fn_res.is_sharded,
)

mlir_fn_res.num_partitions > 1 && (device = nothing)

# Attach a name, and partitioning attributes to the module
__add_mhlo_attributes_and_name!(
mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas
)

# compile MLIR module to XLA executable
is_sharded = mlir_fn_res.num_partitions > 1
if is_sharded
# mesh_shape = collect(Int64, size(mlir_fn_res.sharding_mesh))
mesh_ids = collect(Int64, vec(mlir_fn_res.sharding_mesh.device_ids))
mlir_fn_res.is_sharded && (device = nothing)
mesh_ids = if mlir_fn_res.is_sharded
collect(Int64, mlir_fn_res.sharding_mesh.device_ids)
else
# mesh_shape = Int64[]
mesh_ids = Int64[]
Int64[]
end
# exec = XLA.Compile(client, device, mod; is_sharded, mesh_ids, mesh_shape)
exec = XLA.Compile(client, device, mod; is_sharded, mesh_ids)
exec = XLA.Compile(
client,
device,
mod;
num_results=length(mlir_fn_res.linear_results),
mlir_fn_res.is_sharded,
mesh_ids,
)

return exec, mlir_fn_res, device, client
return mod, exec, mlir_fn_res, device, client
finally
MLIR.IR.deactivate!(ctx)
end
Expand All @@ -1398,7 +1416,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
end

function compile(f, args; sync=false, kwargs...)
exec, mlir_fn_res, device, client = compile_xla(f, args; kwargs...)
_, exec, mlir_fn_res, device, client = compile_xla(f, args; kwargs...)
(; linear_args, seen_args, linear_results, preserved_args, concrete_result) =
mlir_fn_res

Expand All @@ -1408,11 +1426,7 @@ function compile(f, args; sync=false, kwargs...)
end

result_stores = Dict{Tuple,Symbol}()
path_to_shard_info = if mlir_fn_res.is_sharded
Dict{Tuple,Tuple{Array{Vector{UnitRange{Int}}},Tuple}}()
else
nothing
end
path_to_shard_info = mlir_fn_res.is_sharded ? Dict{Tuple,Tuple}() : nothing

# generate Julia `Thunk` code
flatten_arg_names, flatten_code = codegen_flatten!(
Expand All @@ -1431,9 +1445,25 @@ function compile(f, args; sync=false, kwargs...)
donated_args_mask,
length(linear_results),
mlir_fn_res.is_sharded,
mlir_fn_res.is_sharded ? vec(mlir_fn_res.sharding_mesh.device_ids) : Int64[],
if mlir_fn_res.is_sharded
collect(Int64, mlir_fn_res.sharding_mesh.device_ids)
else
Int64[]
end,
)

linear_result_shard_info = if mlir_fn_res.is_sharded
# Generate a tuple of DeviceToArraySlices and PartitionSpecs
output_shardings = XLA.get_output_shardings(exec)
XLA.compute_array_indices_and_partition_spec.(
output_shardings,
size.(mlir_fn_res.linear_results),
(mlir_fn_res.sharding_mesh,),
)
else
ntuple(Returns(nothing), length(linear_results))
end

unflatten_code = codegen_unflatten!(
linear_args,
preserved_args,
Expand All @@ -1442,8 +1472,7 @@ function compile(f, args; sync=false, kwargs...)
concrete_result,
result_stores,
path_to_shard_info,
mlir_fn_res.is_sharded,
mlir_fn_res.linear_result_shard_info,
linear_result_shard_info,
mlir_fn_res.sharding_mesh,
)

Expand Down
Loading
Loading