Skip to content

Commit a815a88

Browse files
committed
test: stack tests
1 parent 4bfe30d commit a815a88

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/TracedRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ end
607607

608608
# stack
609609
function overloaded_stack(dims::Union{Integer,Colon}, xs)
610-
@assert allequal(ndims, xs) "All arrays must have the same number of dimensions..."
610+
@assert allequal(ndims.(xs)) "All arrays must have the same number of dimensions..."
611611
dims = dims isa Colon ? ndims(first(xs)) + 1 : dims
612612
res = map(xs) do x
613613
new_shape = ntuple(

test/basic.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,3 +844,26 @@ end
844844
@test @jit(getindex_linear_vector(x_ra, idx_ra)) getindex_linear_vector(x, idx)
845845
@test @jit(getindex_linear_vector(x_ra, idx)) getindex_linear_vector(x, idx)
846846
end
847+
848+
@testset "stack" begin
849+
x = rand(4, 4)
850+
y = rand(4, 4)
851+
x_ra = Reactant.to_rarray(x)
852+
y_ra = Reactant.to_rarray(y)
853+
854+
s1(x) = stack((x, x))
855+
s2(x) = stack((x, x); dims=2)
856+
s3(x, y) = stack((x, y); dims=2)
857+
s4(x, y) = stack((x, y, x); dims=1)
858+
859+
@test @jit(s1(x_ra)) s1(x)
860+
@test @jit(s2(x_ra)) s2(x)
861+
@test @jit(s3(x_ra, y_ra)) s3(x, y)
862+
@test @jit(s4(x_ra, y_ra)) s4(x, y)
863+
864+
# Test that we don't hit illegal instruction; `x` is intentionally not a traced array
865+
@test @jit(s1(x)) isa Any
866+
@test @jit(s2(x)) isa Any
867+
@test @jit(s3(x, y)) isa Any
868+
@test @jit(s4(x, y)) isa Any
869+
end

0 commit comments

Comments
 (0)