Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to DifferentiationInterface #29

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6' # Replace this with the minimum Julia version that your package supports.
- '1.10' # Replace this with the minimum Julia version that your package supports.
- '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia.
os:
- ubuntu-latest
Expand Down
25 changes: 6 additions & 19 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,43 +1,30 @@
name = "LogDensityProblemsAD"
uuid = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
authors = ["Tamás K. Papp <tkpapp@gmail.com>"]
version = "1.9.0"
version = "1.10.0"
gdalle marked this conversation as resolved.
Show resolved Hide resolved
gdalle marked this conversation as resolved.
Show resolved Hide resolved

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"

[weakdeps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
LogDensityProblemsADADTypesExt = "ADTypes"
LogDensityProblemsADEnzymeExt = "Enzyme"
LogDensityProblemsADFiniteDifferencesExt = "FiniteDifferences"
LogDensityProblemsADForwardDiffBenchmarkToolsExt = ["BenchmarkTools", "ForwardDiff"]
LogDensityProblemsADForwardDiffExt = "ForwardDiff"
LogDensityProblemsADReverseDiffExt = "ReverseDiff"
LogDensityProblemsADTrackerExt = "Tracker"
LogDensityProblemsADZygoteExt = "Zygote"

[compat]
ADTypes = "0.1.7, 0.2, 1"
ADTypes = "1"
DifferentiationInterface = "0.3"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.11, 0.12"
FiniteDifferences = "0.12"
LogDensityProblems = "1, 2"
Requires = "0.5, 1"
SimpleUnPack = "1"
julia = "1.6"
julia = "1.10"
gdalle marked this conversation as resolved.
Show resolved Hide resolved

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
28 changes: 0 additions & 28 deletions ext/DiffResults_helpers.jl

This file was deleted.

53 changes: 0 additions & 53 deletions ext/LogDensityProblemsADADTypesExt.jl

This file was deleted.

79 changes: 7 additions & 72 deletions ext/LogDensityProblemsADEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,78 +1,13 @@
"""
Gradient AD implementation using Enzyme.
"""
module LogDensityProblemsADEnzymeExt

if isdefined(Base, :get_extension)
using LogDensityProblemsAD: ADGradientWrapper, logdensity
using LogDensityProblemsAD.SimpleUnPack: @unpack
using ADTypes: AutoEnzyme
using LogDensityProblemsAD: ADgradient, logdensity
using Enzyme: Enzyme

import LogDensityProblemsAD: ADgradient, logdensity_and_gradient
import Enzyme
else
using ..LogDensityProblemsAD: ADGradientWrapper, logdensity
using ..LogDensityProblemsAD.SimpleUnPack: @unpack

import ..LogDensityProblemsAD: ADgradient, logdensity_and_gradient
import ..Enzyme
end

struct EnzymeGradientLogDensity{L,M<:Union{Enzyme.ForwardMode,Enzyme.ReverseMode},S} <: ADGradientWrapper
ℓ::L
mode::M
shadow::S # only used in forward mode
end

"""
ADgradient(:Enzyme, ℓ; kwargs...)
ADgradient(Val(:Enzyme), ℓ; kwargs...)

Gradient using algorithmic/automatic differentiation via Enzyme.

# Keyword arguments

- `mode::Enzyme.Mode`: Differentiation mode (default: `Enzyme.Reverse`).
Currently only `Enzyme.Reverse` and `Enzyme.Forward` are supported.

- `shadow`: Collection of one-hot vectors for each entry of the inputs `x` to the log density
`ℓ`, or `nothing` (default: `nothing`). This keyword argument is only used in forward
mode. By default, it will be recomputed in every call of `logdensity_and_gradient(ℓ, x)`.
For performance reasons it is recommended to compute it only once when calling `ADgradient`.
The one-hot vectors can be constructed, e.g., with `Enzyme.onehot(x)`.
"""
function ADgradient(::Val{:Enzyme}, ℓ; mode::Enzyme.Mode = Enzyme.Reverse, shadow = nothing)
mode isa Union{Enzyme.ForwardMode,Enzyme.ReverseMode} ||
throw(ArgumentError("currently automatic differentiation via Enzyme only supports " *
"`Enzyme.Forward` and `Enzyme.Reverse` modes"))
if mode isa Enzyme.ReverseMode && shadow !== nothing
@info "keyword argument `shadow` is ignored in reverse mode"
shadow = nothing
end
return EnzymeGradientLogDensity(ℓ, mode, shadow)
end

function Base.show(io::IO, ∇ℓ::EnzymeGradientLogDensity)
print(io, "Enzyme AD wrapper for ", ∇ℓ.ℓ, " with ",
∇ℓ.mode isa Enzyme.ForwardMode ? "forward" : "reverse", " mode")
end

function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme.ForwardMode},
x::AbstractVector)
@unpack ℓ, mode, shadow = ∇ℓ
_shadow = shadow === nothing ? Enzyme.onehot(x) : shadow
y, ∂ℓ_∂x = Enzyme.autodiff(mode, logdensity, Enzyme.BatchDuplicated,
Enzyme.Const(ℓ),
Enzyme.BatchDuplicated(x, _shadow))
return y, collect(∂ℓ_∂x)
end

