Skip to content

Commit a87d5f9

Browse files
committed
feat: sharding with non-divisible dimensions [alternate approach] (#825)
* feat: support implicit padding from XLA * feat: use XLA for shard-info if we need padding * test: padding for sharding * fix: return type
1 parent dc1df21 commit a87d5f9

File tree

6 files changed

+122
-40
lines changed

6 files changed

+122
-40
lines changed

src/Compiler.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,7 @@ function compile_mlir!(
636636
backend="gpu",
637637
fn_kwargs=(),
638638
raise::Union{Bool,String}=false,
639+
input_shardings=nothing,
639640
)
640641
# Explicitly don't use block! to avoid creating a closure, which creates
641642
# both compile-time and relocatability issues
@@ -652,7 +653,7 @@ function compile_mlir!(
652653
activate_raising!(is_raising)
653654

654655
mlir_fn_res = try
655-
Reactant.TracedUtils.make_mlir_fn(f, args, fn_kwargs, "main", true)
656+
Reactant.TracedUtils.make_mlir_fn(f, args, fn_kwargs, "main", true; input_shardings)
656657
finally
657658
deactivate_raising!(is_raising)
658659
deactivate_sdycache!(sdycache)
@@ -1167,14 +1168,16 @@ function codegen_flatten!(
11671168

11681169
if is_sharded
11691170
carg = inv_seen_args[arg]
1171+
condensed_op_sharding = convert(
1172+
Reactant.Sharding.XLA.CondensedOpSharding, linear_parameter_shardings[i]
1173+
)
11701174
if Reactant.Sharding.is_sharded(carg)
1171-
# Currently disabling the error since we roundtrip from MHLO to generate
1172-
# the shardings
1173-
# # Check if the sharding provided is same as the one we have
1174-
# arg_condensed_op_sharding = Reactant.Sharding.XLA.CondensedOpSharding(
1175-
# Reactant.Sharding.ShardingWithShape(carg.sharding, size(carg))
1176-
# )
1177-
# @assert arg_condensed_op_sharding == condensed_op_sharding "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE."
1175+
arg_condensed_op_sharding = convert(
1176+
Reactant.Sharding.XLA.CondensedOpSharding,
1177+
carg.sharding.sharding.hlo_sharding,
1178+
)
1179+
# Check if the sharding provided is same as the one we have
1180+
@assert arg_condensed_op_sharding == condensed_op_sharding "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE."
11781181

11791182
push!(flatten_code, :($usbuf = $flatcode.data))
11801183
for j in 1:length(mesh)
@@ -1183,11 +1186,8 @@ function codegen_flatten!(
11831186
push!(flatten_code, :($sbuf = XLA.synced_buffer(getindex($usbuf, $j))))
11841187
end
11851188
else
1186-
condensed_op_sharding = convert(
1187-
Reactant.Sharding.XLA.CondensedOpSharding, linear_parameter_shardings[i]
1188-
)
11891189
push!(flatten_code, :($usbuf = $flatcode))
1190-
device_to_array_slices = XLA.sharding_to_concrete_array_indices(
1190+
device_to_array_slices, _ = XLA.sharding_to_concrete_array_indices(
11911191
condensed_op_sharding, size(carg), mesh.logical_device_ids
11921192
)
11931193
for j in 1:length(mesh)

src/Sharding.jl

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ function (sharding::NamedSharding)(
200200
client::XLA.PJRT.Client, device::Nothing, x::Union{AbstractArray,Number}
201201
)
202202
@assert length(sharding.partition_spec) == ndims(x)
203-
return (convert(HloSharding, sharding))(client, device, x)
203+
return HloSharding(sharding, client, device, x)
204204
end
205205

206206
function get_shardy_tensor_sharding_attribute(
@@ -339,7 +339,9 @@ end
339339
return ShardInfo{HloSharding{D1,D2},Vector{NTuple{N,UnitRange{Int64}}}}
340340
end
341341

342-
function Base.convert(::Type{HloSharding}, sharding::NamedSharding)
342+
# This doesn't account for the size of the input so in-presence of padding this will be
343+
# incorrect. Hence always use the HloSharding constructor.
344+
function generate_hlo_sharding_from_tensor_attribute(sharding::NamedSharding)
343345
if MLIR.IR._has_context()
344346
ctx = MLIR.IR.context()
345347
else
@@ -370,14 +372,62 @@ function Base.convert(::Type{HloSharding}, sharding::NamedSharding)
370372
end
371373
end
372374

375+
function HloSharding(sharding::NamedSharding, client::XLA.PJRT.Client, _, x)
376+
hlo_sharding = generate_hlo_sharding_from_tensor_attribute(sharding)
377+
378+
# Check if the input needs to be padded. If so this sharding is not valid and we
379+
# need to request the tensor sharding from XLA
380+
condensed_op_sharding = convert(XLA.CondensedOpSharding, hlo_sharding.hlo_sharding)
381+
device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices(
382+
condensed_op_sharding, size(x), hlo_sharding.mesh.logical_device_ids
383+
)
384+
385+
if needs_padding
386+
# Compile a dummy function to get the tensor sharding
387+
tmp = if x isa Number
388+
Reactant.ConcretePJRTNumber(zero(eltype(x)))
389+
else
390+
Reactant.ConcretePJRTArray(ones(eltype(x), size(x)...))
391+
end
392+
_, exec, _, _, _ = Reactant.Compiler.compile_xla(
393+
Reactant.Ops.negate, (tmp,); input_shardings=IdDict(tmp => sharding)
394+
)
395+
xla_hlo_sharding = convert(
396+
Reactant.XLA.HloSharding, only(Reactant.XLA.get_parameter_shardings(exec))
397+
)
398+
hlo_sharding = HloSharding(
399+
xla_hlo_sharding,
400+
hlo_sharding.mesh,
401+
hlo_sharding.is_closed,
402+
hlo_sharding.priority,
403+
)
404+
405+
condensed_op_sharding = convert(XLA.CondensedOpSharding, hlo_sharding.hlo_sharding)
406+
device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices(
407+
condensed_op_sharding, size(x), hlo_sharding.mesh.logical_device_ids
408+
)
409+
end
410+
411+
data = ntuple(length(hlo_sharding.mesh)) do i
412+
XLA.PJRT.AsyncBuffer(
413+
client,
414+
x[device_to_array_slices[i]...],
415+
XLA.get_device(client, hlo_sharding.mesh.device_ids[i]),
416+
)
417+
end
418+
419+
return data, ShardInfo(hlo_sharding, device_to_array_slices)
420+
end
421+
373422
function (sharding::HloSharding)(
374423
client::XLA.PJRT.Client, ::Nothing, x::Union{AbstractArray,Number}
375424
)
376425
condensed_op_sharding = convert(XLA.CondensedOpSharding, sharding.hlo_sharding)
377426

378-
device_to_array_slices = XLA.sharding_to_concrete_array_indices(
427+
device_to_array_slices, needs_padding = XLA.sharding_to_concrete_array_indices(
379428
condensed_op_sharding, size(x), sharding.mesh.logical_device_ids
380429
)
430+
@assert !needs_padding "This shouldn't happen. Open an issue on Reactant.jl"
381431

382432
data = ntuple(length(sharding.mesh)) do i
383433
XLA.PJRT.AsyncBuffer(

src/TracedUtils.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ function make_mlir_fn(
161161
args_in_result::Symbol=:all,
162162
construct_function_without_args::Bool=false,
163163
do_transpose=true,
164+
input_shardings=nothing, # This is not meant to be used by the user.
164165
)
165166
if sizeof(typeof(f)) != 0 || f isa Base.BroadcastFunction
166167
mlir_fn_res = make_mlir_fn(
@@ -173,6 +174,7 @@ function make_mlir_fn(
173174
return_dialect,
174175
do_transpose,
175176
args_in_result,
177+
input_shardings,
176178
)
177179
mlir_fn_res.fnwrapped = true
178180
return mlir_fn_res
@@ -221,10 +223,14 @@ function make_mlir_fn(
221223
# Insert meshes for the sharded arguments
222224
traced_args_to_shardings = OrderedIdDict()
223225
for (k, v) in seen_args
224-
if (k isa Reactant.ConcretePJRTArray || k isa Reactant.ConcretePJRTNumber) &&
225-
Reactant.Sharding.is_sharded(k)
226-
Reactant.Ops.mesh(k.sharding.mesh)
227-
traced_args_to_shardings[v] = k.sharding
226+
if (k isa Reactant.ConcretePJRTArray || k isa Reactant.ConcretePJRTNumber)
227+
if Reactant.Sharding.is_sharded(k)
228+
Reactant.Ops.mesh(k.sharding.mesh)
229+
traced_args_to_shardings[v] = k.sharding
230+
elseif input_shardings !== nothing && haskey(input_shardings, k)
231+
Reactant.Ops.mesh(input_shardings[k].mesh)
232+
traced_args_to_shardings[v] = input_shardings[k]
233+
end
228234
end
229235
end
230236

src/xla/IFRT/Array.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ function Array(client::Client, array::Base.Array{T,N}, sharding::Sharding) where
4848
end
4949

5050
all_devices = XLA.devices(sharding)
51-
array_slices = XLA.sharding_to_concrete_array_indices(
51+
array_slices, _ = XLA.sharding_to_concrete_array_indices(
5252
convert(XLA.HloSharding, sharding),
5353
size(array),
5454
collect(Int64, 0:(length(all_devices) - 1)),
@@ -159,7 +159,7 @@ function XLA.to_host(buffer::Array, data)
159159
# avoid the complexity of supporting that for now.
160160
single_device_arrays = disassemble_into_single_device_arrays(buffer, true)
161161

162-
array_slices = XLA.sharding_to_concrete_array_indices(
162+
array_slices, _ = XLA.sharding_to_concrete_array_indices(
163163
convert(XLA.HloSharding, sharding),
164164
size(data),
165165
collect(Int64, 0:(length(all_devices) - 1)),

src/xla/Sharding.jl

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -251,31 +251,36 @@ function sharding_to_concrete_array_indices(
251251
sharding::CondensedOpSharding, shape::Dims{N}, device_ids
252252
) where {N}
253253
if sharding.type == OpShardingType.Replicated || sharding.type == OpShardingType.Maximal
254-
return ntuple(Returns(ntuple(i -> 1:shape[i], N)), length(device_ids))
254+
return map(Returns(UnitRange.(1, shape)), device_ids), false
255255
elseif sharding.type == OpShardingType.Other
256256
partitions, num_replicas = get_number_of_ways_dim_sharded(sharding)
257257
@assert length(partitions) == length(shape)
258258
shape = reverse(shape)
259259

260+
# XLA will automatically pad the inputs that don't match the final shape
261+
partitionable_shape = map(zip(shape, partitions)) do (dim, n_shards)
262+
dim % n_shards == 0 && return dim
263+
res = dim + n_shards ÷ 2
264+
return res - res % n_shards
265+
end
266+
partitionable_shape = Tuple(partitionable_shape)
267+
268+
needs_padding = any(partitionable_shape .!= shape)
269+
260270
# Calculate indices for each dimension
261-
axis_indices = map(zip(shape, partitions)) do (dim, n_shards)
262-
@assert dim > 0 "Invalid dimension: $dim"
263-
@assert n_shards > 0 "Invalid number of shards: $n_shards"
264-
n_shards == 1 && return [1:dim]
265-
shard_size, remainder = divrem(dim, n_shards)
266-
267-
if remainder != 0
268-
throw(
269-
DimensionMismatch(
270-
"Dimension of Size $(dim) cannot be partitioned into $(n_shards) \
271-
shards each of size $(shard_size) (remainder = $(remainder)).",
272-
),
273-
)
271+
axis_indices =
272+
map(zip(partitionable_shape, shape, partitions)) do (dim_padded, dim, n_shards)
273+
@assert dim > 0 "Invalid dimension: $dim"
274+
@assert n_shards > 0 "Invalid number of shards: $n_shards"
275+
n_shards == 1 && return [1:dim]
276+
shard_size = dim_padded ÷ n_shards
277+
278+
return [
279+
(i * shard_size + 1):min((i + 1) * shard_size, dim) for
280+
i in 0:(n_shards - 1)
281+
]
274282
end
275283

276-
return [(i * shard_size + 1):((i + 1) * shard_size) for i in 0:(n_shards - 1)]
277-
end
278-
279284
indices = Dict{Int,NTuple{N,UnitRange{Int}}}()
280285
device_idx = 1
281286
for _ in 1:num_replicas
@@ -285,7 +290,7 @@ function sharding_to_concrete_array_indices(
285290
end
286291
end
287292

288-
return map(Base.Fix1(getindex, indices), device_ids)
293+
return map(Base.Fix1(getindex, indices), device_ids), needs_padding
289294
else
290295
error("Unsupported sharding type: $(sharding.type)")
291296
end
@@ -295,7 +300,7 @@ function compute_array_indices_and_hlo_sharding(
295300
sharding::CondensedOpSharding, array_size, device_ids
296301
)
297302
return (
298-
sharding_to_concrete_array_indices(sharding, array_size, device_ids),
303+
first(sharding_to_concrete_array_indices(sharding, array_size, device_ids)),
299304
convert(HloSharding, sharding.opsharding),
300305
)
301306
end

test/sharding.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,27 @@ end
221221
end
222222
end
223223

224+
@testset "Sharding with non-divisible axes sizes" begin
225+
if length(Reactant.addressable_devices()) 8
226+
mesh = Sharding.Mesh(reshape(collect(Int64, 0:7), 2, 4), ("data", "model"))
227+
x = reshape(collect(Float32, 1:14), 2, 7)
228+
x_ra = Reactant.to_rarray(
229+
x; sharding=Sharding.NamedSharding(mesh, ("data", "model"))
230+
)
231+
232+
@test Array(@jit sum(x_ra; dims=2)) sum(x; dims=2)
233+
234+
x = reshape(collect(Float32, 1:25), 5, 5)
235+
x_ra = Reactant.to_rarray(
236+
x; sharding=Sharding.NamedSharding(mesh, ("data", "model"))
237+
)
238+
239+
@test Array(@jit fn_test2(x_ra)) fn_test2(x)
240+
else
241+
@warn "Not enough addressable devices to run sharding tests"
242+
end
243+
end
244+
224245
# Tests from the examples in
225246
# https://github.com/openxla/xla/blob/96d6678053d867099a42be9001c49b2ed7111afd/xla/hlo/ir/tile_assignment.h#L53-L68
226247
@testset "Device List from Iota Tile" begin

0 commit comments

Comments
 (0)