From 480f43472be9797142402dbeb5ee79b307286f81 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 26 Jul 2024 13:58:19 -0400 Subject: [PATCH 01/17] Add Enzyme reverse rules --- Project.toml | 7 +++++++ test/runtests.jl | 13 ++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2c44fbd..ae4bc2a 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,12 @@ version = "2.9.4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[weakdeps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + +[extensions] +QuadGKEnzymeExt = "Enzyme" + [compat] DataStructures = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19" julia = "1.2" @@ -15,3 +21,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test"] + diff --git a/test/runtests.jl b/test/runtests.jl index 4e156f2..9e190b0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -# This file contains code that was formerly part of Julia. License is MIT: http://julialang.org/license + # This file contains code that was formerly part of Julia. License is MIT: http://julialang.org/license using QuadGK, LinearAlgebra, Test @@ -426,3 +426,14 @@ quadgk_segbuf_printnull(args...; kws...) = quadgk_segbuf_print(devnull, args...; @inferred QuadGK.to_segbuf([0,1]) @inferred QuadGK.to_segbuf([(0,1+3im)]) end + + +f1(x) -> quadgk(cos, 0., x) +f2(x) -> quadgk(cos, x, 1) +f3(x) -> quadgk(y->cos(x * y), 0., 1.) + +@testset "Enzyme" begin + @test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1, Active(0.3))[1][1] + @test -cos(0.3) ≈ Enzyme.autodiff(Reverse, f2, Active(0.3))[1][1] + @test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3, Active(0.3))[1][1] +end \ No newline at end of file From b5762c72cc156040bdf6a6f429b2ea9da7833259 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 26 Jul 2024 13:58:46 -0400 Subject: [PATCH 02/17] fix --- test/runtests.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 9e190b0..9e718dc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -428,9 +428,9 @@ quadgk_segbuf_printnull(args...; kws...) = quadgk_segbuf_print(devnull, args...; end -f1(x) -> quadgk(cos, 0., x) -f2(x) -> quadgk(cos, x, 1) -f3(x) -> quadgk(y->cos(x * y), 0., 1.) +f1(x) = quadgk(cos, 0., x) +f2(x) = quadgk(cos, x, 1) +f3(x) = quadgk(y->cos(x * y), 0., 1.) @testset "Enzyme" begin @test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1, Active(0.3))[1][1] From 80bbf0945d23514e061d347f73f9bc43b0869462 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 26 Jul 2024 14:02:27 -0400 Subject: [PATCH 03/17] fixup --- test/runtests.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 9e718dc..7cf2848 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ # This file contains code that was formerly part of Julia. License is MIT: http://julialang.org/license -using QuadGK, LinearAlgebra, Test +using QuadGK, LinearAlgebra, Test, Enzyme ≅(x::Tuple, y::Tuple; kws...) = all(a -> isapprox(a[1],a[2]; kws...), zip(x,y)) @@ -428,9 +428,9 @@ quadgk_segbuf_printnull(args...; kws...) = quadgk_segbuf_print(devnull, args...; end -f1(x) = quadgk(cos, 0., x) -f2(x) = quadgk(cos, x, 1) -f3(x) = quadgk(y->cos(x * y), 0., 1.) +f1(x) = quadgk(cos, 0., x)[1] +f2(x) = quadgk(cos, x, 1)[1] +f3(x) = quadgk(y->cos(x * y), 0., 1.)[1] @testset "Enzyme" begin @test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1, Active(0.3))[1][1] From 0978621850af24e2847bd668c1ffa57a7ac557fe Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 26 Jul 2024 14:42:09 -0400 Subject: [PATCH 04/17] Add test project file --- test/Project.toml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 test/Project.toml diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..cefe362 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,4 @@ +[deps] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" \ No newline at end of file From 3fae4caaaa85b90adb5cce149a161dd4209d0eb0 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 26 Jul 2024 14:44:13 -0400 Subject: [PATCH 05/17] gate per extension package --- test/runtests.jl | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 7cf2848..5b6e6d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ # This file contains code that was formerly part of Julia. License is MIT: http://julialang.org/license -using QuadGK, LinearAlgebra, Test, Enzyme +using QuadGK, LinearAlgebra, Test ≅(x::Tuple, y::Tuple; kws...) = all(a -> isapprox(a[1],a[2]; kws...), zip(x,y)) @@ -427,13 +427,16 @@ quadgk_segbuf_printnull(args...; kws...) = quadgk_segbuf_print(devnull, args...; @inferred QuadGK.to_segbuf([(0,1+3im)]) end - -f1(x) = quadgk(cos, 0., x)[1] -f2(x) = quadgk(cos, x, 1)[1] -f3(x) = quadgk(y->cos(x * y), 0., 1.)[1] - -@testset "Enzyme" begin - @test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1, Active(0.3))[1][1] - @test -cos(0.3) ≈ Enzyme.autodiff(Reverse, f2, Active(0.3))[1][1] - @test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3, Active(0.3))[1][1] +# Extension package only supported in 1.9+ +@static if VERSION >= v"1.9" + using Enzyme + f1(x) = quadgk(cos, 0., x)[1] + f2(x) = quadgk(cos, x, 1)[1] + f3(x) = quadgk(y->cos(x * y), 0., 1.)[1] + + @testset "Enzyme" begin + @test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1, Active(0.3))[1][1] + @test -cos(0.3) ≈ Enzyme.autodiff(Reverse, f2, Active(0.3))[1][1] + @test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3, Active(0.3))[1][1] + end end \ No newline at end of file From 422fa47cdecee0c6d414ed70171834d196827c12 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 26 Jul 2024 14:46:44 -0400 Subject: [PATCH 06/17] Update test/runtests.jl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mosè Giordano --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 5b6e6d2..5d759da 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ - # This file contains code that was formerly part of Julia. License is MIT: http://julialang.org/license +# This file contains code that was formerly part of Julia. License is MIT: http://julialang.org/license using QuadGK, LinearAlgebra, Test From 3ed66d280a129152dc916e5e34a7e1055946b917 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 26 Jul 2024 14:46:52 -0400 Subject: [PATCH 07/17] Update test/runtests.jl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mosè Giordano --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 5d759da..f6816c1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -439,4 +439,4 @@ end @test -cos(0.3) ≈ Enzyme.autodiff(Reverse, f2, Active(0.3))[1][1] @test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3, Active(0.3))[1][1] end -end \ No newline at end of file +end From c93b6e6bd40d5cbff11bc63c26393ebd6b39a410 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 26 Jul 2024 14:47:04 -0400 Subject: [PATCH 08/17] Update test/Project.toml MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mosè Giordano --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index cefe362..481c28e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,4 @@ [deps] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" \ No newline at end of file +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From 7a02b26ee35324e49ce800f6040c37643e98bf4e Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 26 Jul 2024 14:47:10 -0400 Subject: [PATCH 09/17] Update Project.toml MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mosè Giordano --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index ae4bc2a..1b6a6eb 100644 --- a/Project.toml +++ b/Project.toml @@ -21,4 +21,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test"] - From 0d981f92d1ee1f2c286d10e103cc6667c9df9f15 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 26 Jul 2024 14:47:49 -0400 Subject: [PATCH 10/17] Add actual file --- ext/QuadGKEnzymeExt.jl | 100 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 ext/QuadGKEnzymeExt.jl diff --git a/ext/QuadGKEnzymeExt.jl b/ext/QuadGKEnzymeExt.jl new file mode 100644 index 0000000..7723bf7 --- /dev/null +++ b/ext/QuadGKEnzymeExt.jl @@ -0,0 +1,100 @@ + +module QuadGKEnzymeExt + +using QuadGK, Enzyme, LinearAlgebra + +function Enzyme.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f, segs::Annotation{T}...; kws...) where {RT, T<:Real} + prims = map(x->x.val, segs) + + retres, segbuf = if f isa Const + if EnzymeRules.needs_primal(config) + quadgk(f.val, prims...; kws...), nothing + else + nothing + end + else + I, E, segbuf = quadgk_segbuf(f.val, prims...; kws...) + if EnzymeRules.needs_primal(config) + (I, E), segbuf + else + nothing, segbuf + end + end + + dres = if !Enzyme.EnzymeRules.needs_shadow(config) + nothing + elseif EnzymeRules.width(config) == 1 + zero.(res...) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + zero.(res...) + end + end + + cache = if RT <: Duplicated || RT <: DuplicatedNoNeed || RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed + dres + else + nothing + end + cache2 = segbuf, cache + + return Enzyme.EnzymeRules.AugmentedReturn{ + Enzyme.EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing, + Enzyme.EnzymeRules.needs_shadow(config) ? (Enzyme.EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{Enzyme.EnzymeRules.width(config), eltype(RT)}) : Nothing, + typeof(cache2) + }(retres, dres, cache2) +end + +function call(f, x) + f(x) +end + +struct ClosureVector{F} + f::F +end + +@inline function guaranteed_nonactive(::Type{T}) where T + rt = Enzyme.Compiler.active_reg_inner(T, (), nothing) + return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState +end + +function Base.:+(a::ClosureVector, b::ClosureVector) + Enzyme.Compiler.recursive_add(a, b, identity, guaranteed_nonactive) +end + +function Base.:-(a::ClosureVector, b::ClosureVector) + Enzyme.Compiler.recursive_add(b, b, x->-x, guaranteed_nonactive) +end + +function Base.:*(a::Number, b::ClosureVector) + # b + (a-1) * b = a * b + Enzyme.Compiler.recursive_add(b, b, x->(a-1)*x, guaranteed_nonactive) +end + +function Base.:*(a::ClosureVector, b::Number) + return b*a +end + +function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f, segs::Annotation{T}...; kws...) where {T<:Real} + df = if f isa Const + nothing + else + segbuf = cache[1] + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T}) + _df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x + tape, prim, shad = fwd(Const(call), f, Const(x)) + drev = rev(Const(call), f, Const(x), dres.val[1], tape) + return ClosureVector(drev[1][1]) + end + _df.f + end + dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1]) + dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1]) + return (df, # f + dsegs1, + ntuple(i -> nothing, Val(length(segs)-2))..., + dsegsn) +end + +end # module \ No newline at end of file From f48b499e698dbadfdfd6c5bf8bd4f748f8310165 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 26 Jul 2024 14:48:45 -0400 Subject: [PATCH 11/17] Update QuadGKEnzymeExt.jl --- ext/QuadGKEnzymeExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/QuadGKEnzymeExt.jl b/ext/QuadGKEnzymeExt.jl index 7723bf7..fd5c89c 100644 --- a/ext/QuadGKEnzymeExt.jl +++ b/ext/QuadGKEnzymeExt.jl @@ -64,7 +64,7 @@ function Base.:+(a::ClosureVector, b::ClosureVector) end function Base.:-(a::ClosureVector, b::ClosureVector) - Enzyme.Compiler.recursive_add(b, b, x->-x, guaranteed_nonactive) + Enzyme.Compiler.recursive_add(a, b, x->-x, guaranteed_nonactive) end function Base.:*(a::Number, b::ClosureVector) @@ -97,4 +97,4 @@ function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres:: dsegsn) end -end # module \ No newline at end of file +end # module From e0f3256cfe8217f711e906012305b9c61aa73e29 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 26 Jul 2024 16:36:30 -0400 Subject: [PATCH 12/17] Update ext/QuadGKEnzymeExt.jl Co-authored-by: Steven G. Johnson --- ext/QuadGKEnzymeExt.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ext/QuadGKEnzymeExt.jl b/ext/QuadGKEnzymeExt.jl index fd5c89c..2e4f79f 100644 --- a/ext/QuadGKEnzymeExt.jl +++ b/ext/QuadGKEnzymeExt.jl @@ -50,6 +50,14 @@ function call(f, x) f(x) end +# Wrapper around a function f that allows it to act as a vector space, and hence be usable as +# an integrand, where the vector operations act on the closed-over parameters of f that are +# begin differentiated with respect to. In particular, if we have a closure f = x -> g(x, p), and we want +# to differentiate with respect to p, then our reverse (vJp) rule needs an integrand given by the +# Jacobian-vector product (pullback) vᵀ∂g/∂p. But Enzyme wraps this in a closure so that it is the +# same "shape" as f, whereas to integrate it we need to be able to treat it as a vector space. +# ClosureVector calls Enzyme.Compiler.recursive_add, which is an internal function that "unwraps" +# the closure to access the internal state, which can then be added/subtracted/scaled. struct ClosureVector{F} f::F end From 296a866b632881fdebd7f4228099485e907ff706 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 29 Jul 2024 14:49:41 -0400 Subject: [PATCH 13/17] fixup --- ext/QuadGKEnzymeExt.jl | 40 ++++++++++++++++++++++++++++++++++------ test/runtests.jl | 17 +++++++++++++++++ 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/ext/QuadGKEnzymeExt.jl b/ext/QuadGKEnzymeExt.jl index 2e4f79f..baf78db 100644 --- a/ext/QuadGKEnzymeExt.jl +++ b/ext/QuadGKEnzymeExt.jl @@ -67,17 +67,17 @@ end return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end -function Base.:+(a::ClosureVector, b::ClosureVector) - Enzyme.Compiler.recursive_add(a, b, identity, guaranteed_nonactive) +function Base.:+(a::CV, b::CV) where {CV <: ClosureVector} + Enzyme.Compiler.recursive_add(a, b, identity, guaranteed_nonactive)::CV end -function Base.:-(a::ClosureVector, b::ClosureVector) - Enzyme.Compiler.recursive_add(a, b, x->-x, guaranteed_nonactive) +function Base.:-(a::CV, b::CV) where {CV <: ClosureVector} + Enzyme.Compiler.recursive_add(a, b, x->-x, guaranteed_nonactive)::CV end -function Base.:*(a::Number, b::ClosureVector) +function Base.:*(a::Number, b::CV) where {CV <: ClosureVector} # b + (a-1) * b = a * b - Enzyme.Compiler.recursive_add(b, b, x->(a-1)*x, guaranteed_nonactive) + Enzyme.Compiler.recursive_add(b, b, x->(a-1)*x, guaranteed_nonactive)::CV end function Base.:*(a::ClosureVector, b::Number) @@ -105,4 +105,32 @@ function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres:: dsegsn) end +function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Type{<:Union{Duplicated, BatchDuplicated}}, cache, f, segs::Annotation{T}...; kws...) where {T<:Real} + dres = cache[2] + @show dres + df = if f isa Const + nothing + else + segbuf = cache[1] + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T}) + _df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x + @show x + tape, prim, shad = fwd(Const(call), f, Const(x)) + @show prim, shad + shad .= dres + drev = rev(Const(call), f, Const(x), tape) + @show drev + return ClosureVector(drev[1][1]) + end + _df.f + end + dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres) + dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres) + Enzyme.make_zero!(dres) + return (df, # f + dsegs1, + ntuple(i -> nothing, Val(length(segs)-2))..., + dsegsn) +end + end # module diff --git a/test/runtests.jl b/test/runtests.jl index f6816c1..317c07a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -434,9 +434,26 @@ end f2(x) = quadgk(cos, x, 1)[1] f3(x) = quadgk(y->cos(x * y), 0., 1.)[1] + f1_count(x) = quadgk_count(cos, 0., x)[1] + f2_count(x) = quadgk_count(cos, x, 1)[1] + f3_count(x) = quadgk_count(y->cos(x * y), 0., 1.)[1] + + f_vec(x) = sum(quadgk(y->[cos(x[1] * y), cos(x[2] * y)], 0., 1.)[1]) + @testset "Enzyme" begin @test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1, Active(0.3))[1][1] @test -cos(0.3) ≈ Enzyme.autodiff(Reverse, f2, Active(0.3))[1][1] @test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3, Active(0.3))[1][1] + + @test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1_count, Active(0.3))[1][1] + @test -cos(0.3) ≈ Enzyme.autodiff(Reverse, f2_count, Active(0.3))[1][1] + @test_broken (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3_count, Active(0.3))[1][1] + + x = [0.3, 0.7] + dx = [0.0, 0.0] + f_vec(x) + # TODO custom rule with mixed vector returns not yet supported + @test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Reverse, f_vec, Duplicated(x, dx)) + # @test dx ≈ [(0.3 * cos(0.3) - sin(0.3))/(0.3*0.3), (0.7 * cos(0.7) - sin(0.7))/(0.7*0.7)] end end From da1454dce1cf0498a73eaa284d10478bd02ccc3b Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 29 Jul 2024 20:09:21 -0400 Subject: [PATCH 14/17] fixup --- ext/QuadGKEnzymeExt.jl | 8 ++------ src/api.jl | 18 ++++++++++++------ test/runtests.jl | 2 +- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/ext/QuadGKEnzymeExt.jl b/ext/QuadGKEnzymeExt.jl index baf78db..26d4263 100644 --- a/ext/QuadGKEnzymeExt.jl +++ b/ext/QuadGKEnzymeExt.jl @@ -84,7 +84,7 @@ function Base.:*(a::ClosureVector, b::Number) return b*a end -function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f, segs::Annotation{T}...; kws...) where {T<:Real} +function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T<:Real} df = if f isa Const nothing else @@ -105,21 +105,17 @@ function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres:: dsegsn) end -function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Type{<:Union{Duplicated, BatchDuplicated}}, cache, f, segs::Annotation{T}...; kws...) where {T<:Real} +function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Type{<:Union{Duplicated, BatchDuplicated}}, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T<:Real} dres = cache[2] - @show dres df = if f isa Const nothing else segbuf = cache[1] fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T}) _df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x - @show x tape, prim, shad = fwd(Const(call), f, Const(x)) - @show prim, shad shad .= dres drev = rev(Const(call), f, Const(x), tape) - @show drev return ClosureVector(drev[1][1]) end _df.f diff --git a/src/api.jl b/src/api.jl index 6087bbc..b375260 100644 --- a/src/api.jl +++ b/src/api.jl @@ -132,6 +132,15 @@ function quadgk!(f!, result, a::T,b::T,c::T...; atol=nothing, rtol=nothing, maxe return quadgk(f, a, b, c...; atol=atol, rtol=rtol, maxevals=maxevals, order=order, norm=norm, segbuf=segbuf, eval_segbuf=eval_segbuf) end +struct Counter{F} + f::F + count::Base.RefValue{Int} +end +function (c::Counter{F})(args...) where F + c.count[] += 1 + c.f(args...) +end + """ quadgk_count(f, args...; kws...) @@ -146,12 +155,9 @@ it may be possible to mathematically transform the problem in some way to improve the convergence rate. """ function quadgk_count(f, args...; kws...) - count = 0 - i = quadgk(args...; kws...) do x - count += 1 - f(x) - end - return (i..., count) + counter = Counter(f, Ref(0)) + i = quadgk(counter, args...; kws...) + return (i..., counter.count[]) end """ diff --git a/test/runtests.jl b/test/runtests.jl index 317c07a..3e1bbc9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -447,7 +447,7 @@ end @test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1_count, Active(0.3))[1][1] @test -cos(0.3) ≈ Enzyme.autodiff(Reverse, f2_count, Active(0.3))[1][1] - @test_broken (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3_count, Active(0.3))[1][1] + @test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3_count, Active(0.3))[1][1] x = [0.3, 0.7] dx = [0.0, 0.0] From f277f08b4f9cf61353e06bef136c114ea38ac4e3 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 29 Jul 2024 20:11:05 -0400 Subject: [PATCH 15/17] Bump minimum to 1.9 --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 62fdcae..b62607d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -22,7 +22,7 @@ jobs: fail-fast: false matrix: version: - - '1.2' + - '1.9' - '1' # - 'nightly' os: From a6e4e04e212405bf0bf994ff0944949eda869bf6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 30 Jul 2024 17:57:14 -0400 Subject: [PATCH 16/17] Update QuadGKEnzymeExt.jl --- ext/QuadGKEnzymeExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/QuadGKEnzymeExt.jl b/ext/QuadGKEnzymeExt.jl index 26d4263..d10d21a 100644 --- a/ext/QuadGKEnzymeExt.jl +++ b/ext/QuadGKEnzymeExt.jl @@ -3,7 +3,7 @@ module QuadGKEnzymeExt using QuadGK, Enzyme, LinearAlgebra -function Enzyme.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f, segs::Annotation{T}...; kws...) where {RT, T<:Real} +function Enzyme.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f, segs::Annotation{T}...; kws...) where {RT, T} prims = map(x->x.val, segs) retres, segbuf = if f isa Const @@ -84,7 +84,7 @@ function Base.:*(a::ClosureVector, b::Number) return b*a end -function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T<:Real} +function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T} df = if f isa Const nothing else @@ -105,7 +105,7 @@ function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres:: dsegsn) end -function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Type{<:Union{Duplicated, BatchDuplicated}}, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T<:Real} +function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Type{<:Union{Duplicated, BatchDuplicated}}, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T} dres = cache[2] df = if f isa Const nothing From 5a1f362954f33d9f41c9a6435d2aa244f6dd32b2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 30 Jul 2024 18:02:25 -0400 Subject: [PATCH 17/17] Update runtests.jl --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 3e1bbc9..dae5b9c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -452,7 +452,7 @@ end x = [0.3, 0.7] dx = [0.0, 0.0] f_vec(x) - # TODO custom rule with mixed vector returns not yet supported + # TODO custom rule with mixed vector returns not yet supported x/ref https://github.com/EnzymeAD/Enzyme.jl/issues/1692 @test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Reverse, f_vec, Duplicated(x, dx)) # @test dx ≈ [(0.3 * cos(0.3) - sin(0.3))/(0.3*0.3), (0.7 * cos(0.7) - sin(0.7))/(0.7*0.7)] end