function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme.ReverseMode},
x::AbstractVector)
@unpack ℓ = ∇ℓ
∂ℓ_∂x = zero(x)
_, y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, logdensity, Enzyme.Active,
Enzyme.Const(ℓ), Enzyme.Duplicated(x, ∂ℓ_∂x))
y, ∂ℓ_∂x
function ADgradient(::Val{:Enzyme}, ℓ; mode = Enzyme.Reverse, shadow = nothing)
@info "keyword argument `shadow` is now ignored"
backend = AutoEnzyme(; mode)
return ADgradient(backend, ℓ)
end

end # module
52 changes: 7 additions & 45 deletions ext/LogDensityProblemsADFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,13 @@
"""
Gradient implementation using FiniteDifferences.
"""
module LogDensityProblemsADFiniteDifferencesExt

if isdefined(Base, :get_extension)
using LogDensityProblemsAD: ADGradientWrapper, logdensity
using LogDensityProblemsAD.SimpleUnPack: @unpack
using ADTypes: AutoFiniteDifferences
using LogDensityProblemsAD: ADgradient
import FiniteDifferences: central_fdm

import LogDensityProblemsAD: ADgradient, logdensity_and_gradient
import FiniteDifferences
else
using ..LogDensityProblemsAD: ADGradientWrapper, logdensity
using ..LogDensityProblemsAD.SimpleUnPack: @unpack

import ..LogDensityProblemsAD: ADgradient, logdensity_and_gradient
import ..FiniteDifferences
end

struct FiniteDifferencesGradientLogDensity{L,M} <: ADGradientWrapper
ℓ::L
"finite difference method"
fdm::M
end

"""
ADgradient(:FiniteDifferences, ℓ; fdm = central_fdm(5, 1))
ADgradient(Val(:FiniteDifferences), ℓ; fdm = central_fdm(5, 1))

Gradient using FiniteDifferences, mainly intended for checking results from other algorithms.

# Keyword arguments

- `fdm`: the finite difference method. Defaults to `central_fdm(5, 1)`.
"""
function ADgradient(::Val{:FiniteDifferences}, ℓ; fdm = FiniteDifferences.central_fdm(5, 1))
FiniteDifferencesGradientLogDensity(ℓ, fdm)
end

function Base.show(io::IO, ∇ℓ::FiniteDifferencesGradientLogDensity)
print(io, "FiniteDifferences AD wrapper for ", ∇ℓ.ℓ, " with ", ∇ℓ.fdm)
end

function logdensity_and_gradient(∇ℓ::FiniteDifferencesGradientLogDensity, x::AbstractVector)
@unpack ℓ, fdm = ∇ℓ
y = logdensity(ℓ, x)
∇y = only(FiniteDifferences.grad(fdm, Base.Fix1(logdensity, ℓ), x))
y, ∇y
function ADgradient(::Val{:FiniteDifferences}, ℓ)
fdm = central_fdm(5, 1)
backend = AutoFiniteDifferences(; fdm)
ADgradient(backend, ℓ)
end

end # module
65 changes: 0 additions & 65 deletions ext/LogDensityProblemsADForwardDiffBenchmarkToolsExt.jl

This file was deleted.

Loading
Loading