Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ jobs:
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
REACTANT_TEST_GROUP: ${{ matrix.test_group }}
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
- uses: julia-actions/julia-processcoverage@v1
if: steps.run_tests.outcome == 'success'
- uses: codecov/codecov-action@v5
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/downgrade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
REACTANT_TEST_GROUP: ${{ matrix.test_group }}
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
- uses: julia-actions/julia-processcoverage@v1
if: steps.run_tests.outcome == 'success'
- uses: codecov/codecov-action@v5
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.5"
Reactant_jll = "0.0.66"
Reactant_jll = "0.0.69"
Scratch = "1.2"
Sockets = "1.10"
SpecialFunctions = "2.4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ within_compile
```@docs
@code_hlo
@code_mhlo
@code_xla
```

## Profile XLA
Expand Down
108 changes: 82 additions & 26 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -891,8 +891,48 @@ macro code_mhlo(args...)
compile_expr, (; compiled) = compile_call_expr(
__module__, compile_xla, default_options, args...
)
return esc(:($(compile_expr);
$(first)($(compiled))))
#! format: off
return esc(
:(
$(compile_expr);
$(first)($(compiled))
)
)
#! format: on
end

"""
@code_xla [optimize = ...] [no_nan = <true/false>] f(args...)

Similar to `@code_hlo`, but prints the HLO module.
"""
macro code_xla(args...)
default_options = Dict{Symbol,Any}(
:optimize => true, :no_nan => false, :client => nothing
)
compile_expr, (; compiled) = compile_call_expr(
__module__, compile_xla, default_options, args...
)
#! format: off
return esc(
:(
$(compile_expr);
exec = $(compiled)[2];
hlo_modules = $(XLA.get_hlo_modules)(exec);
if length(hlo_modules) == 1
hlo_module = only(hlo_modules)
println(hlo_module)
else
println("HLO modules:")
for (i, hlo_module) in enumerate(hlo_modules)
println("Partition $i:")
println(hlo_module)
println()
end
end
)
)
#! format: on
end

"""
Expand Down Expand Up @@ -994,7 +1034,13 @@ The name is due to its similarity to the `flatten` function in `jax.tree_util.re
The _linearized arguments_ do not directly refer to the are the arguments that have been flattened into a single list.
"""
function codegen_flatten!(
linear_args, seen_args, result_stores, is_sharded::Bool, mesh, client
linear_args,
seen_args,
result_stores,
is_sharded::Bool,
mesh,
linear_parameter_shardings,
client,
)
flatten_names = Symbol[]
flatten_code = Expr[]
Expand Down Expand Up @@ -1027,34 +1073,50 @@ function codegen_flatten!(
for p in path[3:end]
flatcode = :(traced_getfield($flatcode, $(Meta.quot(p))))
end
push!(flatten_code, :($usbuf = $flatcode.data))

if is_sharded
carg = inv_seen_args[arg]
condensed_op_sharding = Reactant.Sharding.XLA.CondensedOpSharding(
linear_parameter_shardings[i]
)
if Reactant.Sharding.is_sharded(carg)
# Currently disabling the error since we roundtrip from MHLO to generate
# the shardings
# # Check if the sharding provided is same as the one we have
# arg_condensed_op_sharding = Reactant.Sharding.XLA.CondensedOpSharding(
# Reactant.Sharding.ShardingWithShape(carg.sharding, size(carg))
# )
# @assert arg_condensed_op_sharding == condensed_op_sharding "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE."

push!(flatten_code, :($usbuf = $flatcode.data))
for j in 1:length(mesh)
sbuf = Symbol(:sbuf_, i, "_", j)
push!(flatten_names, sbuf)
push!(flatten_code, :($sbuf = XLA.synced_buffer(getindex($usbuf, $j))))
end
else
# Warn here first and then replicate the input across all devices on the
# mesh
@warn "Input $carg is not sharded, replicating across all devices. It \
is recommended to replicate the input across all devices on the \
mesh manually using `Reactant.Sharding.NamedSharding`" maxlog = 1
buf = Symbol(:buf_, i)
push!(flatten_code, :($buf = XLA.synced_buffer(only($usbuf))))
push!(flatten_code, :($usbuf = $flatcode))
device_to_array_slices = XLA.sharding_to_concrete_array_indices(
condensed_op_sharding, size(carg), mesh
)
device_ids = vec(mesh)
for j in 1:length(mesh)
device_id = mesh.device_ids[j]
buf = Symbol(:buf_, i, :_, j)
device_id = device_ids[j]
slice = device_to_array_slices[j]
push!(
flatten_code,
:($buf = XLA.synced_buffer(only($usbuf[$(slice)...].data))),
)
device_ordinal = XLA.device_ordinal(client, device_id)
sbuf = Symbol(:sbuf_, i, "_", j)
sbuf = Symbol(:sbuf_, i, :_, j)
device = XLA.ClientGetAddressableDevice(client, device_ordinal)
push!(flatten_names, sbuf)
push!(flatten_code, :($sbuf = XLA.CopyBufferToDevice($buf, $device)))
end
end
else
push!(flatten_code, :($usbuf = $flatcode.data))
sbuf = Symbol(:sbuf_, i)
push!(flatten_names, sbuf)
if arg isa TracedRArray || arg isa TracedRNumber
Expand Down Expand Up @@ -1399,19 +1461,17 @@ function compile_xla(f, args; client=nothing, kwargs...)
)

# compile MLIR module to XLA executable
device_ids = mlir_fn_res.is_sharded ? vec(mlir_fn_res.sharding_mesh) : Int64[]
mlir_fn_res.is_sharded && (device = nothing)
mesh_ids = if mlir_fn_res.is_sharded
collect(Int64, mlir_fn_res.sharding_mesh.device_ids)
else
Int64[]
end

exec = XLA.Compile(
client,
device,
mod;
num_results=length(mlir_fn_res.linear_results),
num_outputs=length(mlir_fn_res.linear_results),
num_parameters=length(mlir_fn_res.linear_args),
mlir_fn_res.is_sharded,
mesh_ids,
device_ids,
)

return mod, exec, mlir_fn_res, device, client
Expand Down Expand Up @@ -1443,6 +1503,7 @@ function compile(f, args; sync=false, kwargs...)
result_stores,
mlir_fn_res.is_sharded,
mlir_fn_res.sharding_mesh,
XLA.get_parameter_shardings(exec),
client,
)

Expand All @@ -1453,15 +1514,10 @@ function compile(f, args; sync=false, kwargs...)
donated_args_mask,
length(linear_results),
mlir_fn_res.is_sharded,
if mlir_fn_res.is_sharded
collect(Int64, mlir_fn_res.sharding_mesh.device_ids)
else
Int64[]
end,
mlir_fn_res.is_sharded ? vec(mlir_fn_res.sharding_mesh) : Int64[],
)

linear_result_shard_info = if mlir_fn_res.is_sharded
# Generate a tuple of DeviceToArraySlices and PartitionSpecs
output_shardings = XLA.get_output_shardings(exec)
XLA.compute_array_indices_and_partition_spec.(
output_shardings,
Expand Down
18 changes: 7 additions & 11 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,20 @@ Base.isempty(x::WrappedConcreteRArray) = isempty(ancestor(x))

function Base.convert(::Type{<:Array}, X::ConcreteRArray{T,N}) where {T,N}
data = Array{T,N}(undef, size(X)...)
XLA.await(X)

if Sharding.is_sharded(X)
# TODO: We can we much more efficient here and only move data from the minimal
# slices that populates the array.
completed = Set{eltype(X.sharding.device_to_array_slices)}()
for idx in 1:length(X.data)
buffer = X.data[idx].buffer
# We can't use a pointer to a subarray since BufferToHost expects the data to
# be contiguous.
slice = X.sharding.device_to_array_slices[idx]
data_slice = data[slice...]
GC.@preserve data_slice buffer begin
XLA.BufferToHost(buffer, pointer(data_slice))
if slice ∉ completed
push!(completed, slice)
else
continue
end
data[slice...] = data_slice
data[slice...] = convert(Array{T}, X.data[idx])
end
else
buf = only(X.data).buffer
buf = XLA.synced_buffer(only(X.data))
GC.@preserve data buf begin
XLA.BufferToHost(buf, pointer(data))
end
Expand Down
7 changes: 1 addition & 6 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1960,8 +1960,6 @@ end
tr
else
if typeof(tr) != typeof(fr)
@show tr.mlir_data
@show fr.mlir_data
@assert typeof(tr) == typeof(fr) "$(typeof(tr)) vs $(typeof(fr))"
end
tr
Expand Down Expand Up @@ -2141,10 +2139,7 @@ end
location=mlir_stacktrace("mesh", @__FILE__, @__LINE__),
)
return mesh(
mod,
[k => Int64(v) for (k, v) in zip(m.axis_names, size(m))],
collect(Int64, m.device_ids);
location,
mod, [k => Int64(v) for (k, v) in zip(m.axis_names, size(m))], vec(m); location
)
end

Expand Down
18 changes: 16 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,23 @@ function Enzyme.make_zero(
end

using .Compiler:
@compile, @code_hlo, @code_mhlo, @jit, traced_getfield, create_result, compile
@compile,
@code_hlo,
@code_mhlo,
@jit,
@code_xla,
traced_getfield,
create_result,
compile
export ConcreteRArray,
ConcreteRNumber, @compile, @code_hlo, @code_mhlo, @jit, @trace, within_compile
ConcreteRNumber,
@compile,
@code_hlo,
@code_mhlo,
@code_xla,
@jit,
@trace,
within_compile

const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()

Expand Down
Loading
Loading