Skip to content

Commit

Permalink
Lean into bang-bang convention (#92)
Browse files Browse the repository at this point in the history
* Lean into bang-bang convention

* Fix docs
  • Loading branch information
gdalle authored Mar 24, 2024
1 parent e2e04ae commit 224f8cf
Show file tree
Hide file tree
Showing 31 changed files with 370 additions and 347 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
Taped = "07d77754-e150-4737-8c94-cd238a1fb45b"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -50,6 +51,7 @@ DifferentiationInterfacePolyesterForwardDiffExt = [
"DiffResults",
]
DifferentiationInterfaceReverseDiffExt = ["DiffResults", "ReverseDiff"]
DifferentiationInterfaceTapedExt = ["Taped"]
DifferentiationInterfaceTrackerExt = ["Tracker"]
DifferentiationInterfaceZygoteExt = ["Zygote"]

Expand All @@ -73,6 +75,7 @@ PolyesterForwardDiff = "0.1"
ReverseDiff = "1.15"
RuntimeGeneratedFunctions = "0.5"
Test = "1"
Taped = "1"
Tracker = "0.2"
Zygote = "0.6"
julia = "1.10"
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ This package provides a backend-agnostic syntax to differentiate functions of th

We support most of the backends defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl):

| Backend | Object |
| backend | object |
| :------------------------------------------------------------------------------ | :----------------------------------------------------------- |
| [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) | `AutoChainRules(ruleconfig)` |
| [Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) | `AutoDiffractor()` |
Expand All @@ -41,7 +41,7 @@ We support most of the backends defined by [ADTypes.jl](https://github.com/SciML

We also provide one additional backend:

| Backend | Object |
| backend | object |
| :------------------------------------------------------------------------------- | :-------------------------- |
| [FastDifferentiation.jl](https://github.com/brianguenter/FastDifferentiation.jl) | `AutoFastDifferentiation()` |

Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ This is not part of the public API.
Modules = [DifferentiationInterface]
Public = false
Order = [:function, :type]
Filter = t -> !(t isa Type && t <: ADTypes.AbstractADType)
```

```@autodocs
Expand Down
2 changes: 1 addition & 1 deletion docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,5 @@ Modules = [
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceTrackerExt),
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceZygoteExt)
]
Filter = t -> !(t <: ADTypes.AbstractADType)
Filter = t -> !(t isa Type && t <: ADTypes.AbstractADType)
```
4 changes: 3 additions & 1 deletion docs/src/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ For simplicity, we remove `value_` in the operator names below.

```mermaid
flowchart LR
pushforward!! --> pushforward
derivative --> pushforward
derivative!! --> pushforward!!
gradient .-> |n|pushforward
Expand All @@ -40,10 +41,11 @@ flowchart LR

