-
-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Enzyme extension #927
Add Enzyme extension #927
Conversation
caa207d
to
dd3de87
Compare
ext/DiffEqBaseEnzymeExt.jl
Outdated
res = DiffEqBase._solve_adjoint(copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), SciMLBase.ChainRulesOriginator(), ntuple(arg_copy, Val(length(args)))...; | ||
kwargs...) | ||
|
||
dres = Enzyme.Compiler.make_zero(RT, IdDict(), res[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this command doing? That's not documented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's an internal function that is effectively like deepcopy, but zero'ing memory (to be accumulated into). The better solution here is to construct a zero'd res[1], if you have an API for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And is this why you needed SciML/SciMLBase.jl#496 ? That's the only place where I could see mutating the prob. The solution though is only captured in a subset of the fields, specifically sol.u
(kind ofsol.t
, maybe sol.k
but we can't actually differentiate w.r.t. those so we just mark them as not differentiable so using that would error, but the other pieces it's definitely not differentiable w.r.t.).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No the reason it is needed is that otherwise it is not legal to write a rule for the return type ODEProblem.
Alternatively to make_zero, if a good copy function was defined (with say a zero operation), that would suffice. However,
ERROR: LoadError: type ODESolution has no field sc
Stacktrace:
[1] getproperty(x::ODESolution{…}, s::Symbol)
@ SciMLBase ~/.julia/dev/SciMLBase/src/solutions/ode_solutions.jl:52 [inlined]
[2] copy(VA::ODESolution{…})
@ RecursiveArrayTools ~/.julia/packages/RecursiveArrayTools/X30HP/src/vector_of_array.jl:351
[3] #augmented_primal#1
@ ~/.julia/dev/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl:27
[4] augmented_primal
@ ~/.julia/dev/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl:10 [inlined]
```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you just use:
function zerocopy(sol::ODESolution)
ODESolution{T, N}(zero.(sol.u),
nothing,
nothing,
zero(sol.t),
nothing,
nothing,
nothing,
nothing,
false,
0,
nothing,
nothing,
sol.retcode)
end
? Basically just remove any of the things that are non output values and zero the values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type needs to be equivalent to the primal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about:
function zerocopy(sol::ODESolution)
ODESolution{T, N}(zero.(sol.u),
sol.u_analytic,
sol.errors,
zero(sol.t),
sol.k,
sol.prob,
sol.alg,
sol.interp,
sol.dense,
sol.tslocation,
sol.stats,
sol.alg_choice,
retcode)
end
Collectively the changes outlined here (and the test below) enables that SciML example to run successfully. using Enzyme
Enzyme.API.typeWarning!(false)
# SciML Tools
using OrdinaryDiffEq, SciMLSensitivity
using Optimization, OptimizationOptimisers
# Standard Libraries
using LinearAlgebra, Statistics
# External Libraries
using ComponentArrays, Lux, Zygote, Plots, StableRNGs
gr()
# Set a random seed for reproducible behaviour
const rng = StableRNG(1111)
begin
function lotka!(du, u, p, t)
α, β, γ, δ = p
du[1] = α * u[1] - β * u[2] * u[1]
du[2] = γ * u[1] * u[2] - δ * u[2]
end
# Define the experimental parameter
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(rng, 2)
p_ = [1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = 0.25)
# Add noise in terms of the mean
const X = Array(solution)
const t = solution.t
x̄ = mean(X, dims = 2)
noise_magnitude = 5e-3
const Xₙ = X .+ (noise_magnitude * x̄) .* randn(rng, eltype(X), size(X))
plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])
rbf(x) = exp.(-(x .^ 2))
# Multilayer FeedForward
U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
û = U(u, p, st)[1] # Network prediction
du[1] = p_true[1] * u[1] + û[1]
du[2] = -p_true[4] * u[2] + û[2]
end
# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
# Define the problem
const prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)
function predict(θ, X, T)
_prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
Array(solve(_prob, Vern7(), saveat = T,
abstol = 1e-6, reltol = 1e-6))
end
function loss(θ)
X̂ = predict(θ, Xₙ[:, 1], t)
mean(abs2, Xₙ .- X̂)
end
end
losses = Float64[]
callback = function (p, l)
push!(losses, l)
if length(losses) % 5 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
adtype = Optimization.AutoEnzyme()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 5000)
println("Training loss after $(length(losses)) iterations: $(losses[end])") |
Is SciML/SciMLBase.jl#496 still necessary with the deepcopy? |
That is necessary to register a rule here at all, independent of the content of the rule. |
Why would that matter? ODEProblem isn't being differentiated? |
The problem is that the type is illegal to be used as an argument or return of a reverse-mode custom rule otherwise. |
Even as a const argument? |
Const (or duplicated in forward mode) are both fine. |
The purpose of the |
ext/DiffEqBaseEnzymeExt.jl
Outdated
using ChainRulesCore | ||
using EnzymeCore | ||
|
||
function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, sensealg, u0, p, args...; kwargs...) where RT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, sensealg, u0, p, args...; kwargs...) where RT | |
function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob::Const, sensealg::Const, u0, p, args...; kwargs...) where RT |
is that sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, the types there do not assert the types that are passed in. The results of Enzyme's activity analysis deduced that the prob object contained differentiable data, and therefore correctly requested a Duplicated first arg.
Instead the argument needs to be marked inactive in ActivityAnalysis, similar to EnzymeRules.inactive, but a variant for only particular args.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vchuravy you and someone else I forget had a use for an "this arg is inactive" that I made ActivityAnalysis parse some metadata for.
Did you have thoughts on what the API design for that should be?
e6ddcca
to
5196756
Compare
5196756
to
84a8d6d
Compare
using Enzyme
Enzyme.API.typeWarning!(false)
# SciML Tools
using OrdinaryDiffEq, SciMLSensitivity
using Optimization, OptimizationOptimisers
# Standard Libraries
using LinearAlgebra, Statistics
# External Libraries
using ComponentArrays, Lux, Zygote, StableRNGs, Plots
# Set a random seed for reproducible behaviour
const rng = StableRNG(1111)
rbf(x) = exp.(-(x .^ 2))
begin
function lotka!(du, u, p, t)
α, β, γ, δ = p
du[1] = α * u[1] - β * u[2] * u[1]
du[2] = γ * u[1] * u[2] - δ * u[2]
end
# Define the experimental parameter
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(rng, 2)
p_ = [1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = 0.25)
# Add noise in terms of the mean
const X = Array(solution)
const t = solution.t
x̄ = mean(X, dims = 2)
noise_magnitude = 5e-3
const Xₙ = X .+ (noise_magnitude * x̄) .* randn(rng, eltype(X), size(X))
plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])
# Multilayer FeedForward
U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
û = U(u, p, st)[1] # Network prediction
du[1] = p_true[1] * u[1] + û[1]
du[2] = -p_true[4] * u[2] + û[2]
end
# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
# Define the problem
const prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)
function predict(θ, X, T)
_prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
Array(solve(_prob, Vern7(), saveat = T,
abstol = 1e-6, reltol = 1e-6))
end
function loss(θ)
X̂ = predict(θ, Xₙ[:, 1], t)
mean(abs2, Xₙ .- X̂)
end
end
losses = Float64[]
callback = function (p, l)
push!(losses, l)
if length(losses) % 5 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
adtype = Optimization.AutoEnzyme()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
# Multilayer FeedForward
U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 5000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
using Optim, OptimizationOptimJL
optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback = callback, maxiters = 1000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")
# Rename the best candidate
p_trained = res2.u
ts = first(solution.t):(mean(diff(solution.t)) / 2):last(solution.t)
X̂ = predict(p_trained, Xₙ[:, 1], ts)
# Trained on noisy data vs real solution
pl_trajectory = plot(ts, transpose(X̂), xlabel = "t", ylabel = "x(t), y(t)", color = :red,
label = ["UDE Approximation" nothing])
scatter!(solution.t, transpose(Xₙ), color = :black, label = ["Measurements" nothing])
savefig("Plot.png") |
Requires SciML/SciMLBase.jl#496 and EnzymeAD/Enzyme.jl#1056