Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-argument support: basic infrastructure #461

Merged
merged 16 commits into from
Sep 14, 2024
Merged

Multi-argument support: basic infrastructure #461

merged 16 commits into from
Sep 14, 2024

Conversation

gdalle
Copy link
Member

@gdalle gdalle commented Sep 8, 2024

Overview

For some backends, users can now pass additional context arguments to differentiation operators, which are forwarded to the function. This helps avoid closures and tap into Enzyme's annotations more efficiently. See also #311 and #403 for discussions.

Note

This PR is not complete, see the follow up issues in #465 and #466.

The following rules apply:

  • There must still be only one "active" (differentiated) argument, called x, which is provided right after the backend. The number of context arguments is arbitrary: f(x) becomes f(x, contexts...) and f!(y, x) becomes f!(y, x, contexts...).
  • Every differentiation operator requires the same number of context arguments as the function itself. However the operator's results only refer to x and not to the context arguments.
  • Preparation requires context too, but the context can change when the operator is applied afterwards.
  • Differentiation operators require wrapping contexts in specific types, while functions themselves don't. At the moment Constant is the only context type, but we will add Cache in the future.

Example

Setup:

using DifferentiationInterface
import ForwardDiff
b = AutoForwardDiff()

f1(x) = sum(x .^ 2)
f2(x, ca) = ca * sum(x .^ 2)
f3(x, ca, cb) = ca *  sum(x .^ cb)

x, ca, cb = [1.0, 2.0], 3.0, 4

Right number of arguments:

julia> gradient(f1, b, x)  # gradient of x -> f1(x)
2-element Vector{Float64}:
 2.0
 4.0

julia> gradient(f2, b, x, Constant(ca))  # gradient of x -> f2(x, ca)
2-element Vector{Float64}:
  6.0
 12.0

julia> gradient(f3, b, x, Constant(ca), Constant(cb))  # gradient of x -> f3(x, ca, cb)
2-element Vector{Float64}:
 12.0
 96.0

Wrong number of arguments:

julia> gradient(f1, b, x, Constant(ca))
ERROR: MethodError: no method matching f1(::Vector{Float64}, ::Float64)

Closest candidates are:
  f1(::Any)
   @ Main REPL[6]:1

julia> gradient(f2, b, x)
ERROR: MethodError: no method matching f2(::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(f2), Float64}, Float64, 2}})

Closest candidates are:
  f2(::Any, ::Any)
   @ Main REPL[7]:1

DI source

  • Implement the Context abstract type with its Constant subtype. Add unwrap and Rewrap to remove or recover wrapper types.
  • Allow contexts in all dense operators.
  • Allow contexts in fallbacks.
  • In second-order operators, be careful to unwrap and rewrap properly inside the closure for gradient or derivative.

DI extensions

  • Allow contexts with ForwardDiff.
  • Allow contexts in sparse operators, but sofar it doesn't work because tracing only accepts a single argument. Probably just takes a closure.

DIT source

  • Include tuple of contexts in Scenario.
  • Adapt function and operator signature in all the tests.
  • Implement insert_context to turn f(x) into f(x, a) = a * f(x) and insert that function in the scenario.

DI and DIT tests

  • Add some tests with contexts but they didn't catch all the bugs so more testing is needed.

@codecov-commenter
Copy link

codecov-commenter commented Sep 8, 2024

Codecov Report

Attention: Patch coverage is 98.72611% with 6 lines in your changes missing coverage. Please review.

Project coverage is 98.51%. Comparing base (b3c9b0a) to head (2735a42).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
DifferentiationInterface/src/second_order/hvp.jl 85.18% 4 Missing ⚠️
DifferentiationInterface/src/utils/context.jl 88.23% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #461      +/-   ##
==========================================
- Coverage   98.53%   98.51%   -0.03%     
==========================================
  Files         106      108       +2     
  Lines        4234     4297      +63     
==========================================
+ Hits         4172     4233      +61     
- Misses         62       64       +2     
Flag Coverage Δ
DI 98.57% <97.84%> (-0.12%) ⬇️
DIT 98.38% <100.00%> (+0.16%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@gdalle gdalle added this to the v0.6 milestone Sep 14, 2024
@gdalle gdalle marked this pull request as ready for review September 14, 2024 10:04
@gdalle gdalle merged commit 267023a into main Sep 14, 2024
103 of 105 checks passed
@gdalle gdalle deleted the gd/context branch September 24, 2024 10:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants