diff --git a/src/Compiler.jl b/src/Compiler.jl index 1d01ef41c9..0b9a1a075a 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1187,7 +1187,7 @@ function codegen_flatten!( Reactant.Sharding.XLA.CondensedOpSharding, linear_parameter_shardings[i] ) push!(flatten_code, :($usbuf = $flatcode)) - device_to_array_slices = XLA.sharding_to_concrete_array_indices( + device_to_array_slices, _ = XLA.sharding_to_concrete_array_indices( condensed_op_sharding, size(carg), mesh.logical_device_ids ) for j in 1:length(mesh) @@ -1363,7 +1363,8 @@ function codegen_unflatten!( res = :(args[$(path[2])]) path = path[3:end] end - for p in path + + for p in path[1:(end - 1)] res = :(traced_getfield($res, $(Meta.quot(p)))) end @@ -1371,8 +1372,59 @@ function codegen_unflatten!( for p in argpath[3:end] argres = :(traced_getfield($argres, $(Meta.quot(p)))) end + argres = :($argres.data) + + if length(path) > 0 + final_val = gensym("final_val") + clocal = gensym("clocal") + if !has_cache_dict + has_cache_dict = true + push!( + unflatten_code, + :( + $cache_dict = $(IdDict{ + Union{TracedRArray,TracedRNumber}, + Union{ConcretePJRTArray,ConcretePJRTNumber}, + }()) + ), + ) + end + res = quote + $final_val = traced_getfield($res, $(Meta.quot(path[end]))) + if $final_val isa TracedRArray + $clocal = if haskey($cache_dict, $final_val) + $cache_dict[$final_val] + else + $cache_dict[$final_val] = ConcretePJRTArray{ + $(Reactant.unwrapped_eltype)($final_val), + ndims($final_val), + }( + $argres, size($final_val) + ) + $cache_dict[$final_val] + end + traced_setfield!($res, $(Meta.quot(path[end])), $clocal) + elseif $final_val isa TracedRNumber + $clocal = if haskey($cache_dict, $final_val) + $cache_dict[$final_val] + else + $cache_dict[$final_val] = ConcretePJRTNumber{ + $(Reactant.unwrapped_eltype)($final_val) + }( + $argres + ) + $cache_dict[$final_val] + end + traced_setfield!($res, $(Meta.quot(path[end])), $clocal) + else + traced_setfield!($res, :data, $argres) + end + end + else + res = :(traced_setfield!($res, :data, $argres)) + end - res = :($res.data = $argres.data) + # res = :(traced_setfield!($res, :data, $argres.data)) push!(unflatten_code, res) end end diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 6e4d03a159..6eba439cf6 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -171,6 +171,7 @@ function Base.showarg(io::IO, a::ConcretePJRTArray{T,N}, toplevel) where {T,N} toplevel || print(io, "::") print(io, "ConcretePJRTArray{$T,$N}") Sharding.is_sharded(a) && print(io, " with sharding $(typeof(a.sharding.sharding))") + any(!iszero, a.padding) && print(io, " with padding ", a.padding) return nothing end diff --git a/src/Sharding.jl b/src/Sharding.jl index 205c0cbec2..aefcfb68fd 100644 --- a/src/Sharding.jl +++ b/src/Sharding.jl @@ -107,7 +107,7 @@ Base.getproperty(::NoSharding, x::Symbol) = NoSharding() function (::NoSharding)(client::XLA.PJRT.Client, device, x::Union{AbstractArray,Number}) device === nothing && (device = XLA.default_device(client)) buffer = XLA.PJRT.AsyncBuffer(client, x, device) - return (buffer,), ShardInfo(NoSharding(), nothing) + return (buffer,), ShardInfo(NoSharding(), nothing), ntuple(Returns(0), ndims(x)) end """ @@ -375,10 +375,17 @@ function (sharding::HloSharding)( ) condensed_op_sharding = convert(XLA.CondensedOpSharding, sharding.hlo_sharding) - device_to_array_slices = XLA.sharding_to_concrete_array_indices( + device_to_array_slices, padding = XLA.sharding_to_concrete_array_indices( condensed_op_sharding, size(x), sharding.mesh.logical_device_ids ) + if any(!iszero, padding) + tmp = similar(x, size(x) .+ padding) + fill!(tmp, zero(eltype(x))) + tmp[[1:size(x, i) for i in 1:ndims(x)]...] .= x + x = tmp + end + data = ntuple(length(sharding.mesh)) do i XLA.PJRT.AsyncBuffer( client, @@ -387,7 +394,7 @@ function (sharding::HloSharding)( ) end - return data, ShardInfo(sharding, device_to_array_slices) + return data, ShardInfo(sharding, device_to_array_slices), padding end function get_shardy_tensor_sharding_attribute( diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index d9eaf6126e..e3db0cb3fe 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -193,10 +193,14 @@ function make_mlir_fn( ) end + any_padded = false linear_args = Reactant.TracedType[] for (k, v) in seen_args v isa Reactant.TracedType || continue push!(linear_args, v) + if k isa Reactant.ConcretePJRTArray && any(!iszero, k.padding) + any_padded = true + end end in_tys = if toscalar @@ -252,6 +256,14 @@ function make_mlir_fn( set_mlir_data!(arg, row_maj_arg) end + if any_padded # We need to un-pad the arguments + for i in 1:N + traced_args[i] = Reactant.make_tracer( + seen_args, args[i], (), Reactant.UnpadTracedArray; + ) + end + end + if isempty(kwargs) Reactant.call_with_reactant(f, traced_args...) else diff --git a/src/Tracing.jl b/src/Tracing.jl index 0c3cec5088..d7ed5c87e7 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -6,6 +6,7 @@ TracedSetPath = 5 TracedToTypes = 6 NoStopTracedTrack = 7 + UnpadTracedArray = 8 end struct VisitedObject @@ -893,6 +894,10 @@ function make_tracer( "Mismatched sharding. Input has sharding $(prev.sharding), but requested sharding is $(typeof(sharding))", ) end + if mode == UnpadTracedArray + tarray = seen[prev]::TracedRArray{T,N} + return view(tarray, [1:(size(prev, i) - prev.padding[i]) for i in 1:N]...) + end if mode != ConcreteToTraced throw("Cannot trace concrete") end @@ -923,6 +928,9 @@ function make_tracer( return ConcretePJRTNumber(prev; sharding) end end + if mode == UnpadTracedArray + return seen[prev]::TracedRNumber{T} + end if mode != ConcreteToTraced throw("Cannot trace existing trace type") end diff --git a/src/Types.jl b/src/Types.jl index 951678fda5..80ae538f15 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -103,6 +103,28 @@ mutable struct ConcretePJRTArray{T,N,D,S<:Sharding.ShardInfo} <: AbstractConcret data::NTuple{D,XLA.PJRT.AsyncBuffer} shape::NTuple{N,Int} sharding::S + padding::NTuple{N,Int} # this is an internal field and is used for sharding mostly +end + +# Note that these constructors are provided as a means to test the padding logic easily. +# For all other purposes, padding is directly handled by the sharding logic. +function ConcretePJRTArray( + x::ConcretePJRTArray{T,N,D,S}, padding::NTuple{N,Int} +) where {T,N,D,S} + return ConcretePJRTArray{T,N,D,S}(x, padding) +end + +function ConcretePJRTArray{T,N,D,S}( + x::ConcretePJRTArray{T,N,D,S}, padding::NTuple{N,Int} +) where {T,N,D,S} + all(iszero, padding) && return x + + data = convert(Array, x) + full_data = zeros(T, size(x) .+ padding) + full_data[[1:size(x, i) for i in 1:N]...] .= data + + carray = ConcretePJRTArray(full_data; client=XLA.client(x), sharding=x.sharding) + return ConcretePJRTArray{T,N,D,S}(carray.data, size(carray), x.sharding, padding) end @leaf ConcretePJRTArray @@ -115,6 +137,12 @@ Base.@deprecate ConcretePJRTArray(data::Number; kwargs...) ConcretePJRTNumber( data; kwargs... ) +function ConcretePJRTArray{T,N,D,S}( + data::NTuple{D,XLA.PJRT.AsyncBuffer}, shape::NTuple{N,Int}, sharding::S +) where {T,N,D,S} + return ConcretePJRTArray{T,N,D,S}(data, shape, sharding, ntuple(Returns(0), N)) +end + function ConcretePJRTArray{T,N}( data::Tuple{XLA.PJRT.AsyncBuffer}, shape::NTuple{N,Int} ) where {T,N} @@ -144,13 +172,15 @@ function ConcretePJRTArray( specified, `idx` must match `device`" end end - sdata, sharding = sharding(client, device, data) - return ConcretePJRTArray{T,N,1,typeof(sharding)}(sdata, size(data), sharding) + sdata, sharding, padding = sharding(client, device, data) + return ConcretePJRTArray{T,N,1,typeof(sharding)}( + sdata, size(data) .+ padding, sharding, padding + ) end @assert device === nothing && idx === nothing "If `sharding` is not `NoSharding`, `device` and `idx` cannot be specified!" - sharded_data, sharding = sharding(client, nothing, data) + sharded_data, sharding, padding = sharding(client, nothing, data) return ConcretePJRTArray{T,N,length(sharded_data),typeof(sharding)}( - sharded_data, size(data), sharding + sharded_data, size(data) .+ padding, sharding, padding ) end @@ -169,8 +199,7 @@ const AnyConcretePJRTArray{T,N,D,S} = Union{ ConcretePJRTArray{T,N,D,S},WrappedConcretePJRTArray{T,N,D,S} } -const AnyConcreteRArray = AnyConcretePJRTArray - +## Helpful functions ConcretePJRTArray(x::AnyConcretePJRTArray) = ConcretePJRTArray{eltype(x),ndims(x)}(x) ConcretePJRTArray{T}(x::AnyConcretePJRTArray) where {T} = ConcretePJRTArray{T,ndims(x)}(x) ConcretePJRTArray{T,N}(x::ConcretePJRTArray{T,N}) where {T,N} = x @@ -193,3 +222,4 @@ end ## Aliases to prevent breaking changes const ConcreteRArray = ConcretePJRTArray const ConcreteRNumber = ConcretePJRTNumber +const AnyConcreteRArray = AnyConcretePJRTArray diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index dd0e596db8..139f14ca6f 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -48,7 +48,7 @@ function Array(client::Client, array::Base.Array{T,N}, sharding::Sharding) where end all_devices = XLA.devices(sharding) - array_slices = XLA.sharding_to_concrete_array_indices( + array_slices, _ = XLA.sharding_to_concrete_array_indices( convert(XLA.HloSharding, sharding), size(array), collect(Int64, 0:(length(all_devices) - 1)), @@ -159,7 +159,7 @@ function XLA.to_host(buffer::Array, data) # avoid the complexity of supporting that for now. single_device_arrays = disassemble_into_single_device_arrays(buffer, true) - array_slices = XLA.sharding_to_concrete_array_indices( + array_slices, _ = XLA.sharding_to_concrete_array_indices( convert(XLA.HloSharding, sharding), size(data), collect(Int64, 0:(length(all_devices) - 1)), diff --git a/src/xla/Sharding.jl b/src/xla/Sharding.jl index 593b21228f..dc1fb4ecfd 100644 --- a/src/xla/Sharding.jl +++ b/src/xla/Sharding.jl @@ -251,28 +251,29 @@ function sharding_to_concrete_array_indices( sharding::CondensedOpSharding, shape::Dims{N}, device_ids ) where {N} if sharding.type == OpShardingType.Replicated || sharding.type == OpShardingType.Maximal - return ntuple(Returns(ntuple(i -> 1:shape[i], N)), length(device_ids)) + return ( + ntuple(Returns(ntuple(i -> 1:shape[i], N)), length(device_ids)), + ntuple(Returns(0), N), + ) elseif sharding.type == OpShardingType.Other partitions, num_replicas = get_number_of_ways_dim_sharded(sharding) @assert length(partitions) == length(shape) shape = reverse(shape) + partitionable_shape = map(zip(shape, partitions)) do (dim, n_shards) + dim % n_shards == 0 && return dim + res = dim + n_shards ÷ 2 + return res - res % n_shards + end + partitionable_shape = Tuple(partitionable_shape) + padding = partitionable_shape .- shape + # Calculate indices for each dimension - axis_indices = map(zip(shape, partitions)) do (dim, n_shards) + axis_indices = map(zip(partitionable_shape, partitions)) do (dim, n_shards) @assert dim > 0 "Invalid dimension: $dim" @assert n_shards > 0 "Invalid number of shards: $n_shards" n_shards == 1 && return [1:dim] - shard_size, remainder = divrem(dim, n_shards) - - if remainder != 0 - throw( - DimensionMismatch( - "Dimension of Size $(dim) cannot be partitioned into $(n_shards) \ - shards each of size $(shard_size) (remainder = $(remainder)).", - ), - ) - end - + shard_size = dim ÷ n_shards return [(i * shard_size + 1):((i + 1) * shard_size) for i in 0:(n_shards - 1)] end @@ -285,7 +286,7 @@ function sharding_to_concrete_array_indices( end end - return map(Base.Fix1(getindex, indices), device_ids) + return map(Base.Fix1(getindex, indices), device_ids), reverse(padding) else error("Unsupported sharding type: $(sharding.type)") end @@ -295,7 +296,7 @@ function compute_array_indices_and_hlo_sharding( sharding::CondensedOpSharding, array_size, device_ids ) return ( - sharding_to_concrete_array_indices(sharding, array_size, device_ids), + first(sharding_to_concrete_array_indices(sharding, array_size, device_ids)), convert(HloSharding, sharding.opsharding), ) end diff --git a/test/basic.jl b/test/basic.jl index d62e22f905..4b7f4ba99d 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -358,13 +358,11 @@ end end @testset "repeat" begin - fn_inner(x, counts) = repeat(x; inner=counts) - @testset for (size, counts) in Iterators.product( [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)], [(), (1,), (2,), (2, 1), (1, 2), (2, 2), (2, 2, 2), (1, 1, 1, 1, 1)], ) - x = rand(size...) + x = reshape(collect(Float32, 1:prod(size)), size...) @testset "outer repeat" begin @test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...) @@ -373,7 +371,8 @@ end length(counts) < length(size) && continue @testset "inner repeat" begin - @test (@jit fn_inner(Reactant.to_rarray(x), counts)) == fn_inner(x, counts) + @test (@jit repeat(Reactant.to_rarray(x); inner=counts)) == + repeat(x; inner=counts) end end end diff --git a/test/sharding.jl b/test/sharding.jl index 7508e6e2a3..a55fefb45a 100644 --- a/test/sharding.jl +++ b/test/sharding.jl @@ -242,3 +242,35 @@ end [2, 1], #=iota_transpose_perm=# ) == [0, 2, 4, 6, 1, 3, 5, 7] end + +@testset "Sharding with non-divisible axes sizes" begin + # Test the low-level API where we automatically un-pad the arrays + x_test = Reactant.to_rarray(reshape(collect(Float32, 1:16), 4, 4)) + x_padded = Reactant.ConcretePJRTArray(x_test, (1, 3)) + + # If not done correctly, this will throw an error + @test Array(@jit(fn_test2(x_padded))) ≈ Array(@jit(fn_test2(x_test))) + + function fn_set_1(x) + x[:, 1:2] .= 1 + return x + end + + res1 = @jit fn_set_1(x_padded) + res2 = @jit fn_set_1(x_test) + @test Array(res1) ≈ Array(res2) + @test_broken all(Array(x_padded)[:, 1:2] .== 1) + @test all(Array(x_test)[:, 1:2] .== 1) + + if length(Reactant.addressable_devices()) ≥ 8 + mesh = Sharding.Mesh(reshape(collect(Int64, 0:7), 2, 4), ("data", "model")) + x = reshape(collect(Float32, 1:14), 2, 7) + x_ra = Reactant.to_rarray( + x; sharding=Sharding.NamedSharding(mesh, ("data", "model")) + ) + + @test Array(@jit fn_test3(x_ra)) ≈ fn_test3(x) + else + @warn "Not enough addressable devices to run sharding tests" + end +end