Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ function compile_mlir!(
backend="gpu",
fn_kwargs=(),
raise::Union{Bool,String}=false,
input_shardings=nothing,
)
# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues
Expand All @@ -652,7 +653,7 @@ function compile_mlir!(
activate_raising!(is_raising)

mlir_fn_res = try
Reactant.TracedUtils.make_mlir_fn(f, args, fn_kwargs, "main", true)
Reactant.TracedUtils.make_mlir_fn(f, args, fn_kwargs, "main", true; input_shardings)
finally
deactivate_raising!(is_raising)
deactivate_sdycache!(sdycache)
Expand Down Expand Up @@ -1167,14 +1168,16 @@ function codegen_flatten!(

if is_sharded
carg = inv_seen_args[arg]
condensed_op_sharding = convert(
Reactant.Sharding.XLA.CondensedOpSharding, linear_parameter_shardings[i]
)
if Reactant.Sharding.is_sharded(carg)
# Currently disabling the error since we roundtrip from MHLO to generate
# the shardings
# # Check if the sharding provided is same as the one we have
# arg_condensed_op_sharding = Reactant.Sharding.XLA.CondensedOpSharding(
# Reactant.Sharding.ShardingWithShape(carg.sharding, size(carg))
# )
# @assert arg_condensed_op_sharding == condensed_op_sharding "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE."
arg_condensed_op_sharding = convert(
Reactant.Sharding.XLA.CondensedOpSharding,
carg.sharding.sharding.hlo_sharding,
)
# Check if the sharding provided is same as the one we have
@assert arg_condensed_op_sharding == condensed_op_sharding "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE."

push!(flatten_code, :($usbuf = $flatcode.data))
for j in 1:length(mesh)
Expand All @@ -1183,11 +1186,8 @@ function codegen_flatten!(
push!(flatten_code, :($sbuf = XLA.synced_buffer(getindex($usbuf, $j))))
end
else
condensed_op_sharding = convert(
Reactant.Sharding.XLA.CondensedOpSharding, linear_parameter_shardings[i]
)
push!(flatten_code, :($usbuf = $flatcode))
device_to_array_slices = XLA.sharding_to_concrete_array_indices(
device_to_array_slices, _ = XLA.sharding_to_concrete_array_indices(
condensed_op_sharding, size(carg), mesh.logical_device_ids
)
for j in 1:length(mesh)
Expand Down
56 changes: 53 additions & 3 deletions src/Sharding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ function (sharding::NamedSharding)(
client::XLA.PJRT.Client, device::Nothing, x::Union{AbstractArray,Number}
)
@assert length(sharding.partition_spec) == ndims(x)
return (convert(HloSharding, sharding))(client, device, x)
return HloSharding(sharding, client, device, x)
end

function get_shardy_tensor_sharding_attribute(
Expand Down Expand Up @@ -339,7 +339,9 @@ end
return ShardInfo{HloSharding{D1,D2},Vector{NTuple{N,UnitRange{Int64}}}}
end

function Base.convert(::Type{HloSharding}, sharding::NamedSharding)
# This doesn't account for the size of the input so in-presence of padding this will be
# incorrect. Hence always use the HloSharding constructor.
function generate_hlo_sharding_from_tensor_attribute(sharding::NamedSharding)
if MLIR.IR._has_context()
ctx = MLIR.IR.context()
else
Expand Down Expand Up @@ -370,14 +372,62 @@ function Base.convert(::Type{HloSharding}, sharding::NamedSharding)
end
end

function HloSharding(sharding::NamedSharding, client::XLA.PJRT.Client, _, x)
hlo_sharding = generate_hlo_sharding_from_tensor_attribute(sharding)

# Check if the input needs to be padded. If so this sharding is not valid and we
# need to request the tensor sharding from XLA
condensed_op_sharding = convert(XLA.CondensedOpSharding, hlo_sharding.hlo_sharding)
device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices(
condensed_op_sharding, size(x), hlo_sharding.mesh.logical_device_ids
)

if needs_padding
# Compile a dummy function to get the tensor sharding
tmp = if x isa Number
Reactant.ConcretePJRTNumber(zero(eltype(x)))
else
Reactant.ConcretePJRTArray(ones(eltype(x), size(x)...))
end
_, exec, _, _, _ = Reactant.Compiler.compile_xla(
Reactant.Ops.negate, (tmp,); input_shardings=IdDict(tmp => sharding)
)
xla_hlo_sharding = convert(
Reactant.XLA.HloSharding, only(Reactant.XLA.get_parameter_shardings(exec))
)
hlo_sharding = HloSharding(
xla_hlo_sharding,
hlo_sharding.mesh,
hlo_sharding.is_closed,
hlo_sharding.priority,
)

condensed_op_sharding = convert(XLA.CondensedOpSharding, hlo_sharding.hlo_sharding)
device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices(
condensed_op_sharding, size(x), hlo_sharding.mesh.logical_device_ids
)
end

data = ntuple(length(hlo_sharding.mesh)) do i
XLA.PJRT.AsyncBuffer(
client,
x[device_to_array_slices[i]...],
XLA.get_device(client, hlo_sharding.mesh.device_ids[i]),
)
end

return data, ShardInfo(hlo_sharding, device_to_array_slices)
end

function (sharding::HloSharding)(
client::XLA.PJRT.Client, ::Nothing, x::Union{AbstractArray,Number}
)
condensed_op_sharding = convert(XLA.CondensedOpSharding, sharding.hlo_sharding)

device_to_array_slices = XLA.sharding_to_concrete_array_indices(
device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices(
condensed_op_sharding, size(x), sharding.mesh.logical_device_ids
)
@assert !needs_padding "This shouldn't happen. Open an issue on Reactant.jl"

data = ntuple(length(sharding.mesh)) do i
XLA.PJRT.AsyncBuffer(
Expand Down
14 changes: 10 additions & 4 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ function make_mlir_fn(
args_in_result::Symbol=:all,
construct_function_without_args::Bool=false,
do_transpose=true,
input_shardings=nothing, # This is not meant to be used by the user.
)
if sizeof(typeof(f)) != 0 || f isa Base.BroadcastFunction
mlir_fn_res = make_mlir_fn(
Expand All @@ -173,6 +174,7 @@ function make_mlir_fn(
return_dialect,
do_transpose,
args_in_result,
input_shardings,
)
mlir_fn_res.fnwrapped = true
return mlir_fn_res
Expand Down Expand Up @@ -221,10 +223,14 @@ function make_mlir_fn(
# Insert meshes for the sharded arguments
traced_args_to_shardings = OrderedIdDict()
for (k, v) in seen_args
if (k isa Reactant.ConcretePJRTArray || k isa Reactant.ConcretePJRTNumber) &&
Reactant.Sharding.is_sharded(k)
Reactant.Ops.mesh(k.sharding.mesh)
traced_args_to_shardings[v] = k.sharding
if (k isa Reactant.ConcretePJRTArray || k isa Reactant.ConcretePJRTNumber)
if Reactant.Sharding.is_sharded(k)
Reactant.Ops.mesh(k.sharding.mesh)
traced_args_to_shardings[v] = k.sharding
elseif input_shardings !== nothing && haskey(input_shardings, k)
Reactant.Ops.mesh(input_shardings[k].mesh)
traced_args_to_shardings[v] = input_shardings[k]
end
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/xla/IFRT/Array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function Array(client::Client, array::Base.Array{T,N}, sharding::Sharding) where
end

all_devices = XLA.devices(sharding)
array_slices = XLA.sharding_to_concrete_array_indices(
array_slices, _ = XLA.sharding_to_concrete_array_indices(
convert(XLA.HloSharding, sharding),
size(array),
collect(Int64, 0:(length(all_devices) - 1)),
Expand Down Expand Up @@ -159,7 +159,7 @@ function XLA.to_host(buffer::Array, data)
# avoid the complexity of supporting that for now.
single_device_arrays = disassemble_into_single_device_arrays(buffer, true)

array_slices = XLA.sharding_to_concrete_array_indices(
array_slices, _ = XLA.sharding_to_concrete_array_indices(
convert(XLA.HloSharding, sharding),
size(data),
collect(Int64, 0:(length(all_devices) - 1)),
Expand Down
43 changes: 24 additions & 19 deletions src/xla/Sharding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,31 +251,36 @@ function sharding_to_concrete_array_indices(
sharding::CondensedOpSharding, shape::Dims{N}, device_ids
) where {N}
if sharding.type == OpShardingType.Replicated || sharding.type == OpShardingType.Maximal
return ntuple(Returns(ntuple(i -> 1:shape[i], N)), length(device_ids))
return map(Returns(UnitRange.(1, shape)), device_ids), false
elseif sharding.type == OpShardingType.Other
partitions, num_replicas = get_number_of_ways_dim_sharded(sharding)
@assert length(partitions) == length(shape)
shape = reverse(shape)

# XLA will automatically pad the inputs that don't match the final shape
partitionable_shape = map(zip(shape, partitions)) do (dim, n_shards)
dim % n_shards == 0 && return dim
res = dim + n_shards ÷ 2
return res - res % n_shards
end
partitionable_shape = Tuple(partitionable_shape)

needs_padding = any(partitionable_shape .!= shape)

# Calculate indices for each dimension
axis_indices = map(zip(shape, partitions)) do (dim, n_shards)
@assert dim > 0 "Invalid dimension: $dim"
@assert n_shards > 0 "Invalid number of shards: $n_shards"
n_shards == 1 && return [1:dim]
shard_size, remainder = divrem(dim, n_shards)

if remainder != 0
throw(
DimensionMismatch(
"Dimension of Size $(dim) cannot be partitioned into $(n_shards) \
shards each of size $(shard_size) (remainder = $(remainder)).",
),
)
axis_indices =
map(zip(partitionable_shape, shape, partitions)) do (dim_padded, dim, n_shards)
@assert dim > 0 "Invalid dimension: $dim"
@assert n_shards > 0 "Invalid number of shards: $n_shards"
n_shards == 1 && return [1:dim]
shard_size = dim_padded ÷ n_shards

return [
(i * shard_size + 1):min((i + 1) * shard_size, dim) for
i in 0:(n_shards - 1)
]
end

return [(i * shard_size + 1):((i + 1) * shard_size) for i in 0:(n_shards - 1)]
end

indices = Dict{Int,NTuple{N,UnitRange{Int}}}()
device_idx = 1
for _ in 1:num_replicas
Expand All @@ -285,7 +290,7 @@ function sharding_to_concrete_array_indices(
end
end

return map(Base.Fix1(getindex, indices), device_ids)
return map(Base.Fix1(getindex, indices), device_ids), needs_padding
else
error("Unsupported sharding type: $(sharding.type)")
end
Expand All @@ -295,7 +300,7 @@ function compute_array_indices_and_hlo_sharding(
sharding::CondensedOpSharding, array_size, device_ids
)
return (
sharding_to_concrete_array_indices(sharding, array_size, device_ids),
first(sharding_to_concrete_array_indices(sharding, array_size, device_ids)),
convert(HloSharding, sharding.opsharding),
)
end
Expand Down
21 changes: 21 additions & 0 deletions test/sharding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,27 @@ end
end
end

@testset "Sharding with non-divisible axes sizes" begin
if length(Reactant.addressable_devices()) ≥ 8
mesh = Sharding.Mesh(reshape(collect(Int64, 0:7), 2, 4), ("data", "model"))
x = reshape(collect(Float32, 1:14), 2, 7)
x_ra = Reactant.to_rarray(
x; sharding=Sharding.NamedSharding(mesh, ("data", "model"))
)

@test Array(@jit sum(x_ra; dims=2)) ≈ sum(x; dims=2)

x = reshape(collect(Float32, 1:25), 5, 5)
x_ra = Reactant.to_rarray(
x; sharding=Sharding.NamedSharding(mesh, ("data", "model"))
)

@test Array(@jit fn_test2(x_ra)) ≈ fn_test2(x)
else
@warn "Not enough addressable devices to run sharding tests"
end
end

# Tests from the examples in
# https://github.com/openxla/xla/blob/96d6678053d867099a42be9001c49b2ed7111afd/xla/hlo/ir/tile_assignment.h#L53-L68
@testset "Device List from Iota Tile" begin
Expand Down