Skip to content

Commit 59b9ad9

Browse files
committed
refactor: move tests around a bit
1 parent 1a576e2 commit 59b9ad9

File tree

4 files changed

+230
-101
lines changed

4 files changed

+230
-101
lines changed

test/basic.jl

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -419,37 +419,6 @@ end
419419
end
420420
end
421421

422-
function update_on_copy(x)
423-
y = x[1:2, 2:4, :]
424-
y[1:1, 1:1, :] = ones(1, 1, 3)
425-
return y
426-
end
427-
428-
@testset "view / setindex" begin
429-
x = rand(2, 4, 3)
430-
y = copy(x)
431-
x_concrete = Reactant.to_rarray(x)
432-
y_concrete = Reactant.to_rarray(y)
433-
434-
y1 = update_on_copy(x)
435-
y2 = @jit update_on_copy(x_concrete)
436-
@test x == y
437-
@test x_concrete == y_concrete
438-
@test y1 == y2
439-
440-
# function update_inplace(x)
441-
# y = view(x, 1:2, 1:2, :)
442-
# y[1, 1, :] .= 1
443-
# return y
444-
# end
445-
446-
# get_indices(x) = x[1:2, 1:2, :]
447-
# get_view(x) = view(x, 1:2, 1:2, :)
448-
449-
# get_indices_compiled = @compile get_indices(x_concrete)
450-
# get_view_compiled = @compile get_view(x_concrete)
451-
end
452-
453422
function write_with_broadcast1!(x, y)
454423
x[1, :, :] .= reshape(y, 4, 3)
455424
return x
@@ -483,63 +452,6 @@ end
483452
@test res[:, 1, :] view(y, :, 1:3)
484453
end
485454

486-
function masking(x)
487-
y = similar(x)
488-
y[1:2, :] .= 0
489-
y[3:4, :] .= 1
490-
return y
491-
end
492-
493-
function masking!(x)
494-
x[1:2, :] .= 0
495-
x[3:4, :] .= 1
496-
return x
497-
end
498-
499-
@testset "setindex! with views" begin
500-
x = rand(4, 4) .+ 2.0
501-
x_ra = Reactant.to_rarray(x)
502-
503-
y = masking(x)
504-
y_ra = @jit(masking(x_ra))
505-
@test y y_ra
506-
507-
x_ra_array = Array(x_ra)
508-
@test !(any(iszero, x_ra_array[1, :]))
509-
@test !(any(iszero, x_ra_array[2, :]))
510-
@test !(any(isone, x_ra_array[3, :]))
511-
@test !(any(isone, x_ra_array[4, :]))
512-
513-
y_ra = @jit(masking!(x_ra))
514-
@test y y_ra
515-
516-
x_ra_array = Array(x_ra)
517-
@test @allowscalar all(iszero, x_ra_array[1, :])
518-
@test @allowscalar all(iszero, x_ra_array[2, :])
519-
@test @allowscalar all(isone, x_ra_array[3, :])
520-
@test @allowscalar all(isone, x_ra_array[4, :])
521-
end
522-
523-
function non_contiguous_setindex!(x)
524-
x[[1, 3, 2], [1, 2, 3, 4]] .= 1.0
525-
return x
526-
end
527-
528-
@testset "non-contiguous setindex!" begin
529-
x = rand(6, 6)
530-
x_ra = Reactant.to_rarray(x)
531-
532-
y = @jit(non_contiguous_setindex!(x_ra))
533-
y = Array(y)
534-
x_ra = Array(x_ra)
535-
@test all(isone, y[1:3, 1:4])
536-
@test all(isone, x_ra[1:3, 1:4])
537-
@test !all(isone, y[4:end, :])
538-
@test !all(isone, x_ra[4:end, :])
539-
@test !all(isone, y[:, 5:end])
540-
@test !all(isone, x_ra[:, 5:end])
541-
end
542-
543455
tuple_byref(x) = (; a=(; b=x))
544456
tuple_byref2(x) = abs2.(x), tuple_byref2(x)
545457

@@ -681,19 +593,6 @@ end
681593
end
682594
end
683595

684-
@testset "dynamic indexing" begin
685-
x = randn(5, 3)
686-
x_ra = Reactant.to_rarray(x)
687-
688-
idx = [1, 2, 3]
689-
idx_ra = Reactant.to_rarray(idx)
690-
691-
fn(x, idx) = @allowscalar x[idx, :]
692-
693-
y = @jit(fn(x_ra, idx_ra))
694-
@test y x[idx, :]
695-
end
696-
697596
@testset "aos_to_soa" begin
698597
using ArrayInterface
699598

