@@ -777,11 +777,28 @@ end
777777function Base. partialsortperm (
778778 x:: AnyTracedRVector , k:: Union{Integer,OrdinalRange} ; kwargs...
779779)
780- return partialsortperm! (similar (x, Int), x, k; kwargs... )
780+ idxs = overloaded_partialsortperm (x, k; kwargs... )
781+ k isa Integer && return @allowscalar idxs[k]
782+ return view (idxs, k)
781783end
782784
783785function Base. partialsortperm! (
784786 ix:: AnyTracedRVector{Int} ,
787+ x:: AnyTracedRVector ,
788+ k:: Union{Integer,OrdinalRange} ; kwargs...
789+ )
790+ idxs = overloaded_partialsortperm (x, k; kwargs... )
791+
792+ if k isa Integer
793+ @allowscalar setindex! (ix, idxs[k], k)
794+ return idxs
795+ else
796+ setindex! (ix, idxs[k], k)
797+ return view (ix, k)
798+ end
799+ end
800+
801+ function overloaded_partialsortperm (
785802 x:: AnyTracedRVector ,
786803 k:: Union{Integer,OrdinalRange} ;
787804 by= identity,
@@ -791,24 +808,11 @@ function Base.partialsortperm!(
791808 # TODO : general `lt` support
792809 @assert lt === isless " Only `isless` is supported for now in `partialsortperm!`"
793810
794- by_x = by .(x)
795- # XXX : If `maxk` is beyond a threshold should we emit a sort directly?
796- if k isa Integer
797- ! rev && (k = length (x) - k + 1 )
798- (; values, indices) = Ops. top_k (materialize_traced_array (by_x), k)
799- indices = Ops. convert (TracedRArray{Int64,1 }, indices)
800- idx = @allowscalar indices[k]
801- @allowscalar setindex! (ix, idx, k)
802- return idx
803- else
804- klist = collect (Int64, k)
805- ! rev && (klist = length (x) .- klist .+ 1 )
806- maxk = maximum (klist)
807- (; values, indices) = Ops. top_k (materialize_traced_array (by_x), maxk)
808- indices = Ops. convert (TracedRArray{Int64,1 }, indices)
809- setindex! (ix, indices[klist], klist)
810- return indices[klist]
811- end
811+ # XXX : If `maxk` is beyond a threshold should we emit a sort directly? Or do a neg
812+ ! rev && (k = length (x) .- k .+ 1 )
813+ ! (k isa Integer) && (k = maximum (k))
814+ (; indices) = Ops. top_k (materialize_traced_array (by .(x)), k)
815+ return Ops. convert (TracedRArray{Int64,1 }, indices)
812816end
813817
814818function Base. argmin (x:: AnyTracedRArray ; kwargs... )
0 commit comments