Skip to content

Commit 17ea1d6

Browse files
committed
feat: keep lazy indexing
1 parent 2172ff2 commit 17ea1d6

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

src/TracedRArray.jl

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -777,11 +777,28 @@ end
777777
function Base.partialsortperm(
778778
x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...
779779
)
780-
return partialsortperm!(similar(x, Int), x, k; kwargs...)
780+
idxs = overloaded_partialsortperm(x, k; kwargs...)
781+
k isa Integer && return @allowscalar idxs[k]
782+
return view(idxs, k)
781783
end
782784

783785
function Base.partialsortperm!(
784786
ix::AnyTracedRVector{Int},
787+
x::AnyTracedRVector,
788+
k::Union{Integer,OrdinalRange}; kwargs...
789+
)
790+
idxs = overloaded_partialsortperm(x, k; kwargs...)
791+
792+
if k isa Integer
793+
@allowscalar setindex!(ix, idxs[k], k)
794+
return idxs
795+
else
796+
setindex!(ix, idxs[k], k)
797+
return view(ix, k)
798+
end
799+
end
800+
801+
function overloaded_partialsortperm(
785802
x::AnyTracedRVector,
786803
k::Union{Integer,OrdinalRange};
787804
by=identity,
@@ -791,24 +808,11 @@ function Base.partialsortperm!(
791808
# TODO: general `lt` support
792809
@assert lt === isless "Only `isless` is supported for now in `partialsortperm!`"
793810

794-
by_x = by.(x)
795-
# XXX: If `maxk` is beyond a threshold should we emit a sort directly?
796-
if k isa Integer
797-
!rev && (k = length(x) - k + 1)
798-
(; values, indices) = Ops.top_k(materialize_traced_array(by_x), k)
799-
indices = Ops.convert(TracedRArray{Int64,1}, indices)
800-
idx = @allowscalar indices[k]
801-
@allowscalar setindex!(ix, idx, k)
802-
return idx
803-
else
804-
klist = collect(Int64, k)
805-
!rev && (klist = length(x) .- klist .+ 1)
806-
maxk = maximum(klist)
807-
(; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk)
808-
indices = Ops.convert(TracedRArray{Int64,1}, indices)
809-
setindex!(ix, indices[klist], klist)
810-
return indices[klist]
811-
end
811+
# XXX: If `maxk` is beyond a threshold should we emit a sort directly? Or do a neg
812+
!rev && (k = length(x) .- k .+ 1)
813+
!(k isa Integer) && (k = maximum(k))
814+
(; indices) = Ops.top_k(materialize_traced_array(by.(x)), k)
815+
return Ops.convert(TracedRArray{Int64,1}, indices)
812816
end
813817

814818
function Base.argmin(x::AnyTracedRArray; kwargs...)

0 commit comments

Comments
 (0)