test/indexing.jl

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
using LinearAlgebra, Reactant, Test
2+
3+
function update_on_copy(x)
4+
y = x[1:2, 2:4, :]
5+
y[1:1, 1:1, :] = ones(1, 1, 3)
6+
return y
7+
end
8+
9+
@testset "view / setindex" begin
10+
x = rand(2, 4, 3)
11+
y = copy(x)
12+
x_concrete = Reactant.to_rarray(x)
13+
y_concrete = Reactant.to_rarray(y)
14+
15+
y1 = update_on_copy(x)
16+
y2 = @jit update_on_copy(x_concrete)
17+
@test x == y
18+
@test x_concrete == y_concrete
19+
@test y1 == y2
20+
21+
# function update_inplace(x)
22+
# y = view(x, 1:2, 1:2, :)
23+
# y[1, 1, :] .= 1
24+
# return y
25+
# end
26+
27+
# get_indices(x) = x[1:2, 1:2, :]
28+
# get_view(x) = view(x, 1:2, 1:2, :)
29+
30+
# get_indices_compiled = @compile get_indices(x_concrete)
31+
# get_view_compiled = @compile get_view(x_concrete)
32+
end
33+
34+
function masking(x)
35+
y = similar(x)
36+
y[1:2, :] .= 0
37+
y[3:4, :] .= 1
38+
return y
39+
end
40+
41+
function masking!(x)
42+
x[1:2, :] .= 0
43+
x[3:4, :] .= 1
44+
return x
45+
end
46+
47+
@testset "setindex! with views" begin
48+
x = rand(4, 4) .+ 2.0
49+
x_ra = Reactant.to_rarray(x)
50+
51+
y = masking(x)
52+
y_ra = @jit(masking(x_ra))
53+
@test y y_ra
54+
55+
x_ra_array = Array(x_ra)
56+
@test !(any(iszero, x_ra_array[1, :]))
57+
@test !(any(iszero, x_ra_array[2, :]))
58+
@test !(any(isone, x_ra_array[3, :]))
59+
@test !(any(isone, x_ra_array[4, :]))
60+
61+
y_ra = @jit(masking!(x_ra))
62+
@test y y_ra
63+
64+
x_ra_array = Array(x_ra)
65+
@test @allowscalar all(iszero, x_ra_array[1, :])
66+
@test @allowscalar all(iszero, x_ra_array[2, :])
67+
@test @allowscalar all(isone, x_ra_array[3, :])
68+
@test @allowscalar all(isone, x_ra_array[4, :])
69+
end
70+
71+
function non_contiguous_setindex!(x)
72+
x[[1, 3, 2], [1, 2, 3, 4]] .= 1.0
73+
return x
74+
end
75+
76+
@testset "non-contiguous setindex!" begin
77+
x = rand(6, 6)
78+
x_ra = Reactant.to_rarray(x)
79+
80+
y = @jit(non_contiguous_setindex!(x_ra))
81+
y = Array(y)
82+
x_ra = Array(x_ra)
83+
@test all(isone, y[1:3, 1:4])
84+
@test all(isone, x_ra[1:3, 1:4])
85+
@test !all(isone, y[4:end, :])
86+
@test !all(isone, x_ra[4:end, :])
87+
@test !all(isone, y[:, 5:end])
88+
@test !all(isone, x_ra[:, 5:end])
89+
end
90+
91+
@testset "dynamic indexing" begin
92+
x = randn(5, 3)
93+
x_ra = Reactant.to_rarray(x)
94+
95+
idx = [1, 2, 3]
96+
idx_ra = Reactant.to_rarray(idx)
97+
98+
fn(x, idx) = @allowscalar x[idx, :]
99+
100+
y = @jit(fn(x_ra, idx_ra))
101+
@test y x[idx, :]
102+
end
103+
104+
@testset "non-contiguous indexing" begin
105+
x = rand(4, 4, 3)
106+
x_ra = Reactant.to_rarray(x)
107+
108+
non_contiguous_indexing1(x) = x[[1, 3, 2], :, :]
109+
non_contiguous_indexing2(x) = x[:, [1, 2, 1, 3], [1, 3]]
110+
111+
@test @jit(non_contiguous_indexing1(x_ra)) non_contiguous_indexing1(x)
112+
@test @jit(non_contiguous_indexing2(x_ra)) non_contiguous_indexing2(x)
113+
114+
x = rand(4, 2)
115+
x_ra = Reactant.to_rarray(x)
116+
117+
non_contiguous_indexing3(x) = x[[1, 3, 2], :]
118+
non_contiguous_indexing4(x) = x[:, [1, 2, 2]]
119+
120+
@test @jit(non_contiguous_indexing3(x_ra)) non_contiguous_indexing3(x)
121+
@test @jit(non_contiguous_indexing4(x_ra)) non_contiguous_indexing4(x)
122+
123+
x = rand(4, 4, 3)
124+
x_ra = Reactant.to_rarray(x)
125+
126+
non_contiguous_indexing1!(x) = x[[1, 3, 2], :, :] .= 2
127+
non_contiguous_indexing2!(x) = x[:, [1, 2, 1, 3], [1, 3]] .= 2
128+
129+
@jit(non_contiguous_indexing1!(x_ra))
130+
non_contiguous_indexing1!(x)
131+
@test x_ra x
132+
133+
x = rand(4, 4, 3)
134+
x_ra = Reactant.to_rarray(x)
135+
136+
@jit(non_contiguous_indexing2!(x_ra))
137+
non_contiguous_indexing2!(x)
138+
@test x_ra x
139+
140+
x = rand(4, 2)
141+
x_ra = Reactant.to_rarray(x)
142+
143+
non_contiguous_indexing3!(x) = x[[1, 3, 2], :] .= 2
144+
non_contiguous_indexing4!(x) = x[:, [1, 2, 2]] .= 2
145+
146+
@jit(non_contiguous_indexing3!(x_ra))
147+
non_contiguous_indexing3!(x)
148+
@test x_ra x
149+
150+
x = rand(4, 2)
151+
x_ra = Reactant.to_rarray(x)
152+
153+
@jit(non_contiguous_indexing4!(x_ra))
154+
non_contiguous_indexing4!(x)
155+
@test x_ra x
156+
end
157+
158+
@testset "indexing with traced arrays" begin
159+
x = rand(4, 4, 3)
160+
idx1 = [1, 3, 2]
161+
idx3 = [1, 2, 1, 3]
162+
163+
x_ra = Reactant.to_rarray(x)
164+
idx1_ra = Reactant.to_rarray(idx1)
165+
idx3_ra = Reactant.to_rarray(idx3)
166+
167+
getindex1(x, idx1) = x[idx1, :, :]
168+
getindex2(x, idx1) = x[:, idx1, :]
169+
getindex3(x, idx3) = x[:, :, idx3]
170+
getindex4(x, idx1, idx3) = x[idx1, :, idx3]
171+
172+
@test @jit(getindex1(x_ra, idx1_ra)) getindex1(x, idx1)
173+
@test @jit(getindex2(x_ra, idx1_ra)) getindex2(x, idx1)
174+
@test @jit(getindex3(x_ra, idx3_ra)) getindex3(x, idx3)
175+
@test @jit(getindex4(x_ra, idx1_ra, idx3_ra)) getindex4(x, idx1, idx3)
176+
end
177+
178+
@testset "linear indexing" begin
179+
x = rand(4, 4, 3)
180+
x_ra = Reactant.to_rarray(x)
181+
182+
getindex_linear_scalar(x, idx) = @allowscalar x[idx]
183+
184+
@testset for i in 1:length(x)
185+
@test @jit(getindex_linear_scalar(x_ra, i)) getindex_linear_scalar(x, i)
186+
@test @jit(
187+
getindex_linear_scalar(x_ra, Reactant.to_rarray(i; track_numbers=Number))
188+
) getindex_linear_scalar(x, i)
189+
end
190+
191+
idx = rand(1:length(x), 8)
192+
idx_ra = Reactant.to_rarray(idx)
193+
194+
getindex_linear_vector(x, idx) = x[idx]
195+
196+
@test @jit(getindex_linear_vector(x_ra, idx_ra)) getindex_linear_vector(x, idx)
197+
@test @jit(getindex_linear_vector(x_ra, idx)) getindex_linear_vector(x, idx)
198+
end
199+
200+
@testset "Boolean Indexing" begin
201+
x_ra = Reactant.to_rarray(rand(Float32, 4, 16))
202+
idxs_ra = Reactant.to_rarray(rand(Bool, 16))
203+
204+
fn(x, idxs) = x[:, idxs]
205+
206+
@test_throws ErrorException @jit(fn(x_ra, idxs_ra))
207+
208+
res = @jit fn(x_ra, Array(idxs_ra))
209+
@test res fn(Array(x_ra), Array(idxs_ra))
210+
end
211+
212+
@testset "inconsistent indexing" begin
213+
x_ra = Reactant.to_rarray(rand(3, 4, 3))
214+
idx_ra = Reactant.to_rarray(1; track_numbers=Number)
215+
216+
fn1(x) = x[:, :, 1]
217+
fn2(x, idx) = x[:, :, idx]
218+
fn3(x, idx) = x[idx, :, 1]
219+
220+
@test ndims(@jit(fn1(x_ra))) == 2
221+
@test ndims(@jit(fn2(x_ra, idx_ra))) == 2
222+
@test ndims(@jit(fn3(x_ra, idx_ra))) == 1
223+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5757
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
5858
@safetestset "Control Flow" include("control_flow.jl")
5959
@safetestset "Sorting" include("sorting.jl")
60+
@safetestset "Indexing" include("indexing.jl")
6061
end
6162

6263
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"

test/wrapped_arrays.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,9 @@ end
228228
@jit(broadcast_reshaped_array(x_ra, idx1_ra, idx3))
229229
@jit(broadcast_reshaped_array(x_ra, Array(idx1_ra), Int64(idx3)))
230230
end
231+
232+
@testset "reshaped subarray indexing" begin
233+
fn(x) = view(x, 1:2) .+ 1
234+
x_ra = Reactant.to_rarray(rand(3, 4, 3))
235+
@test @jit(fn(x_ra)) == fn(Array(x_ra))
236+
end

0 commit comments

Comments
 (0)