Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,9 @@ function vendored_buildIntrinsicLoweringPipeline(
return LLVM.add!(mpm, LLVM.AlwaysInlinerPass())
end

function vendored_buildScalarOptimizerPipeline(fpm, @nospecialize(job), opt_level; instcombine::Bool=false)
function vendored_buildScalarOptimizerPipeline(
fpm, @nospecialize(job), opt_level; instcombine::Bool=false
)
if opt_level >= 2
LLVM.add!(fpm, LLVM.Interop.AllocOptPass())
LLVM.add!(fpm, LLVM.SROAPass())
Expand All @@ -597,9 +599,9 @@ function vendored_buildScalarOptimizerPipeline(fpm, @nospecialize(job), opt_leve
LLVM.add!(fpm, LLVM.DCEPass())
LLVM.add!(fpm, LLVM.IRCEPass())
if instcombine
LLVM.add!(fpm, LLVM.InstCombinePass())
LLVM.add!(fpm, LLVM.InstCombinePass())
else
LLVM.add!(fpm, LLVM.InstSimplifyPass())
LLVM.add!(fpm, LLVM.InstSimplifyPass())
end
LLVM.add!(fpm, LLVM.JumpThreadingPass())
end
Expand Down
8 changes: 4 additions & 4 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ function codegen_flatten!(

push!(flatten_code, :($usbuf = $flatcode.data))
for j in 1:length(mesh)
sbuf = Symbol(:sbuf_, i, "_", mesh.device_ids[j])
sbuf = Symbol(:sbuf_, i, "_", mesh.logical_device_ids[j])
push!(flatten_names, sbuf)
push!(flatten_code, :($sbuf = XLA.synced_buffer(getindex($usbuf, $j))))
end
Expand All @@ -1188,10 +1188,10 @@ function codegen_flatten!(
)
push!(flatten_code, :($usbuf = $flatcode))
device_to_array_slices = XLA.sharding_to_concrete_array_indices(
condensed_op_sharding, size(carg), mesh.device_ids
condensed_op_sharding, size(carg), mesh.logical_device_ids
)
for j in 1:length(mesh)
device_id = mesh.device_ids[j]
device_id = mesh.logical_device_ids[j]
buf = Symbol(:buf_, i, :_, device_id)
slice = device_to_array_slices[j]
push!(
Expand Down Expand Up @@ -1548,7 +1548,7 @@ function compile_xla(f, args; client=nothing, kwargs...)

# compile MLIR module to XLA executable
global_device_ids = if mlir_fn_res.is_sharded
collect(Int64, mlir_fn_res.sharding_mesh.device_ids)
vec(mlir_fn_res.sharding_mesh.device_ids)
else
Int64[]
end
Expand Down
25 changes: 15 additions & 10 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2235,7 +2235,7 @@ We return a NamedTuple with the following fields:
cache !== nothing && haskey(cache, m) && return cache[m]
result = mesh(
[k => Int64(v) for (k, v) in zip(m.axis_names, size(m))],
collect(Int64, m.logical_device_ids);
m.logical_device_ids;
mod,
sym_name,
location,
Expand All @@ -2246,23 +2246,25 @@ end

@noinline function mesh(
mesh_axes::Vector{<:Pair{<:Union{String,Symbol},Int64}},
device_ids::Vector{Int64};
device_ids::AbstractVector{Int64};
mod::MLIR.IR.Module=MLIR.IR.mmodule(),
sym_name::String="mesh",
location=mlir_stacktrace("mesh", @__FILE__, @__LINE__),
)
# See https://github.com/openxla/shardy/blob/f9d83e779a58b811b848c4edfaf68e88b636787d/shardy/dialect/sdy/ir/verifiers.cc#L647-L699 for the checks
ndevices = prod(last, mesh_axes)

@assert allunique(first, mesh_axes) "mesh_axes must be unique"
@assert ndevices == length(device_ids) "length(device_ids) should be same as \
prod(last, mesh_axes)"
@assert all(x -> x ≥ 0, device_ids) "device_ids must be non-negative"
@assert Base.sort(device_ids) == collect(Int64, 0:(ndevices - 1)) "sorted device_ids must be the same as iota(product(axes)), got $(Base.sort(device_ids))"
@assert all(Base.Fix2(≥, 0), device_ids) "device_ids must be non-negative"
@assert Base.sort(device_ids) == 0:(ndevices - 1) "sorted device_ids must be the same \
as iota(product(axes)), got \
$(Base.sort(device_ids))"

if Base.sort(device_ids) == device_ids
# error: if the ordered device ids are the same as iota(product(axes)), no need to specify them for simplicity
device_ids = Int64[]
end
# error: if the ordered device ids are the same as iota(product(axes)), no need to
# specify them for simplicity
issorted(device_ids) && (device_ids = Int64[])

ctx = MLIR.IR.context()
mesh_axis_attrs = [
Expand All @@ -2273,7 +2275,7 @@ end
Int64(length(mesh_axis_attrs)),
mesh_axis_attrs,
Int64(length(device_ids)),
device_ids,
collect(Int64, device_ids),
)

sym_name = Reactant.TracedUtils.__lookup_unique_name_in_module(mod, sym_name)
Expand Down Expand Up @@ -2306,10 +2308,13 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi
`input` and `sharding`.
"""
@noinline function sharding_constraint(
input::Union{TracedRArray,TracedRNumber},
input::Union{AbstractArray,Number},
sharding::Reactant.Sharding.AbstractSharding;
location=mlir_stacktrace("sharding_constraint", @__FILE__, @__LINE__),
)
!(input isa TracedRNumber || input isa TracedRArray) &&
(input = constant(input; location))

cache = Reactant.Compiler.sdycache()
haskey(cache, sharding.mesh) || Ops.mesh(sharding.mesh; location)
(; sym_name, mesh_attr) = cache[sharding.mesh]
Expand Down
76 changes: 40 additions & 36 deletions src/Sharding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,50 +21,54 @@ julia> mesh = Mesh(reshape(devices, 2, 2, 2), (:x, :y, :z));
julia> mesh = Mesh(reshape(devices, 4, 2), (:x, :y));
```
"""
struct Mesh{D,ND}
device_ids::NTuple{ND,Int}
sorted_device_ids::NTuple{ND,Int}
logical_device_ids::NTuple{ND,Int}
shape::Dims{D}
struct Mesh{D}
device_ids::Array{Int64,D}
logical_device_ids::UnitRange{Int}
axis_names::NTuple{D,Symbol}

function Mesh(devices::AbstractArray{<:XLA.AbstractDevice}, axis_names)
return Mesh(XLA.device_ordinal.(devices), axis_names)
end

function Mesh(
devices::NTuple{D,<:XLA.AbstractDevice}, shape::Dims{D}, axis_names
device_ids::AbstractArray{<:Integer,D}, axis_names::NTuple{D,Union{String,Symbol}}
) where {D}
return Mesh(XLA.device_ordinal.(devices), shape, axis_names)
return new{D}(device_ids, 0:(length(device_ids) - 1), Symbol.(axis_names))
end

# XXX (Deprecated): remove in v0.3
function Mesh(
device_ids::AbstractArray{<:Integer,D}, axis_names::NTuple{D,Union{String,Symbol}}
devices::NTuple{D,<:XLA.AbstractDevice}, shape::Dims{D}, axis_names
) where {D}
return Mesh(Tuple(vec(device_ids)), size(device_ids), axis_names)
Base.depwarn(
"Mesh(devices::NTuple{D,<:XLA.AbstractDevice}, shape::Dims{D}, axis_names) is \
deprecated, use Mesh(reshape(collect(XLA.device_ordinal.(devices)), shape), \
axis_names) instead",
:Mesh,
)
global_ids = reshape(collect(XLA.device_ordinal.(devices)), shape)
return Mesh(global_ids, axis_names)
end

# XXX (Deprecated): remove in v0.3
function Mesh(
device_ids::NTuple{D1,Int},
shape::Dims{D},
axis_names::NTuple{D,Union{String,Symbol}},
device_ids::Dims{D1}, shape::Dims{D}, axis_names::NTuple{D,Union{String,Symbol}}
) where {D,D1}
@assert allunique(device_ids)
return new{D,D1}(
device_ids,
Tuple(sort([device_ids...])),
ntuple(Base.Fix2(-, 1), D1),
shape,
Symbol.(axis_names),
Base.depwarn(
"Mesh(device_ids::Dims{D1}, shape::Dims{D}, \
axis_names::NTuple{D,Union{String,Symbol}}) is deprecated, use \
Mesh(reshape(collect(Int64, device_ids), shape), axis_names) instead",
:Mesh,
)
return Mesh(reshape(collect(Int64, device_ids), shape), axis_names)
end
end

Base.length(::Mesh{D,ND}) where {D,ND} = ND
Base.length(m::Mesh) = length(m.device_ids)
Base.ndims(::Mesh{D}) where {D} = D

Base.size(mesh::Mesh) = mesh.shape
Base.size(mesh::Mesh, axis::Int) = mesh.shape[axis]
Base.size(mesh::Mesh) = size(mesh.device_ids)
Base.size(mesh::Mesh, axis::Int) = size(mesh.device_ids, axis)
function Base.size(mesh::Mesh, axis::Union{String,Symbol})
return size(mesh, findfirst(==(Symbol(axis)), mesh.axis_names))
end
Expand Down Expand Up @@ -146,18 +150,18 @@ julia> sharding = NamedSharding(mesh, (nothing, nothing)); # fully replicated Ma

See also: [`Sharding.NoSharding`](@ref)
"""
struct NamedSharding{D1,D2,P<:Tuple,D3} <: AbstractSharding
mesh::Mesh{D1,D2}
struct NamedSharding{D1,D2,P<:Tuple} <: AbstractSharding
mesh::Mesh{D1}
partition_spec::P
is_closed::NTuple{D3,Bool}
priority::NTuple{D3,Int}
is_closed::NTuple{D2,Bool}
priority::NTuple{D2,Int}

function NamedSharding(
mesh::Mesh{D1,D2},
mesh::Mesh{D1},
partition_spec::P;
is_closed::NTuple{D3,Bool}=ntuple(Returns(true), length(partition_spec)),
priority::NTuple{D3,Int}=ntuple(i -> -1, length(partition_spec)),
) where {D1,D2,P<:Tuple,D3}
is_closed::NTuple{D2,Bool}=ntuple(Returns(true), length(partition_spec)),
priority::NTuple{D2,Int}=ntuple(i -> -1, length(partition_spec)),
) where {D1,P<:Tuple,D2}
axis_names = Symbol[]
pspec = ()
for p in partition_spec
Expand All @@ -177,7 +181,7 @@ struct NamedSharding{D1,D2,P<:Tuple,D3} <: AbstractSharding
end
@assert allunique(axis_names) "Duplicate axis names!"

return new{D1,D2,typeof(pspec),D3}(mesh, pspec, is_closed, priority)
return new{D1,D2,typeof(pspec)}(mesh, pspec, is_closed, priority)
end
end

Expand Down Expand Up @@ -226,17 +230,17 @@ end
# This stores the sharding information in the form of XLA.HloSharding, and provides a
# central type for the final storage. It also potentially saves us the pain of not having
# to regenerate the partition spec from the HloSharding.
struct HloSharding{M,D,D2} <: AbstractSharding
struct HloSharding{D1,D2} <: AbstractSharding
hlo_sharding::XLA.HloSharding
mesh::Mesh{M,D}
mesh::Mesh{D1}
is_closed::NTuple{D2,Bool}
priority::NTuple{D2,Int}

function HloSharding(
hlo_sharding::XLA.HloSharding, mesh::Mesh{M,D}, is_closed, priority
) where {M,D}
hlo_sharding::XLA.HloSharding, mesh::Mesh{D1}, is_closed, priority
) where {D1}
@assert length(is_closed) == length(priority)
return new{M,D,length(is_closed)}(hlo_sharding, mesh, is_closed, priority)
return new{D1,length(is_closed)}(hlo_sharding, mesh, is_closed, priority)
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ function make_mlir_fn(
# TODO: support multiple meshes
if length(unique_meshes) > 1
error("Currently we support using a single mesh")
sorted_devices = [m.sorted_device_ids for m in unique_meshes]
sorted_devices = [sort(vec(m.device_ids)) for m in unique_meshes]
@assert allequal(sorted_devices) "All meshes must have the same device ids"
end
sharding_mesh = first(unique_meshes)
Expand Down
Loading