From 91fdd0590e60236c616edcfcbd040064637ca9bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Fri, 13 Dec 2024 17:26:58 +0100 Subject: [PATCH 1/5] `stablehlo.sort` Ops --- src/Ops.jl | 83 +++++++++++++++++++++++++++++++++++++++-------------- test/ops.jl | 12 ++++++++ 2 files changed, 73 insertions(+), 22 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 013e0dbc8e..135409041c 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -914,28 +914,67 @@ end # return TracedRArray{T,N}((), res, size(x)) # end -# sorting ops -# TODO need to trace over `comparator` -# function sort( -# x::TracedRArray{T,N}; -# comparator, -# dimension=-1, -# is_stable=false, -# location=mlir_stacktrace("sort", @__FILE__, @__LINE__), -# ) where {T,N} -# dimension = MLIR.IR.Attribute(dimension) -# is_stable = MLIR.IR.Attribute(is_stable) -# res = MLIR.IR.result( -# stablehlo.sort( -# x.mlir_data; -# result=mlir_type(TracedRArray{T,N}, size(x)), -# dimension, -# is_stable, -# location, -# ), -# ) -# return TracedRArray{T,N}((), res, size(x)) -# end +function sort( + x::TracedRArray{T,N}; + comparator::Function, + dimension=1, + is_stable=false, + location=mlir_stacktrace("sort", @__FILE__, @__LINE__), +) where {T,N} + #C4: + @assert 0 < dimension <= ndims(x) "$x invalid dimension" + + #C5: + method = Base.methods( + comparator, (Reactant.TracedRArray{T,N}, Reactant.TracedRArray{T,N}) + ) + @assert size(method, 1) != 0 error("$comparator is not a valid comparator") + @assert size(method, 1) == 1 error("$comparator ambiguous candidates") + #TODO: move to @trace + (a, b) = (ConcreteRNumber(T(0)), ConcreteRNumber(T(0))) + func = Reactant.make_mlir_fn(comparator, (a, b), (), "main"; no_args_in_result=true)[2] + + fn_name = String( + MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())) + ) + @assert fn_name == "main" "$comparator: no function generated" + @assert MLIR.IR.nregions(func) == 1 + ftype_attr = MLIR.IR.attr(func, "function_type") + ftype = MLIR.IR.Type(ftype_attr) + @assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) error( + "$comparator return type is not tensor" + ) + + #TODO: move takebody to utils? + comparator = MLIR.IR.Region() + MLIR.API.mlirRegionTakeBody(comparator, MLIR.IR.region(func, 1)) + MLIR.IR.rmfromparent!(func) + global leaked = comparator + for block in MLIR.IR.BlockIterator(comparator) + return_op = MLIR.IR.terminator(block) + MLIR.IR.name(return_op) == "func.return" || continue + operands = [MLIR.IR.operand(return_op, i) for i in 1:MLIR.IR.noperands(return_op)] + MLIR.IR.block!(block) do + MLIR.Dialects.stablehlo.return_(operands; location=MLIR.IR.location(return_op)) + MLIR.IR.rmfromparent!(return_op) + end + end + + dimension = MLIR.IR.Attribute(dimension - 1) + is_stable = MLIR.IR.Attribute(is_stable) + + res = MLIR.IR.result( + stablehlo.sort( + [x.mlir_data]; + result_0=[mlir_type(TracedRArray{T,N}, size(x))], + dimension, + is_stable, + comparator, + location, + ), + ) + return TracedRArray{T,N}((), res, size(x)) +end function top_k( x::TracedRArray{T,N}, k; location=mlir_stacktrace("top_k", @__FILE__, @__LINE__) diff --git a/test/ops.jl b/test/ops.jl index 07f911e88b..81e6d8feb8 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -646,6 +646,18 @@ end end end +@testset "sort" begin + basic_sort(x, dimension) = Reactant.Ops.sort(x; comparator=(a, b) -> a < b, dimension) + for i in 1:3 + t_size = tuple(fill(10, (i,))...) + x = Reactant.to_rarray(randn(t_size)) + + for j in 1:i + @test (i == 1 ? sort(x) : sort(x; dims=j)) == @jit basic_sort(x, j) + end + end +end + @testset "slice" begin x = ConcreteRArray([1, 2, 3, 4]) @test [2, 3] == @jit Ops.slice(x, [2], [3]) From e1808968085ef132483d440c5838d6f9fba8eee5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Fri, 13 Dec 2024 19:25:53 +0100 Subject: [PATCH 2/5] review --- src/Ops.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 135409041c..ba883eb8f0 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -916,7 +916,7 @@ end function sort( x::TracedRArray{T,N}; - comparator::Function, + comparator, dimension=1, is_stable=false, location=mlir_stacktrace("sort", @__FILE__, @__LINE__), @@ -924,12 +924,6 @@ function sort( #C4: @assert 0 < dimension <= ndims(x) "$x invalid dimension" - #C5: - method = Base.methods( - comparator, (Reactant.TracedRArray{T,N}, Reactant.TracedRArray{T,N}) - ) - @assert size(method, 1) != 0 error("$comparator is not a valid comparator") - @assert size(method, 1) == 1 error("$comparator ambiguous candidates") #TODO: move to @trace (a, b) = (ConcreteRNumber(T(0)), ConcreteRNumber(T(0))) func = Reactant.make_mlir_fn(comparator, (a, b), (), "main"; no_args_in_result=true)[2] @@ -937,6 +931,7 @@ function sort( fn_name = String( MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())) ) + #C5: @assert fn_name == "main" "$comparator: no function generated" @assert MLIR.IR.nregions(func) == 1 ftype_attr = MLIR.IR.attr(func, "function_type") @@ -949,7 +944,6 @@ function sort( comparator = MLIR.IR.Region() MLIR.API.mlirRegionTakeBody(comparator, MLIR.IR.region(func, 1)) MLIR.IR.rmfromparent!(func) - global leaked = comparator for block in MLIR.IR.BlockIterator(comparator) return_op = MLIR.IR.terminator(block) MLIR.IR.name(return_op) == "func.return" || continue From ab3a6530412eae2b7e691d810b7eccab43419fb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Sun, 15 Dec 2024 02:46:47 +0100 Subject: [PATCH 3/5] use `return_dialect` --- src/Ops.jl | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index ba883eb8f0..9a6c50106d 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -926,33 +926,22 @@ function sort( #TODO: move to @trace (a, b) = (ConcreteRNumber(T(0)), ConcreteRNumber(T(0))) - func = Reactant.make_mlir_fn(comparator, (a, b), (), "main"; no_args_in_result=true)[2] - + func = Reactant.make_mlir_fn(comparator, (a, b), (), "main"; no_args_in_result=true, return_dialect=:stablehlo)[2] + @assert MLIR.IR.nregions(func) == 1 fn_name = String( MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())) ) #C5: @assert fn_name == "main" "$comparator: no function generated" - @assert MLIR.IR.nregions(func) == 1 ftype_attr = MLIR.IR.attr(func, "function_type") ftype = MLIR.IR.Type(ftype_attr) @assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) error( "$comparator return type is not tensor" ) - #TODO: move takebody to utils? comparator = MLIR.IR.Region() MLIR.API.mlirRegionTakeBody(comparator, MLIR.IR.region(func, 1)) MLIR.IR.rmfromparent!(func) - for block in MLIR.IR.BlockIterator(comparator) - return_op = MLIR.IR.terminator(block) - MLIR.IR.name(return_op) == "func.return" || continue - operands = [MLIR.IR.operand(return_op, i) for i in 1:MLIR.IR.noperands(return_op)] - MLIR.IR.block!(block) do - MLIR.Dialects.stablehlo.return_(operands; location=MLIR.IR.location(return_op)) - MLIR.IR.rmfromparent!(return_op) - end - end dimension = MLIR.IR.Attribute(dimension - 1) is_stable = MLIR.IR.Attribute(is_stable) From 30a8da47fcc856c55937bc24df38f162d92ca08f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Wed, 18 Dec 2024 23:47:21 +0100 Subject: [PATCH 4/5] feedback --- src/Ops.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index e35de1ce75..e7ffe033ae 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -950,7 +950,7 @@ end # return TracedRArray{T,N}((), res, size(x)) # end -function sort( +@noinline function sort( x::TracedRArray{T,N}; comparator, dimension=1, @@ -960,9 +960,8 @@ function sort( #C4: @assert 0 < dimension <= ndims(x) "$x invalid dimension" - #TODO: move to @trace - (a, b) = (ConcreteRNumber(T(0)), ConcreteRNumber(T(0))) - func = Reactant.make_mlir_fn(comparator, (a, b), (), "main"; no_args_in_result=true, return_dialect=:stablehlo)[2] + (a, b) = (Reactant.ConcreteRNumber(T(0)), Reactant.ConcreteRNumber(T(0))) + func = Reactant.TracedUtils.make_mlir_fn(comparator, (a, b), (), "main"; no_args_in_result=true, return_dialect=:stablehlo)[2] @assert MLIR.IR.nregions(func) == 1 fn_name = String( MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())) From 2b03cc0416007d4a05e026c4a32e78d9b77d7edd Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Thu, 19 Dec 2024 01:51:37 +0000 Subject: [PATCH 5/5] CompatHelper: add new compat entry for HypothesisTests at version 0.11 for package test, (keep existing compat) --- test/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Project.toml b/test/Project.toml index d8861a1aae..5d0bc81512 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -34,6 +34,7 @@ Enzyme = "0.13.21" FFTW = "1.8" Flux = "0.15, 0.16" Functors = "0.5" +HypothesisTests = "0.11" InteractiveUtils = "1.10" LinearAlgebra = "1.10" Lux = "1.4.1"