Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enzyme reverse rules #110

Merged
merged 17 commits into from
Jul 31, 2024
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ version = "2.9.4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
QuadGKEnzymeExt = "Enzyme"

[compat]
DataStructures = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19"
julia = "1.2"
Expand Down
100 changes: 100 additions & 0 deletions ext/QuadGKEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@

module QuadGKEnzymeExt

using QuadGK, Enzyme, LinearAlgebra

function Enzyme.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f, segs::Annotation{T}...; kws...) where {RT, T<:Real}
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
prims = map(x->x.val, segs)

Check warning on line 7 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L6-L7

Added lines #L6 - L7 were not covered by tests

retres, segbuf = if f isa Const
if EnzymeRules.needs_primal(config)
quadgk(f.val, prims...; kws...), nothing

Check warning on line 11 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L9-L11

Added lines #L9 - L11 were not covered by tests
else
nothing

Check warning on line 13 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L13

Added line #L13 was not covered by tests
end
else
I, E, segbuf = quadgk_segbuf(f.val, prims...; kws...)
if EnzymeRules.needs_primal(config)
(I, E), segbuf

Check warning on line 18 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L16-L18

Added lines #L16 - L18 were not covered by tests
else
nothing, segbuf

Check warning on line 20 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L20

Added line #L20 was not covered by tests
end
end

dres = if !Enzyme.EnzymeRules.needs_shadow(config)
nothing
elseif EnzymeRules.width(config) == 1
zero.(res...)

Check warning on line 27 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L24-L27

Added lines #L24 - L27 were not covered by tests
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
zero.(res...)

Check warning on line 31 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L29-L31

Added lines #L29 - L31 were not covered by tests
end
end

cache = if RT <: Duplicated || RT <: DuplicatedNoNeed || RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed
dres

Check warning on line 36 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L35-L36

Added lines #L35 - L36 were not covered by tests
else
nothing

Check warning on line 38 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L38

Added line #L38 was not covered by tests
end
cache2 = segbuf, cache

Check warning on line 40 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L40

Added line #L40 was not covered by tests

return Enzyme.EnzymeRules.AugmentedReturn{

Check warning on line 42 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L42

Added line #L42 was not covered by tests
Enzyme.EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing,
Enzyme.EnzymeRules.needs_shadow(config) ? (Enzyme.EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{Enzyme.EnzymeRules.width(config), eltype(RT)}) : Nothing,
typeof(cache2)
}(retres, dres, cache2)
end

function call(f, x)
f(x)

Check warning on line 50 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L49-L50

Added lines #L49 - L50 were not covered by tests
end

struct ClosureVector{F}
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
f::F
end

@inline function guaranteed_nonactive(::Type{T}) where T
rt = Enzyme.Compiler.active_reg_inner(T, (), nothing)
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
end

function Base.:+(a::ClosureVector, b::ClosureVector)
Enzyme.Compiler.recursive_add(a, b, identity, guaranteed_nonactive)
end
wsmoses marked this conversation as resolved.
Show resolved Hide resolved

function Base.:-(a::ClosureVector, b::ClosureVector)
Enzyme.Compiler.recursive_add(a, b, x->-x, guaranteed_nonactive)
end

function Base.:*(a::Number, b::ClosureVector)
# b + (a-1) * b = a * b
Enzyme.Compiler.recursive_add(b, b, x->(a-1)*x, guaranteed_nonactive)
end

function Base.:*(a::ClosureVector, b::Number)
return b*a
end

function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f, segs::Annotation{T}...; kws...) where {T<:Real}
df = if f isa Const
nothing
else
segbuf = cache[1]
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T})
_df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x
tape, prim, shad = fwd(Const(call), f, Const(x))
drev = rev(Const(call), f, Const(x), dres.val[1], tape)
return ClosureVector(drev[1][1])
end
_df.f
end
dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1])
dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1])
return (df, # f
dsegs1,
ntuple(i -> nothing, Val(length(segs)-2))...,
dsegsn)
end

end # module
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[deps]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14 changes: 14 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,17 @@ quadgk_segbuf_printnull(args...; kws...) = quadgk_segbuf_print(devnull, args...;
@inferred QuadGK.to_segbuf([0,1])
@inferred QuadGK.to_segbuf([(0,1+3im)])
end

# Extension package only supported in 1.9+
@static if VERSION >= v"1.9"
using Enzyme
f1(x) = quadgk(cos, 0., x)[1]
f2(x) = quadgk(cos, x, 1)[1]
f3(x) = quadgk(y->cos(x * y), 0., 1.)[1]

@testset "Enzyme" begin
@test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1, Active(0.3))[1][1]
@test -cos(0.3) ≈ Enzyme.autodiff(Reverse, f2, Active(0.3))[1][1]
@test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3, Active(0.3))[1][1]
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
end
end
Loading