Skip to content

Commit

Permalink
Test against Enzyme (#318)
Browse files Browse the repository at this point in the history
* Test against Enzyme

* Run JuliaFormatter

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Disable some CI tests for Enzyme testing purposes

* Import ChainRule for find_alpha for Enzyme

* Remove unnecessary whitespace

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Fixes to Enzyme extension

* Enzyme test fixes

* Remove unnecessary Enzyme settings

* Code style

* Apply suggestions from reviewdog

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Check broken symbols in tests

* Add :Tapir to list of valid broken test marks

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>

* Add Enzyme compat bounds

* Code style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Adjust when Enzyme tests are run

* Improve Enzyme brokenness check

* Don't check Enzyme AD for Julia < v1.10

* Reenable CI for Julia 1.6

* Misc tiny typos

* Don't run Enzyme at all for Julia < v1.10

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 11, 2024
1 parent bfd5ce4 commit b79c425
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/AD.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
arch:
- x64
AD:
- Enzyme
- ForwardDiff
- Tapir
- Tracker
Expand All @@ -29,6 +30,10 @@ jobs:
exclude:
- version: 1.6
AD: Tapir
# TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see
# discussion in https://github.com/TuringLang/Bijectors.jl/pull.
- version: 1.6
AD: Enzyme
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand All @@ -34,6 +35,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BijectorsDistributionsADExt = "DistributionsAD"
BijectorsEnzymeExt = "Enzyme"
BijectorsForwardDiffExt = "ForwardDiff"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsReverseDiffExt = "ReverseDiff"
Expand All @@ -50,6 +52,7 @@ Compat = "3.46, 4.2"
Distributions = "0.25.33"
DistributionsAD = "0.6"
DocStringExtensions = "0.9"
Enzyme = "0.12.22"
ForwardDiff = "0.10"
Functors = "0.1, 0.2, 0.3, 0.4"
InverseFunctions = "0.1"
Expand All @@ -69,6 +72,7 @@ julia = "1.6"

[extras]
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand Down
14 changes: 14 additions & 0 deletions ext/BijectorsEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module BijectorsEnzymeExt

if isdefined(Base, :get_extension)
using Enzyme: @import_frule, @import_rrule
using Bijectors: find_alpha
else
using ..Enzyme: @import_frule, @import_rrule
using ..Bijectors: find_alpha
end

@import_rrule typeof(find_alpha) Real Real Real
@import_frule typeof(find_alpha) Real Real Real

end
2 changes: 1 addition & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function transform(t::Transform, x)
res = with_logabsdet_jacobian(t, x)
if res isa ChangesOfVariables.NoLogAbsDetJacobian
error(
"`transform` not implemented for $(typeof(f)); implement `transform` and/or `with_logabsdet_jacobian`.",
"`transform` not implemented for $(typeof(t)); implement `transform` and/or `with_logabsdet_jacobian`.",
)
end

Expand Down
4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -15,8 +16,8 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -31,6 +32,7 @@ ChangesOfVariables = "0.1"
Combinatorics = "1.0.2"
Compat = "3.46, 4.2"
DistributionsAD = "0.6.3"
Enzyme = "0.12.22"
FillArrays = "1"
FiniteDifferences = "0.11, 0.12"
ForwardDiff = "0.10.12"
Expand Down
48 changes: 48 additions & 0 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,24 @@
const AD = get(ENV, "AD", "All")

function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
for b in broken
if !(
b in (
:ForwardDiff,
:Zygote,
:Tapir,
:ReverseDiff,
:Enzyme,
:EnzymeForward,
:EnzymeReverse,
)
)
error("Unknown broken AD backend: $b")
end
end

finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1]
et = eltype(finitediff)

if AD == "All" || AD == "ForwardDiff"
if :ForwardDiff in broken
Expand Down Expand Up @@ -30,6 +47,37 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
end
end

# TODO(mhauru) The version bound should be relaxed once some Enzyme issues get
# sorted out. I think forward mode will remain broken for versions <= 1.6 due to
# some Julia bug. See https://github.com/EnzymeAD/Enzyme.jl/issues/1629 and
# discussion in https://github.com/TuringLang/Bijectors.jl/pull/318.
if (AD == "All" || AD == "Enzyme") && VERSION >= v"1.10"
forward_broken = :EnzymeForward in broken || :Enzyme in broken
reverse_broken = :EnzymeReverse in broken || :Enzyme in broken
if forward_broken
@test_broken(
collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) finitediff,
rtol = rtol,
atol = atol
)
else
@test(
collect(et, Enzyme.gradient(Enzyme.Forward, f, x)) finitediff,
rtol = rtol,
atol = atol
)
end
if reverse_broken
@test_broken(
Enzyme.gradient(Enzyme.Reverse, f, x) finitediff, rtol = rtol, atol = atol
)
else
@test(
Enzyme.gradient(Enzyme.Reverse, f, x) finitediff, rtol = rtol, atol = atol
)
end
end

if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10"
rule = Tapir.build_rrule(f, x; safety_on=false)
if :tapir in broken
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Bijectors
using ChainRulesTestUtils
using Combinatorics
using DistributionsAD
using Enzyme
using FiniteDifferences
using ForwardDiff
using Functors
Expand Down

0 comments on commit b79c425

Please sign in to comment.