```mermaid
flowchart LR
pullback!! --> pullback
derivative .-> |m|pullback
derivative!! .-> |m|pullback!!
gradient --> pullback
gradient!! --> pullback!!
jacobian .-> |m|pullback
jacobian!! .-> |m|pullback!!
```
```
30 changes: 29 additions & 1 deletion docs/src/overview.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Getting started
# Overview

## [Operators](@id operators)

Expand Down Expand Up @@ -54,6 +54,34 @@ Second-order differentiation is also supported, with the following operators:
!!! danger
This is an experimental functionality, use at your own risk.

## Preparation

In many cases, automatic differentiation can be accelerated if the function has been run at least once (e.g. to record a tape) and if some cache objects are provided.
This is a backend-specific procedure, but we expose a common syntax to achieve it.

| operator | preparation function |
| :------------------ | :---------------------------------- |
| `derivative` | [`prepare_derivative`](@ref) |
| `gradient` | [`prepare_gradient`](@ref) |
| `jacobian` | [`prepare_jacobian`](@ref) |
| `second_derivative` | [`prepare_second_derivative`](@ref) |
| `hessian` | [`prepare_hessian`](@ref) |
| `pushforward` | [`prepare_pushforward`](@ref) |
| `pullback` | [`prepare_pullback`](@ref) |
| `hvp` | [`prepare_hvp`](@ref) |

If you run `prepare_operator(backend, f, x)`, it will create an object called `extras` containing the necessary information to speed up `operator` and its variants.
This information is specific to `backend` and `f`, as well as the _type and size_ of the input `x`, but it should work with different _values_ of `x`.

You can then call `operator(backend, f, similar_x, extras)`, which should be faster than `operator(backend, f, similar_x)`.
This is especially worth it if you plan to call `operator` several times in similar settings: you can think of it as a warm up.

By default, all the preparation functions return `nothing`.
We do not make any guarantees on their implementation for each backend, or on the performance gains that can be expected.

!!! warning
We haven't fully figured out what must happen when an `extras` object is prepared for a specific operator but then given to a lower-level one (i.e. prepare it for `jacobian` but then give it to `pushforward` inside `jacobian`).

## Multiple inputs/outputs

Restricting the API to one input and one output has many coding advantages, but it is not very flexible.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module DifferentiationInterfaceChainRulesCoreExt
using ADTypes: ADTypes, AutoChainRules
using ChainRulesCore:
HasForwardsMode, HasReverseMode, NoTangent, RuleConfig, frule_via_ad, rrule_via_ad
using DifferentiationInterface: myupdate!!
import DifferentiationInterface as DI

ruleconfig(backend::AutoChainRules) = backend.ruleconfig
Expand All @@ -25,13 +24,6 @@ function DI.value_and_pushforward(
return y, new_dy
end

function DI.value_and_pushforward!!(
f::F, dy, backend::AutoForwardChainRules, x, dx, extras
) where {F}
y, new_dy = DI.value_and_pushforward(f, backend, x, dx, extras)
return y, myupdate!!(dy, new_dy)
end

function DI.value_and_pullback(
f::F, backend::AutoReverseChainRules, x, dy, extras::Nothing
) where {F}
Expand All @@ -41,11 +33,4 @@ function DI.value_and_pullback(
return y, new_dx
end

function DI.value_and_pullback!!(
f::F, dx, backend::AutoReverseChainRules, x, dy, extras
) where {F}
y, new_dx = DI.value_and_pullback(f, backend, x, dy, extras)
return y, myupdate!!(dx, new_dx)
end

end
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module DifferentiationInterfaceDiffractorExt

import AbstractDifferentiation as AD # public API for Diffractor
using ADTypes: ADTypes, AutoChainRules, AutoDiffractor
using DifferentiationInterface: myupdate!!
import DifferentiationInterface as DI
using Diffractor: DiffractorForwardBackend, DiffractorRuleConfig

Expand All @@ -16,11 +15,4 @@ function DI.value_and_pushforward(f::F, ::AutoDiffractor, x, dx, extras::Nothing
return y, dy
end

function DI.value_and_pushforward!!(
f::F, dy, backend::AutoDiffractor, x, dx, extras
) where {F}
y, new_dy = DI.value_and_pushforward(f, backend, x, dx, extras)
return y, myupdate!!(dy, new_dy)
end

end
7 changes: 0 additions & 7 deletions ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,3 @@ function DI.value_and_pushforward(
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx_sametype))
return y, new_dy
end

function DI.value_and_pushforward!!(
f::F, dy, backend::AutoForwardEnzyme, x, dx, extras
) where {F}
y, new_dy = DI.value_and_pushforward(f, backend, x, dx, extras)
return y, myupdate!!(dy, new_dy)
end
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module DifferentiationInterfaceFiniteDifferencesExt

using ADTypes: AutoFiniteDifferences
using DifferentiationInterface: myupdate!!
import DifferentiationInterface as DI
using FillArrays: OneElement
using FiniteDifferences: FiniteDifferences, jvp
Expand All @@ -20,11 +19,4 @@ function DI.value_and_pushforward(
return y, jvp(backend.fdm, f, (x, dx))
end

function DI.value_and_pushforward!!(
f::F, dy, backend::AutoFiniteDifferences, x, dx, extras
) where {F}
y, new_dy = DI.value_and_pushforward(f, backend, x, dx, extras)
return y, myupdate!!(dy, new_dy)
end

end
46 changes: 0 additions & 46 deletions ext/DifferentiationInterfaceForwardDiffExt/TestCorrectness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ function DT.test_correctness(
@test dy_out2 dy_true rtol = 1e-3
@test dy_out3 dy_true rtol = 1e-3
@test dy_out4 dy_true rtol = 1e-3
if ismutable(dy_true)
@testset "Mutation" begin
@test dy_in2 dy_true rtol = 1e-3
@test dy_in4 dy_true rtol = 1e-3
end
end
end
return test_scen_intact(new_scen, scen)
end
Expand All @@ -65,19 +59,9 @@ function DT.test_correctness(

@testset "Primal value" begin
@test y_out y
@testset "Mutation" begin
if ismutable(y)
@test y_in y
end
end
end
@testset "Tangent value" begin
@test dy_out dy_true rtol = 1e-3
@testset "Mutation" begin
if ismutable(dy_true)
@test dy_in dy_true rtol = 1e-3
end
end
end
return test_scen_intact(new_scen, scen)
end
Expand Down Expand Up @@ -105,12 +89,6 @@ function DT.test_correctness(ba::AbstractADType, ::typeof(pullback), scen::Scena
@test dx_out2 dx_true rtol = 1e-3
@test dx_out3 dx_true rtol = 1e-3
@test dx_out4 dx_true rtol = 1e-3
if ismutable(dx_true)
@testset "Mutation" begin
@test dx_in2 dx_true rtol = 1e-3
@test dx_in4 dx_true rtol = 1e-3
end
end
end
return test_scen_intact(new_scen, scen)
end
Expand All @@ -134,13 +112,6 @@ function DT.test_correctness(ba::AbstractADType, ::typeof(pullback), scen::Scena
end
@testset "Cotangent value" begin
@test dx_out dx_true rtol = 1e-3
if ismutable(dx_true)
@testset "Mutation" begin
if ismutable(dx_true)
@test dx_in dx_true rtol = 1e-3
end
end
end
end
return test_scen_intact(new_scen, scen)
end
Expand Down Expand Up @@ -170,12 +141,6 @@ function DT.test_correctness(
@test der_out2 der_true rtol = 1e-3
@test der_out3 der_true rtol = 1e-3
@test der_out4 der_true rtol = 1e-3
@testset "Mutation" begin
if ismutable(der_true)
@test der_in2 der_true rtol = 1e-3
@test der_in4 der_true rtol = 1e-3
end
end
end
return test_scen_intact(new_scen, scen)
end
Expand All @@ -199,11 +164,6 @@ function DT.test_correctness(ba::AbstractADType, ::typeof(derivative), scen::Sce
end
@testset "Derivative value" begin
@test der_out der_true rtol = 1e-3
@testset "Mutation" begin
if ismutable(der_true)
@test der_in der_true rtol = 1e-3
end
end
end
return test_scen_intact(new_scen, scen)
end
Expand Down Expand Up @@ -235,12 +195,6 @@ function DT.test_correctness(ba::AbstractADType, ::typeof(gradient), scen::Scena
@test grad_out2 grad_true rtol = 1e-3
@test grad_out3 grad_true rtol = 1e-3
@test grad_out4 grad_true rtol = 1e-3
@testset "Mutation" begin
if ismutable(grad_true)
@test grad_in2 grad_true rtol = 1e-3
@test grad_in4 grad_true rtol = 1e-3
end
end
end
return test_scen_intact(new_scen, scen)
end
Expand Down
11 changes: 0 additions & 11 deletions ext/DifferentiationInterfaceForwardDiffExt/allocating.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
function DI.value_and_pushforward!!(
f::F, dy, ::AutoForwardDiff, x, dx, extras::Nothing
) where {F}
T = tag_type(f, x)
xdual = make_dual(T, x, dx)
ydual = f(xdual)
y = my_value(T, ydual)
dy = my_derivative!!(T, dy, ydual)
return y, dy
end

function DI.value_and_pushforward(f::F, ::AutoForwardDiff, x, dx, extras::Nothing) where {F}
T = tag_type(f, x)
xdual = make_dual(T, x, dx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ using PolyesterForwardDiff: threaded_gradient!, threaded_jacobian!

## Pushforward

function DI.value_and_pushforward!!(
f::F, dy, ::AutoPolyesterForwardDiff{C}, x, dx, extras::Nothing
function DI.value_and_pushforward(
f::F, ::AutoPolyesterForwardDiff{C}, x, dx, extras::Nothing
) where {F,C}
return DI.value_and_pushforward!!(
f, dy, AutoForwardDiff{C,Nothing}(nothing), x, dx, extras
)
return DI.value_and_pushforward(f, AutoForwardDiff{C,Nothing}(nothing), x, dx, extras)
end

function DI.value_and_pushforward!!(
Expand All @@ -26,10 +24,4 @@ function DI.value_and_pushforward!!(
)
end

function DI.value_and_pushforward(
f::F, ::AutoPolyesterForwardDiff{C}, x, dx, extras::Nothing
) where {F,C}
return DI.value_and_pushforward(f, AutoForwardDiff{C,Nothing}(nothing), x, dx, extras)
end

end # module
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module DifferentiationInterfaceTapedExt

using ADTypes: ADTypes
using DifferentiationInterface: AutoTaped, myupdate!!
import DifferentiationInterface as DI
using Taped: build_rrule, value_and_pullback!!

DI.supports_mutation(::AutoTaped) = DI.MutationNotSupported()

function DI.value_and_pullback(f::F, ::AutoTaped, x, dy, extras::Nothing) where {F}
rrule = build_rrule(f, x)
# TODO: fix for https://github.com/withbayes/Taped.jl/issues/97
y = f(x)
# TODO:
dy_righttype = convert(typeof(y), dy)
_, (_, dx) = value_and_pullback!!(rrule, dy_righttype, f, x)
return y, dx
end

end
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module DifferentiationInterfaceTrackerExt

using ADTypes: AutoTracker
import DifferentiationInterface as DI
using DifferentiationInterface: myupdate!!
using Tracker: Tracker, back, data, forward, gradient, jacobian, param, withgradient

DI.supports_mutation(::AutoTracker) = DI.MutationNotSupported()
Expand All @@ -14,9 +13,4 @@ function DI.value_and_pullback(f::F, ::AutoTracker, x, dy, extras::Nothing) wher
return y, data(only(back(dy)))
end

function DI.value_and_pullback!!(f::F, dx, backend::AutoTracker, x, dy, extras) where {F}
y, new_dx = DI.value_and_pullback(f, backend, x, dy, extras)
return y, myupdate!!(dx, new_dx)
end

end
Loading

0 comments on commit 224f8cf

Please sign in to comment.