Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,7 @@ function codegen_flatten!(
Reactant.Sharding.XLA.CondensedOpSharding, linear_parameter_shardings[i]
)
push!(flatten_code, :($usbuf = $flatcode))
device_to_array_slices = XLA.sharding_to_concrete_array_indices(
device_to_array_slices, _ = XLA.sharding_to_concrete_array_indices(
condensed_op_sharding, size(carg), mesh.logical_device_ids
)
for j in 1:length(mesh)
Expand Down Expand Up @@ -1363,16 +1363,68 @@ function codegen_unflatten!(
res = :(args[$(path[2])])
path = path[3:end]
end
for p in path

for p in path[1:(end - 1)]
res = :(traced_getfield($res, $(Meta.quot(p))))
end

argres = :(args[$(argpath[2])])
for p in argpath[3:end]
argres = :(traced_getfield($argres, $(Meta.quot(p))))
end
argres = :($argres.data)

if length(path) > 0
final_val = gensym("final_val")
clocal = gensym("clocal")
if !has_cache_dict
has_cache_dict = true
push!(
unflatten_code,
:(
$cache_dict = $(IdDict{
Union{TracedRArray,TracedRNumber},
Union{ConcretePJRTArray,ConcretePJRTNumber},
}())
),
)
end
res = quote
$final_val = traced_getfield($res, $(Meta.quot(path[end])))
if $final_val isa TracedRArray
$clocal = if haskey($cache_dict, $final_val)
$cache_dict[$final_val]
else
$cache_dict[$final_val] = ConcretePJRTArray{
$(Reactant.unwrapped_eltype)($final_val),
ndims($final_val),
}(
$argres, size($final_val)
)
$cache_dict[$final_val]
end
traced_setfield!($res, $(Meta.quot(path[end])), $clocal)
elseif $final_val isa TracedRNumber
$clocal = if haskey($cache_dict, $final_val)
$cache_dict[$final_val]
else
$cache_dict[$final_val] = ConcretePJRTNumber{
$(Reactant.unwrapped_eltype)($final_val)
}(
$argres
)
$cache_dict[$final_val]
end
traced_setfield!($res, $(Meta.quot(path[end])), $clocal)
else
traced_setfield!($res, :data, $argres)
end
end
else
res = :(traced_setfield!($res, :data, $argres))
end

res = :($res.data = $argres.data)
# res = :(traced_setfield!($res, :data, $argres.data))
push!(unflatten_code, res)
end
end
Expand Down
1 change: 1 addition & 0 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ function Base.showarg(io::IO, a::ConcretePJRTArray{T,N}, toplevel) where {T,N}
toplevel || print(io, "::")
print(io, "ConcretePJRTArray{$T,$N}")
Sharding.is_sharded(a) && print(io, " with sharding $(typeof(a.sharding.sharding))")
any(!iszero, a.padding) && print(io, " with padding ", a.padding)
return nothing
end

Expand Down
13 changes: 10 additions & 3 deletions src/Sharding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Base.getproperty(::NoSharding, x::Symbol) = NoSharding()
function (::NoSharding)(client::XLA.PJRT.Client, device, x::Union{AbstractArray,Number})
device === nothing && (device = XLA.default_device(client))
buffer = XLA.PJRT.AsyncBuffer(client, x, device)
return (buffer,), ShardInfo(NoSharding(), nothing)
return (buffer,), ShardInfo(NoSharding(), nothing), ntuple(Returns(0), ndims(x))
end

"""
Expand Down Expand Up @@ -375,10 +375,17 @@ function (sharding::HloSharding)(
)
condensed_op_sharding = convert(XLA.CondensedOpSharding, sharding.hlo_sharding)

device_to_array_slices = XLA.sharding_to_concrete_array_indices(
device_to_array_slices, padding = XLA.sharding_to_concrete_array_indices(
condensed_op_sharding, size(x), sharding.mesh.logical_device_ids
)

if any(!iszero, padding)
tmp = similar(x, size(x) .+ padding)
fill!(tmp, zero(eltype(x)))
tmp[[1:size(x, i) for i in 1:ndims(x)]...] .= x
x = tmp
end

data = ntuple(length(sharding.mesh)) do i
XLA.PJRT.AsyncBuffer(
client,
Expand All @@ -387,7 +394,7 @@ function (sharding::HloSharding)(
)
end

return data, ShardInfo(sharding, device_to_array_slices)
return data, ShardInfo(sharding, device_to_array_slices), padding
end

function get_shardy_tensor_sharding_attribute(
Expand Down
12 changes: 12 additions & 0 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,14 @@ function make_mlir_fn(
)
end

any_padded = false
linear_args = Reactant.TracedType[]
for (k, v) in seen_args
v isa Reactant.TracedType || continue
push!(linear_args, v)
if k isa Reactant.ConcretePJRTArray && any(!iszero, k.padding)
any_padded = true
end
end

in_tys = if toscalar
Expand Down Expand Up @@ -252,6 +256,14 @@ function make_mlir_fn(
set_mlir_data!(arg, row_maj_arg)
end

if any_padded # We need to un-pad the arguments
for i in 1:N
traced_args[i] = Reactant.make_tracer(
seen_args, args[i], (), Reactant.UnpadTracedArray;
)
end
end

if isempty(kwargs)
Reactant.call_with_reactant(f, traced_args...)
else
Expand Down
8 changes: 8 additions & 0 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
TracedSetPath = 5
TracedToTypes = 6
NoStopTracedTrack = 7
UnpadTracedArray = 8
end

struct VisitedObject
Expand Down Expand Up @@ -893,6 +894,10 @@ function make_tracer(
"Mismatched sharding. Input has sharding $(prev.sharding), but requested sharding is $(typeof(sharding))",
)
end
if mode == UnpadTracedArray
tarray = seen[prev]::TracedRArray{T,N}
return view(tarray, [1:(size(prev, i) - prev.padding[i]) for i in 1:N]...)
end
if mode != ConcreteToTraced
throw("Cannot trace concrete")
end
Expand Down Expand Up @@ -923,6 +928,9 @@ function make_tracer(
return ConcretePJRTNumber(prev; sharding)
end
end
if mode == UnpadTracedArray
return seen[prev]::TracedRNumber{T}
end
if mode != ConcreteToTraced
throw("Cannot trace existing trace type")
end
Expand Down
42 changes: 36 additions & 6 deletions src/Types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,28 @@ mutable struct ConcretePJRTArray{T,N,D,S<:Sharding.ShardInfo} <: AbstractConcret
data::NTuple{D,XLA.PJRT.AsyncBuffer}
shape::NTuple{N,Int}
sharding::S
padding::NTuple{N,Int} # this is an internal field and is used for sharding mostly
end

# Note that these constructors are provided as a means to test the padding logic easily.
# For all other purposes, padding is directly handled by the sharding logic.
function ConcretePJRTArray(
x::ConcretePJRTArray{T,N,D,S}, padding::NTuple{N,Int}
) where {T,N,D,S}
return ConcretePJRTArray{T,N,D,S}(x, padding)
end

function ConcretePJRTArray{T,N,D,S}(
x::ConcretePJRTArray{T,N,D,S}, padding::NTuple{N,Int}
) where {T,N,D,S}
all(iszero, padding) && return x

data = convert(Array, x)
full_data = zeros(T, size(x) .+ padding)
full_data[[1:size(x, i) for i in 1:N]...] .= data

carray = ConcretePJRTArray(full_data; client=XLA.client(x), sharding=x.sharding)
return ConcretePJRTArray{T,N,D,S}(carray.data, size(carray), x.sharding, padding)
end

@leaf ConcretePJRTArray
Expand All @@ -115,6 +137,12 @@ Base.@deprecate ConcretePJRTArray(data::Number; kwargs...) ConcretePJRTNumber(
data; kwargs...
)

function ConcretePJRTArray{T,N,D,S}(
data::NTuple{D,XLA.PJRT.AsyncBuffer}, shape::NTuple{N,Int}, sharding::S
) where {T,N,D,S}
return ConcretePJRTArray{T,N,D,S}(data, shape, sharding, ntuple(Returns(0), N))
end

function ConcretePJRTArray{T,N}(
data::Tuple{XLA.PJRT.AsyncBuffer}, shape::NTuple{N,Int}
) where {T,N}
Expand Down Expand Up @@ -144,13 +172,15 @@ function ConcretePJRTArray(
specified, `idx` must match `device`"
end
end
sdata, sharding = sharding(client, device, data)
return ConcretePJRTArray{T,N,1,typeof(sharding)}(sdata, size(data), sharding)
sdata, sharding, padding = sharding(client, device, data)
return ConcretePJRTArray{T,N,1,typeof(sharding)}(
sdata, size(data) .+ padding, sharding, padding
)
end
@assert device === nothing && idx === nothing "If `sharding` is not `NoSharding`, `device` and `idx` cannot be specified!"
sharded_data, sharding = sharding(client, nothing, data)
sharded_data, sharding, padding = sharding(client, nothing, data)
return ConcretePJRTArray{T,N,length(sharded_data),typeof(sharding)}(
sharded_data, size(data), sharding
sharded_data, size(data) .+ padding, sharding, padding
)
end

Expand All @@ -169,8 +199,7 @@ const AnyConcretePJRTArray{T,N,D,S} = Union{
ConcretePJRTArray{T,N,D,S},WrappedConcretePJRTArray{T,N,D,S}
}

const AnyConcreteRArray = AnyConcretePJRTArray

## Helpful functions
ConcretePJRTArray(x::AnyConcretePJRTArray) = ConcretePJRTArray{eltype(x),ndims(x)}(x)
ConcretePJRTArray{T}(x::AnyConcretePJRTArray) where {T} = ConcretePJRTArray{T,ndims(x)}(x)
ConcretePJRTArray{T,N}(x::ConcretePJRTArray{T,N}) where {T,N} = x
Expand All @@ -193,3 +222,4 @@ end
## Aliases to prevent breaking changes
const ConcreteRArray = ConcretePJRTArray
const ConcreteRNumber = ConcretePJRTNumber
const AnyConcreteRArray = AnyConcretePJRTArray
4 changes: 2 additions & 2 deletions src/xla/IFRT/Array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function Array(client::Client, array::Base.Array{T,N}, sharding::Sharding) where
end

all_devices = XLA.devices(sharding)
array_slices = XLA.sharding_to_concrete_array_indices(
array_slices, _ = XLA.sharding_to_concrete_array_indices(
convert(XLA.HloSharding, sharding),
size(array),
collect(Int64, 0:(length(all_devices) - 1)),
Expand Down Expand Up @@ -159,7 +159,7 @@ function XLA.to_host(buffer::Array, data)
# avoid the complexity of supporting that for now.
single_device_arrays = disassemble_into_single_device_arrays(buffer, true)

array_slices = XLA.sharding_to_concrete_array_indices(
array_slices, _ = XLA.sharding_to_concrete_array_indices(
convert(XLA.HloSharding, sharding),
size(data),
collect(Int64, 0:(length(all_devices) - 1)),
Expand Down
31 changes: 16 additions & 15 deletions src/xla/Sharding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,28 +251,29 @@ function sharding_to_concrete_array_indices(
sharding::CondensedOpSharding, shape::Dims{N}, device_ids
) where {N}
if sharding.type == OpShardingType.Replicated || sharding.type == OpShardingType.Maximal
return ntuple(Returns(ntuple(i -> 1:shape[i], N)), length(device_ids))
return (
ntuple(Returns(ntuple(i -> 1:shape[i], N)), length(device_ids)),
ntuple(Returns(0), N),
)
elseif sharding.type == OpShardingType.Other
partitions, num_replicas = get_number_of_ways_dim_sharded(sharding)
@assert length(partitions) == length(shape)
shape = reverse(shape)

partitionable_shape = map(zip(shape, partitions)) do (dim, n_shards)
dim % n_shards == 0 && return dim
res = dim + n_shards ÷ 2
return res - res % n_shards
end
partitionable_shape = Tuple(partitionable_shape)
padding = partitionable_shape .- shape

# Calculate indices for each dimension
axis_indices = map(zip(shape, partitions)) do (dim, n_shards)
axis_indices = map(zip(partitionable_shape, partitions)) do (dim, n_shards)
@assert dim > 0 "Invalid dimension: $dim"
@assert n_shards > 0 "Invalid number of shards: $n_shards"
n_shards == 1 && return [1:dim]
shard_size, remainder = divrem(dim, n_shards)

if remainder != 0
throw(
DimensionMismatch(
"Dimension of Size $(dim) cannot be partitioned into $(n_shards) \
shards each of size $(shard_size) (remainder = $(remainder)).",
),
)
end

shard_size = dim ÷ n_shards
return [(i * shard_size + 1):((i + 1) * shard_size) for i in 0:(n_shards - 1)]
end

Expand All @@ -285,7 +286,7 @@ function sharding_to_concrete_array_indices(
end
end

return map(Base.Fix1(getindex, indices), device_ids)
return map(Base.Fix1(getindex, indices), device_ids), reverse(padding)
else
error("Unsupported sharding type: $(sharding.type)")
end
Expand All @@ -295,7 +296,7 @@ function compute_array_indices_and_hlo_sharding(
sharding::CondensedOpSharding, array_size, device_ids
)
return (
sharding_to_concrete_array_indices(sharding, array_size, device_ids),
first(sharding_to_concrete_array_indices(sharding, array_size, device_ids)),
convert(HloSharding, sharding.opsharding),
)
end
Expand Down
7 changes: 3 additions & 4 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,11 @@ end
end

@testset "repeat" begin
fn_inner(x, counts) = repeat(x; inner=counts)

@testset for (size, counts) in Iterators.product(
[(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)],
[(), (1,), (2,), (2, 1), (1, 2), (2, 2), (2, 2, 2), (1, 1, 1, 1, 1)],
)
x = rand(size...)
x = reshape(collect(Float32, 1:prod(size)), size...)

@testset "outer repeat" begin
@test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...)
Expand All @@ -373,7 +371,8 @@ end
length(counts) < length(size) && continue

@testset "inner repeat" begin
@test (@jit fn_inner(Reactant.to_rarray(x), counts)) == fn_inner(x, counts)
@test (@jit repeat(Reactant.to_rarray(x); inner=counts)) ==
repeat(x; inner=counts)
end
end
end
Expand Down
Loading
Loading