Skip to content

Commit

Permalink
Merge pull request #42 from TuringLang/dw/enzyme
Browse files Browse the repository at this point in the history
Add Enzyme support
  • Loading branch information
yebai authored Feb 24, 2023
2 parents c2c4d8b + e7265af commit 1aee09e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,34 @@ function __init__()
return out
end
end
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin
include("compat/enzyme.jl")
export EnzymeAD

function AdvancedVI.grad!(
vo,
alg::VariationalInference{<:AdvancedVI.EnzymeAD},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) = if (q isa Distribution)
- vo(alg, update(q, θ), model, args...)
else
- vo(alg, q(θ), model, args...)
end
# Use `Enzyme.ReverseWithPrimal` once it is released:
# https://github.com/EnzymeAD/Enzyme.jl/pull/598
y = f(θ)
DiffResults.value!(out, y)
dy = DiffResults.gradient(out)
fill!(dy, 0)
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy))
return out
end
end
end

export
Expand Down
5 changes: 5 additions & 0 deletions src/compat/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
struct EnzymeAD <: ADBackend end
ADBackend(::Val{:enzyme}) = EnzymeAD
function setadbackend(::Val{:enzyme})
ADBACKEND[] = :enzyme
end

0 comments on commit 1aee09e

Please sign in to comment.