Skip to content

Commit a15e5ca

Browse files
committed
feat: implement perm related functions
1 parent ce04478 commit a15e5ca

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

src/TracedRArray.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,4 +704,23 @@ function Base.sort!(
704704
return x
705705
end
706706

707+
Base.sortperm(x::AnyTracedRArray; kwargs...) = sortperm!(similar(x, Int), x; kwargs...)
708+
709+
function Base.sortperm!(
710+
ix::AnyTracedRArray{Int,N},
711+
x::AnyTracedRArray{<:Any,N};
712+
dims::Integer,
713+
lt=isless,
714+
by=identity,
715+
rev::Bool=false,
716+
kwargs..., # TODO: implement `order` and `alg` kwargs
717+
) where {N}
718+
comparator =
719+
rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b))
720+
idxs = Ops.constant(collect(LinearIndices(x)))
721+
_, res = Ops.sort(x, idxs; dimension=dims, comparator)
722+
set_mlir_data!(ix, get_mlir_data(res))
723+
return ix
724+
end
725+
707726
end

0 commit comments

Comments
 (0)