Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove split reverse mode for mutating functions #143

Merged
merged 3 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DifferentiationInterface/docs/src/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ This means the Hessian is obtained as the sparse Jacobian of the gradient.
### Split reverse mode

Many reverse mode AD backends expose a "split" option, which runs only the forward sweep, and encapsulates the reverse sweep in a closure.
We make this available for everyone with the following operators:
We make this available for allocating functions only, with the following operators:

| out-of-place | in-place (or not) |
| ---------------------------------- | ------------------------------------ |
Expand Down
9 changes: 5 additions & 4 deletions DifferentiationInterface/src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,14 @@ function value_and_jacobian_aux!!(
x::AbstractArray,
extras::PullbackJacobianExtras,
)
y, pullbackfunc!! = value_and_pullback!!_split(
f!, y, backend, x, extras.pullback_extras
)
for (k, i) in enumerate(CartesianIndices(y))
dy_i = basis(backend, y, i)
jac_row_i_old = reshape(view(jac, k, :), size(x))
jac_row_i_new = pullbackfunc!!(y, jac_row_i_old, dy_i)
jac_row_i_new = last(
value_and_pullback!!(
f!, y, jac_row_i_old, backend, x, dy_i, extras.pullback_extras
),
)
# this allocates
copyto!(jac_row_i_old, jac_row_i_new)
end
Expand Down
23 changes: 0 additions & 23 deletions DifferentiationInterface/src/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,26 +158,3 @@ function value_and_pullback!!_split(
pullbackfunc!!(dx, dy) = pullback!!(f, dx, backend, x, dy, extras)
return f(x), pullbackfunc!!
end

"""
value_and_pullback!!_split(f!, y, backend, x, [extras])

Apply split reverse mode autodiff.

Returns a tuple `(y, pullbackfunc!!)` where the second element is a function (closure) with the following signature:

pullbackfunc!!(y, dx, dy) -> dx
"""
function value_and_pullback!!_split(
f!,
y,
backend::AbstractADType,
x,
extras::PullbackExtras=prepare_pullback(f!, backend, y, x),
)
function pullbackfunc!!(y, dx, dy)
return value_and_pullback!!(f!, y, dx, backend, x, dy, extras)[2]
end
f!(y, x)
return y, pullbackfunc!!
end
3 changes: 0 additions & 3 deletions DifferentiationInterfaceTest/src/tests/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,10 @@ function run_benchmark!(
(; f, x, y, dy) = deepcopy(scen)
f! = f
extras = prepare_pullback(f!, ba, y, x)
_, pullbackfunc!! = value_and_pullback!!_split(f!, y, ba, x, extras)
bench1 = @be (mysimilar(y), mysimilar(x)) value_and_pullback!!(
f!, _[1], _[2], ba, x, dy, extras
)
bench2 = @be (mysimilar(y), mysimilar(x)) pullbackfunc!!(_[1], _[2], dy)
record!(data, ba, value_and_pullback!!, scen, bench1)
record!(data, ba, pullbackfunc!!, scen, bench2)
return nothing
end

Expand Down
7 changes: 0 additions & 7 deletions DifferentiationInterfaceTest/src/tests/correctness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,13 @@ function test_correctness(
y10 = mysimilar(y)
y1, dx1 = value_and_pullback!!(f!, y10, mysimilar(x), ba, x, dy, extras)

y20 = mysimilar(y)
y2, pullbackfunc!! = value_and_pullback!!_split(f!, y20, ba, x, extras)
dx2 = pullbackfunc!!(y20, mysimilar(x), dy)

let (≈)(x, y) = isapprox(x, y; atol, rtol)
@testset "Primal value" begin
@test y10 ≈ y
@test y20 ≈ y
@test y1 ≈ y
@test y2 ≈ y
end
@testset "Cotangent value" begin
@test dx1 ≈ dx_true
@test dx2 ≈ dx_true
end
end
test_scen_intact(new_scen, scen)
Expand Down
3 changes: 0 additions & 3 deletions DifferentiationInterfaceTest/src/tests/type_stability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,8 @@ function test_jet(ba::AbstractADType, scen::PullbackScenario{true};)
y_in = mysimilar(y)
dx_in = mysimilar(x)

_, pullbackfunc!! = value_and_pullback!!_split(f!, y, ba, x, extras)

if Bool(pullback_performance(ba))
@test_opt value_and_pullback!!(f!, y_in, dx_in, ba, x, dy, extras)
@test_opt pullbackfunc!!(y_in, dx_in, dy)
end
return nothing
end
Expand Down
Loading