Skip to content

Commit 11a38ed

Browse files
committed
feat: support implicit padding from XLA
1 parent a39b055 commit 11a38ed

File tree

2 files changed

+28
-26
lines changed

2 files changed

+28
-26
lines changed

src/Compiler.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,14 +1167,16 @@ function codegen_flatten!(
11671167

11681168
if is_sharded
11691169
carg = inv_seen_args[arg]
1170+
condensed_op_sharding = convert(
1171+
Reactant.Sharding.XLA.CondensedOpSharding, linear_parameter_shardings[i]
1172+
)
11701173
if Reactant.Sharding.is_sharded(carg)
1171-
# Currently disabling the error since we roundtrip from MHLO to generate
1172-
# the shardings
1173-
# # Check if the sharding provided is same as the one we have
1174-
# arg_condensed_op_sharding = Reactant.Sharding.XLA.CondensedOpSharding(
1175-
# Reactant.Sharding.ShardingWithShape(carg.sharding, size(carg))
1176-
# )
1177-
# @assert arg_condensed_op_sharding == condensed_op_sharding "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE."
1174+
arg_condensed_op_sharding = convert(
1175+
Reactant.Sharding.XLA.CondensedOpSharding,
1176+
carg.sharding.sharding.hlo_sharding,
1177+
)
1178+
# Check if the sharding provided is same as the one we have
1179+
@assert arg_condensed_op_sharding == condensed_op_sharding "Sharding provided by the user ($arg_condensed_op_sharding) does not match the sharding computed by XLA ($condensed_op_sharding). This generally means that Reactant.jl made an error in generating the executable. Please open an issue with the error message and an MWE."
11781180

11791181
push!(flatten_code, :($usbuf = $flatcode.data))
11801182
for j in 1:length(mesh)
@@ -1183,9 +1185,6 @@ function codegen_flatten!(
11831185
push!(flatten_code, :($sbuf = XLA.synced_buffer(getindex($usbuf, $j))))
11841186
end
11851187
else
1186-
condensed_op_sharding = convert(
1187-
Reactant.Sharding.XLA.CondensedOpSharding, linear_parameter_shardings[i]
1188-
)
11891188
push!(flatten_code, :($usbuf = $flatcode))
11901189
device_to_array_slices = XLA.sharding_to_concrete_array_indices(
11911190
condensed_op_sharding, size(carg), mesh.logical_device_ids

src/xla/Sharding.jl

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -257,25 +257,28 @@ function sharding_to_concrete_array_indices(
257257
@assert length(partitions) == length(shape)
258258
shape = reverse(shape)
259259

260+
# XLA will automatically pad the inputs that don't match the final shape
261+
partitionable_shape = map(zip(shape, partitions)) do (dim, n_shards)
262+
dim % n_shards == 0 && return dim
263+
res = dim + n_shards ÷ 2
264+
return res - res % n_shards
265+
end
266+
partitionable_shape = Tuple(partitionable_shape)
267+
260268
# Calculate indices for each dimension
261-
axis_indices = map(zip(shape, partitions)) do (dim, n_shards)
262-
@assert dim > 0 "Invalid dimension: $dim"
263-
@assert n_shards > 0 "Invalid number of shards: $n_shards"
264-
n_shards == 1 && return [1:dim]
265-
shard_size, remainder = divrem(dim, n_shards)
266-
267-
if remainder != 0
268-
throw(
269-
DimensionMismatch(
270-
"Dimension of Size $(dim) cannot be partitioned into $(n_shards) \
271-
shards each of size $(shard_size) (remainder = $(remainder)).",
272-
),
273-
)
269+
axis_indices =
270+
map(zip(partitionable_shape, shape, partitions)) do (dim_padded, dim, n_shards)
271+
@assert dim > 0 "Invalid dimension: $dim"
272+
@assert n_shards > 0 "Invalid number of shards: $n_shards"
273+
n_shards == 1 && return [1:dim]
274+
shard_size = dim_padded ÷ n_shards
275+
276+
return [
277+
(i * shard_size + 1):min((i + 1) * shard_size, dim) for
278+
i in 0:(n_shards - 1)
279+
]
274280
end
275281

276-
return [(i * shard_size + 1):((i + 1) * shard_size) for i in 0:(n_shards - 1)]
277-
end
278-
279282
indices = Dict{Int,NTuple{N,UnitRange{Int}}}()
280283
device_idx = 1
281284
for _ in 1:num_replicas

0 commit comments

Comments
 (0)