@@ -39,3 +39,54 @@ using Reactant, Test
3939 @test argmin (abs2, x) == @jit (argmin (abs2, x_ra))
4040 @test argmax (abs2, x) == @jit (argmax (abs2, x_ra))
4141end
42+
43+ @testset " findmin / findmax" begin
44+ xvec = randn (10 )
45+ xvec_ra = Reactant. to_rarray (xvec)
46+
47+ x = randn (2 , 3 )
48+ x_ra = Reactant. to_rarray (x)
49+
50+ function fwithlinindices (g, f, x; kwargs... )
51+ values, indices = g (f, x; kwargs... )
52+ return values, LinearIndices (x)[indices]
53+ end
54+
55+ @test fwithlinindices (findmin, identity, x) == @jit (findmin (x_ra))
56+ @test fwithlinindices (findmax, identity, x) == @jit (findmax (x_ra))
57+ @test fwithlinindices (findmin, identity, xvec) == @jit (findmin (xvec_ra))
58+ @test fwithlinindices (findmax, identity, xvec) == @jit (findmax (xvec_ra))
59+
60+ fmindims (x, d) = findmin (x; dims= d)
61+ fmindims (f, x, d) = findmin (f, x; dims= d)
62+ fmaxdims (x, d) = findmax (x; dims= d)
63+ fmaxdims (f, x, d) = findmax (f, x; dims= d)
64+
65+ @test fwithlinindices (findmin, identity, x; dims= 1 ) == @jit (fmindims (x_ra, 1 ))
66+ @test fwithlinindices (findmax, identity, x; dims= 1 ) == @jit (fmaxdims (x_ra, 1 ))
67+ @test fwithlinindices (findmin, identity, x; dims= 2 ) == @jit (fmindims (x_ra, 2 ))
68+ @test fwithlinindices (findmax, identity, x; dims= 2 ) == @jit (fmaxdims (x_ra, 2 ))
69+ @test fwithlinindices (findmin, abs2, x; dims= 1 ) == @jit (fmindims (abs2, x_ra, 1 ))
70+ @test fwithlinindices (findmax, abs2, x; dims= 1 ) == @jit (fmaxdims (abs2, x_ra, 1 ))
71+ @test fwithlinindices (findmin, abs2, x; dims= 2 ) == @jit (fmindims (abs2, x_ra, 2 ))
72+ @test fwithlinindices (findmax, abs2, x; dims= 2 ) == @jit (fmaxdims (abs2, x_ra, 2 ))
73+ end
74+
75+ @testset " findfirst / findlast" begin
76+ x = rand (Bool, 3 , 4 )
77+ x_ra = Reactant. to_rarray (x)
78+
79+ ffirstlinindices (x) = LinearIndices (x)[findfirst (x)]
80+ ffirstlinindices (f, x) = LinearIndices (x)[findfirst (f, x)]
81+ flastlinindices (x) = LinearIndices (x)[findlast (x)]
82+ flastlinindices (f, x) = LinearIndices (x)[findlast (f, x)]
83+
84+ @test ffirstlinindices (x) == @jit (findfirst (x_ra))
85+ @test flastlinindices (x) == @jit (findlast (x_ra))
86+
87+ x = rand (1 : 256 , 3 , 4 )
88+ x_ra = Reactant. to_rarray (x)
89+
90+ @test ffirstlinindices (iseven, x) == @jit (findfirst (iseven, x_ra))
91+ @test flastlinindices (iseven, x) == @jit (findlast (iseven, x_ra))
92+ end
0 commit comments