diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 57ecb45712..1741252356 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -585,7 +585,9 @@ function vendored_buildIntrinsicLoweringPipeline( return LLVM.add!(mpm, LLVM.AlwaysInlinerPass()) end -function vendored_buildScalarOptimizerPipeline(fpm, @nospecialize(job), opt_level; instcombine::Bool=false) +function vendored_buildScalarOptimizerPipeline( + fpm, @nospecialize(job), opt_level; instcombine::Bool=false +) if opt_level >= 2 LLVM.add!(fpm, LLVM.Interop.AllocOptPass()) LLVM.add!(fpm, LLVM.SROAPass()) @@ -597,9 +599,9 @@ function vendored_buildScalarOptimizerPipeline(fpm, @nospecialize(job), opt_leve LLVM.add!(fpm, LLVM.DCEPass()) LLVM.add!(fpm, LLVM.IRCEPass()) if instcombine - LLVM.add!(fpm, LLVM.InstCombinePass()) + LLVM.add!(fpm, LLVM.InstCombinePass()) else - LLVM.add!(fpm, LLVM.InstSimplifyPass()) + LLVM.add!(fpm, LLVM.InstSimplifyPass()) end LLVM.add!(fpm, LLVM.JumpThreadingPass()) end diff --git a/src/Compiler.jl b/src/Compiler.jl index ba272a9ea7..b6fd0a7c1c 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1178,7 +1178,7 @@ function codegen_flatten!( push!(flatten_code, :($usbuf = $flatcode.data)) for j in 1:length(mesh) - sbuf = Symbol(:sbuf_, i, "_", mesh.device_ids[j]) + sbuf = Symbol(:sbuf_, i, "_", mesh.logical_device_ids[j]) push!(flatten_names, sbuf) push!(flatten_code, :($sbuf = XLA.synced_buffer(getindex($usbuf, $j)))) end @@ -1188,10 +1188,10 @@ function codegen_flatten!( ) push!(flatten_code, :($usbuf = $flatcode)) device_to_array_slices = XLA.sharding_to_concrete_array_indices( - condensed_op_sharding, size(carg), mesh.device_ids + condensed_op_sharding, size(carg), mesh.logical_device_ids ) for j in 1:length(mesh) - device_id = mesh.device_ids[j] + device_id = mesh.logical_device_ids[j] buf = Symbol(:buf_, i, :_, device_id) slice = device_to_array_slices[j] push!( @@ -1548,7 +1548,7 @@ function compile_xla(f, args; client=nothing, kwargs...) # compile MLIR module to XLA executable global_device_ids = if mlir_fn_res.is_sharded - collect(Int64, mlir_fn_res.sharding_mesh.device_ids) + vec(mlir_fn_res.sharding_mesh.device_ids) else Int64[] end diff --git a/src/Ops.jl b/src/Ops.jl index 43fea7ae2a..aa00bbf257 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2235,7 +2235,7 @@ We return a NamedTuple with the following fields: cache !== nothing && haskey(cache, m) && return cache[m] result = mesh( [k => Int64(v) for (k, v) in zip(m.axis_names, size(m))], - collect(Int64, m.logical_device_ids); + m.logical_device_ids; mod, sym_name, location, @@ -2246,23 +2246,25 @@ end @noinline function mesh( mesh_axes::Vector{<:Pair{<:Union{String,Symbol},Int64}}, - device_ids::Vector{Int64}; + device_ids::AbstractVector{Int64}; mod::MLIR.IR.Module=MLIR.IR.mmodule(), sym_name::String="mesh", location=mlir_stacktrace("mesh", @__FILE__, @__LINE__), ) # See https://github.com/openxla/shardy/blob/f9d83e779a58b811b848c4edfaf68e88b636787d/shardy/dialect/sdy/ir/verifiers.cc#L647-L699 for the checks ndevices = prod(last, mesh_axes) + @assert allunique(first, mesh_axes) "mesh_axes must be unique" @assert ndevices == length(device_ids) "length(device_ids) should be same as \ prod(last, mesh_axes)" - @assert all(x -> x ≥ 0, device_ids) "device_ids must be non-negative" - @assert Base.sort(device_ids) == collect(Int64, 0:(ndevices - 1)) "sorted device_ids must be the same as iota(product(axes)), got $(Base.sort(device_ids))" + @assert all(Base.Fix2(≥, 0), device_ids) "device_ids must be non-negative" + @assert Base.sort(device_ids) == 0:(ndevices - 1) "sorted device_ids must be the same \ + as iota(product(axes)), got \ + $(Base.sort(device_ids))" - if Base.sort(device_ids) == device_ids - # error: if the ordered device ids are the same as iota(product(axes)), no need to specify them for simplicity - device_ids = Int64[] - end + # error: if the ordered device ids are the same as iota(product(axes)), no need to + # specify them for simplicity + issorted(device_ids) && (device_ids = Int64[]) ctx = MLIR.IR.context() mesh_axis_attrs = [ @@ -2273,7 +2275,7 @@ end Int64(length(mesh_axis_attrs)), mesh_axis_attrs, Int64(length(device_ids)), - device_ids, + collect(Int64, device_ids), ) sym_name = Reactant.TracedUtils.__lookup_unique_name_in_module(mod, sym_name) @@ -2306,10 +2308,13 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi `input` and `sharding`. """ @noinline function sharding_constraint( - input::Union{TracedRArray,TracedRNumber}, + input::Union{AbstractArray,Number}, sharding::Reactant.Sharding.AbstractSharding; location=mlir_stacktrace("sharding_constraint", @__FILE__, @__LINE__), ) + !(input isa TracedRNumber || input isa TracedRArray) && + (input = constant(input; location)) + cache = Reactant.Compiler.sdycache() haskey(cache, sharding.mesh) || Ops.mesh(sharding.mesh; location) (; sym_name, mesh_attr) = cache[sharding.mesh] diff --git a/src/Sharding.jl b/src/Sharding.jl index 822d20bb4f..fc4e36a2c5 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -21,11 +21,9 @@ julia> mesh = Mesh(reshape(devices, 2, 2, 2), (:x, :y, :z)); julia> mesh = Mesh(reshape(devices, 4, 2), (:x, :y)); ``` """ -struct Mesh{D,ND} - device_ids::NTuple{ND,Int} - sorted_device_ids::NTuple{ND,Int} - logical_device_ids::NTuple{ND,Int} - shape::Dims{D} +struct Mesh{D} + device_ids::Array{Int64,D} + logical_device_ids::UnitRange{Int} axis_names::NTuple{D,Symbol} function Mesh(devices::AbstractArray{<:XLA.AbstractDevice}, axis_names) @@ -33,38 +31,44 @@ struct Mesh{D,ND} end function Mesh( - devices::NTuple{D,<:XLA.AbstractDevice}, shape::Dims{D}, axis_names + device_ids::AbstractArray{<:Integer,D}, axis_names::NTuple{D,Union{String,Symbol}} ) where {D} - return Mesh(XLA.device_ordinal.(devices), shape, axis_names) + return new{D}(device_ids, 0:(length(device_ids) - 1), Symbol.(axis_names)) end + # XXX (Deprecated): remove in v0.3 function Mesh( - device_ids::AbstractArray{<:Integer,D}, axis_names::NTuple{D,Union{String,Symbol}} + devices::NTuple{D,<:XLA.AbstractDevice}, shape::Dims{D}, axis_names ) where {D} - return Mesh(Tuple(vec(device_ids)), size(device_ids), axis_names) + Base.depwarn( + "Mesh(devices::NTuple{D,<:XLA.AbstractDevice}, shape::Dims{D}, axis_names) is \ + deprecated, use Mesh(reshape(collect(XLA.device_ordinal.(devices)), shape), \ + axis_names) instead", + :Mesh, + ) + global_ids = reshape(collect(XLA.device_ordinal.(devices)), shape) + return Mesh(global_ids, axis_names) end + # XXX (Deprecated): remove in v0.3 function Mesh( - device_ids::NTuple{D1,Int}, - shape::Dims{D}, - axis_names::NTuple{D,Union{String,Symbol}}, + device_ids::Dims{D1}, shape::Dims{D}, axis_names::NTuple{D,Union{String,Symbol}} ) where {D,D1} - @assert allunique(device_ids) - return new{D,D1}( - device_ids, - Tuple(sort([device_ids...])), - ntuple(Base.Fix2(-, 1), D1), - shape, - Symbol.(axis_names), + Base.depwarn( + "Mesh(device_ids::Dims{D1}, shape::Dims{D}, \ + axis_names::NTuple{D,Union{String,Symbol}}) is deprecated, use \ + Mesh(reshape(collect(Int64, device_ids), shape), axis_names) instead", + :Mesh, ) + return Mesh(reshape(collect(Int64, device_ids), shape), axis_names) end end -Base.length(::Mesh{D,ND}) where {D,ND} = ND +Base.length(m::Mesh) = length(m.device_ids) Base.ndims(::Mesh{D}) where {D} = D -Base.size(mesh::Mesh) = mesh.shape -Base.size(mesh::Mesh, axis::Int) = mesh.shape[axis] +Base.size(mesh::Mesh) = size(mesh.device_ids) +Base.size(mesh::Mesh, axis::Int) = size(mesh.device_ids, axis) function Base.size(mesh::Mesh, axis::Union{String,Symbol}) return size(mesh, findfirst(==(Symbol(axis)), mesh.axis_names)) end @@ -146,18 +150,18 @@ julia> sharding = NamedSharding(mesh, (nothing, nothing)); # fully replicated Ma See also: [`Sharding.NoSharding`](@ref) """ -struct NamedSharding{D1,D2,P<:Tuple,D3} <: AbstractSharding - mesh::Mesh{D1,D2} +struct NamedSharding{D1,D2,P<:Tuple} <: AbstractSharding + mesh::Mesh{D1} partition_spec::P - is_closed::NTuple{D3,Bool} - priority::NTuple{D3,Int} + is_closed::NTuple{D2,Bool} + priority::NTuple{D2,Int} function NamedSharding( - mesh::Mesh{D1,D2}, + mesh::Mesh{D1}, partition_spec::P; - is_closed::NTuple{D3,Bool}=ntuple(Returns(true), length(partition_spec)), - priority::NTuple{D3,Int}=ntuple(i -> -1, length(partition_spec)), - ) where {D1,D2,P<:Tuple,D3} + is_closed::NTuple{D2,Bool}=ntuple(Returns(true), length(partition_spec)), + priority::NTuple{D2,Int}=ntuple(i -> -1, length(partition_spec)), + ) where {D1,P<:Tuple,D2} axis_names = Symbol[] pspec = () for p in partition_spec @@ -177,7 +181,7 @@ struct NamedSharding{D1,D2,P<:Tuple,D3} <: AbstractSharding end @assert allunique(axis_names) "Duplicate axis names!" - return new{D1,D2,typeof(pspec),D3}(mesh, pspec, is_closed, priority) + return new{D1,D2,typeof(pspec)}(mesh, pspec, is_closed, priority) end end @@ -226,17 +230,17 @@ end # This stores the sharding information in the form of XLA.HloSharding, and provides a # central type for the final storage. It also potentially saves us the pain of not having # to regenerate the partition spec from the HloSharding. -struct HloSharding{M,D,D2} <: AbstractSharding +struct HloSharding{D1,D2} <: AbstractSharding hlo_sharding::XLA.HloSharding - mesh::Mesh{M,D} + mesh::Mesh{D1} is_closed::NTuple{D2,Bool} priority::NTuple{D2,Int} function HloSharding( - hlo_sharding::XLA.HloSharding, mesh::Mesh{M,D}, is_closed, priority - ) where {M,D} + hlo_sharding::XLA.HloSharding, mesh::Mesh{D1}, is_closed, priority + ) where {D1} @assert length(is_closed) == length(priority) - return new{M,D,length(is_closed)}(hlo_sharding, mesh, is_closed, priority) + return new{D1,length(is_closed)}(hlo_sharding, mesh, is_closed, priority) end end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 0dd4507c4e..e6ee6d4870 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -350,7 +350,7 @@ function make_mlir_fn( # TODO: support multiple meshes if length(unique_meshes) > 1 error("Currently we support using a single mesh") - sorted_devices = [m.sorted_device_ids for m in unique_meshes] + sorted_devices = [sort(vec(m.device_ids)) for m in unique_meshes] @assert allequal(sorted_devices) "All meshes must have the same device ids" end sharding_mesh = first(unique_meshes)