Skip to content

Commit 657f93d

Browse files
committed
fix: more tests and fixes for find functions
1 parent 75a2151 commit 657f93d

File tree

2 files changed

+58
-4
lines changed

2 files changed

+58
-4
lines changed

src/TracedRArray.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -836,14 +836,14 @@ Base.findfirst(x::AnyTracedRArray) = findfirst(identity, x)
836836
Base.findlast(x::AnyTracedRArray) = findlast(identity, x)
837837

838838
function Base.findfirst(f::Function, x::AnyTracedRArray)
839-
fA = f.(x)
840-
(; indices) = Ops.top_k(materialize_traced_array(fA), 1)
839+
fA = materialize_traced_array(vec(f.(x)))
840+
(; indices) = Ops.top_k(fA, 1)
841841
return @allowscalar indices[1]
842842
end
843843

844844
function Base.findlast(f::Function, x::AnyTracedRArray)
845845
fA = Ops.reverse(materialize_traced_array(vec(f.(x))); dimensions=[1])
846-
(; indices) = Ops.top_k(materialize_traced_array(fA), 1)
846+
(; indices) = Ops.top_k(fA, 1)
847847
return length(x) - @allowscalar(indices[1]) + 1
848848
end
849849

@@ -883,7 +883,9 @@ function Base.findmin(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin
883883
)
884884
end
885885

886-
return (Ops.negate(values), linear_indices)
886+
values = Ops.negate(values)
887+
ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1])
888+
return (values, linear_indices)
887889
end
888890

889891
function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing)
@@ -910,6 +912,7 @@ function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin
910912
)
911913
end
912914

915+
ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1])
913916
return (values, linear_indices)
914917
end
915918

test/sorting.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,54 @@ using Reactant, Test
3939
@test argmin(abs2, x) == @jit(argmin(abs2, x_ra))
4040
@test argmax(abs2, x) == @jit(argmax(abs2, x_ra))
4141
end
42+
43+
@testset "findmin / findmax" begin
44+
xvec = randn(10)
45+
xvec_ra = Reactant.to_rarray(xvec)
46+
47+
x = randn(2, 3)
48+
x_ra = Reactant.to_rarray(x)
49+
50+
function fwithlinindices(g, f, x; kwargs...)
51+
values, indices = g(f, x; kwargs...)
52+
return values, LinearIndices(x)[indices]
53+
end
54+
55+
@test fwithlinindices(findmin, identity, x) == @jit(findmin(x_ra))
56+
@test fwithlinindices(findmax, identity, x) == @jit(findmax(x_ra))
57+
@test fwithlinindices(findmin, identity, xvec) == @jit(findmin(xvec_ra))
58+
@test fwithlinindices(findmax, identity, xvec) == @jit(findmax(xvec_ra))
59+
60+
fmindims(x, d) = findmin(x; dims=d)
61+
fmindims(f, x, d) = findmin(f, x; dims=d)
62+
fmaxdims(x, d) = findmax(x; dims=d)
63+
fmaxdims(f, x, d) = findmax(f, x; dims=d)
64+
65+
@test fwithlinindices(findmin, identity, x; dims=1) == @jit(fmindims(x_ra, 1))
66+
@test fwithlinindices(findmax, identity, x; dims=1) == @jit(fmaxdims(x_ra, 1))
67+
@test fwithlinindices(findmin, identity, x; dims=2) == @jit(fmindims(x_ra, 2))
68+
@test fwithlinindices(findmax, identity, x; dims=2) == @jit(fmaxdims(x_ra, 2))
69+
@test fwithlinindices(findmin, abs2, x; dims=1) == @jit(fmindims(abs2, x_ra, 1))
70+
@test fwithlinindices(findmax, abs2, x; dims=1) == @jit(fmaxdims(abs2, x_ra, 1))
71+
@test fwithlinindices(findmin, abs2, x; dims=2) == @jit(fmindims(abs2, x_ra, 2))
72+
@test fwithlinindices(findmax, abs2, x; dims=2) == @jit(fmaxdims(abs2, x_ra, 2))
73+
end
74+
75+
@testset "findfirst / findlast" begin
76+
x = rand(Bool, 3, 4)
77+
x_ra = Reactant.to_rarray(x)
78+
79+
ffirstlinindices(x) = LinearIndices(x)[findfirst(x)]
80+
ffirstlinindices(f, x) = LinearIndices(x)[findfirst(f, x)]
81+
flastlinindices(x) = LinearIndices(x)[findlast(x)]
82+
flastlinindices(f, x) = LinearIndices(x)[findlast(f, x)]
83+
84+
@test ffirstlinindices(x) == @jit(findfirst(x_ra))
85+
@test flastlinindices(x) == @jit(findlast(x_ra))
86+
87+
x = rand(1:256, 3, 4)
88+
x_ra = Reactant.to_rarray(x)
89+
90+
@test ffirstlinindices(iseven, x) == @jit(findfirst(iseven, x_ra))
91+
@test flastlinindices(iseven, x) == @jit(findlast(iseven, x_ra))
92+
end

0 commit comments

Comments
 (0)