Skip to content

Commit 31b2111

Browse files
committed
test: padding for sharding
1 parent 6b53b75 commit 31b2111

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

test/sharding.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,27 @@ end
221221
end
222222
end
223223

224+
@testset "Sharding with non-divisible axes sizes" begin
225+
if length(Reactant.addressable_devices()) 8
226+
mesh = Sharding.Mesh(reshape(collect(Int64, 0:7), 2, 4), ("data", "model"))
227+
x = reshape(collect(Float32, 1:14), 2, 7)
228+
x_ra = Reactant.to_rarray(
229+
x; sharding=Sharding.NamedSharding(mesh, ("data", "model"))
230+
)
231+
232+
@test Array(@jit sum(x_ra; dims=2)) sum(x; dims=2)
233+
234+
x = reshape(collect(Float32, 1:25), 5, 5)
235+
x_ra = Reactant.to_rarray(
236+
x; sharding=Sharding.NamedSharding(mesh, ("data", "model"))
237+
)
238+
239+
@test Array(@jit fn_test2(x_ra)) fn_test2(x)
240+
else
241+
@warn "Not enough addressable devices to run sharding tests"
242+
end
243+
end
244+
224245
# Tests from the examples in
225246
# https://github.com/openxla/xla/blob/96d6678053d867099a42be9001c49b2ed7111afd/xla/hlo/ir/tile_assignment.h#L53-L68
226247
@testset "Device List from Iota Tile" begin

0 commit comments

Comments
 (0)