Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enzyme extension #927

Merged
merged 4 commits into from
Sep 23, 2023
Merged

Add Enzyme extension #927

merged 4 commits into from
Sep 23, 2023

Conversation

wsmoses
Copy link
Contributor

@wsmoses wsmoses commented Sep 16, 2023

@wsmoses wsmoses force-pushed the master branch 2 times, most recently from caa207d to dd3de87 Compare September 16, 2023 22:23
res = DiffEqBase._solve_adjoint(copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), SciMLBase.ChainRulesOriginator(), ntuple(arg_copy, Val(length(args)))...;
kwargs...)

dres = Enzyme.Compiler.make_zero(RT, IdDict(), res[1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this command doing? That's not documented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an internal function that is effectively like deepcopy, but zero'ing memory (to be accumulated into). The better solution here is to construct a zero'd res[1], if you have an API for that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And is this why you needed SciML/SciMLBase.jl#496 ? That's the only place where I could see mutating the prob. The solution though is only captured in a subset of the fields, specifically sol.u (kind ofsol.t, maybe sol.k but we can't actually differentiate w.r.t. those so we just mark them as not differentiable so using that would error, but the other pieces it's definitely not differentiable w.r.t.).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No the reason it is needed is that otherwise it is not legal to write a rule for the return type ODEProblem.

Alternatively to make_zero, if a good copy function was defined (with say a zero operation), that would suffice. However,

ERROR: LoadError: type ODESolution has no field sc
Stacktrace:
  [1] getproperty(x::ODESolution{…}, s::Symbol)
    @ SciMLBase ~/.julia/dev/SciMLBase/src/solutions/ode_solutions.jl:52 [inlined]
  [2] copy(VA::ODESolution{…})
    @ RecursiveArrayTools ~/.julia/packages/RecursiveArrayTools/X30HP/src/vector_of_array.jl:351
  [3] #augmented_primal#1
    @ ~/.julia/dev/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl:27
  [4] augmented_primal
    @ ~/.julia/dev/DiffEqBase/ext/DiffEqBaseEnzymeExt.jl:10 [inlined]
    ```

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you just use:

function zerocopy(sol::ODESolution)
    ODESolution{T, N}(zero.(sol.u),
        nothing,
        nothing,
        zero(sol.t),
        nothing,
        nothing,
        nothing,
        nothing,
        false,
        0,
        nothing,
        nothing,
        sol.retcode)
end

? Basically just remove any of the things that are non output values and zero the values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type needs to be equivalent to the primal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about:

function zerocopy(sol::ODESolution)
    ODESolution{T, N}(zero.(sol.u),
        sol.u_analytic,
        sol.errors,
        zero(sol.t),
        sol.k,
        sol.prob,
        sol.alg,
        sol.interp,
        sol.dense,
        sol.tslocation,
        sol.stats,
        sol.alg_choice,
        retcode)
end

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 16, 2023

Collectively the changes outlined here (and the test below) enables that SciML example to run successfully.

using Enzyme

Enzyme.API.typeWarning!(false)

# SciML Tools
using OrdinaryDiffEq,  SciMLSensitivity
using Optimization, OptimizationOptimisers

# Standard Libraries
using LinearAlgebra, Statistics

# External Libraries
using ComponentArrays, Lux, Zygote, Plots, StableRNGs
gr()

# Set a random seed for reproducible behaviour
const rng = StableRNG(1111)

begin

	function lotka!(du, u, p, t)
	    α, β, γ, δ = p
	    du[1] = α * u[1] - β * u[2] * u[1]
	    du[2] = γ * u[1] * u[2] - δ * u[2]
	end

	# Define the experimental parameter
	tspan = (0.0, 5.0)
	u0 = 5.0f0 * rand(rng, 2)
	p_ = [1.3, 0.9, 0.8, 1.8]
	prob = ODEProblem(lotka!, u0, tspan, p_)
	solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = 0.25)

	# Add noise in terms of the mean
	const X = Array(solution)
	const t = solution.t

	x̄ = mean(X, dims = 2)
	noise_magnitude = 5e-3
	const Xₙ = X .+ (noise_magnitude * x̄) .* randn(rng, eltype(X), size(X))

	plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
	scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])

	rbf(x) = exp.(-(x .^ 2))

	# Multilayer FeedForward
	U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
	              Lux.Dense(5, 2))
	# Get the initial parameters and state variables of the model
	p, st = Lux.setup(rng, U)

	# Define the hybrid model
	function ude_dynamics!(du, u, p, t, p_true)
	    û = U(u, p, st)[1] # Network prediction
	    du[1] = p_true[1] * u[1] + û[1]
	    du[2] = -p_true[4] * u[2] + û[2]
	end

	# Closure with the known parameter
	nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
	# Define the problem

	const prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)

	function predict(θ, X, T)
	    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
	    Array(solve(_prob, Vern7(), saveat = T,
	                abstol = 1e-6, reltol = 1e-6))
	end

	function loss(θ)
	    X̂ = predict(θ, Xₙ[:, 1], t)
		mean(abs2, Xₙ .- X̂)
	end
end

losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 5 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

adtype = Optimization.AutoEnzyme()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 5000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

@ChrisRackauckas
Copy link
Member

Is SciML/SciMLBase.jl#496 still necessary with the deepcopy?

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 17, 2023

That is necessary to register a rule here at all, independent of the content of the rule.

@ChrisRackauckas
Copy link
Member

Why would that matter? ODEProblem isn't being differentiated?

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 17, 2023

The problem is that the type is illegal to be used as an argument or return of a reverse-mode custom rule otherwise.

@ChrisRackauckas
Copy link
Member

Even as a const argument?

@wsmoses
Copy link
Contributor Author

wsmoses commented Sep 17, 2023

Const (or duplicated in forward mode) are both fine.

@ChrisRackauckas
Copy link
Member

The purpose of the solve_up function is so that u0 and p are removed from the prob since those are the two arguments being differentiated, and those are then passed as separate arguments. That is done to make prob a const w.r.t. differentiation of u0 and p, since the functions then just use the two other pieces. This was required for ChainRules to work, so we could just make use of that.

using ChainRulesCore
using EnzymeCore

function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, sensealg, u0, p, args...; kwargs...) where RT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, sensealg, u0, p, args...; kwargs...) where RT
function EnzymeCore.EnzymeRules.augmented_primal(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob::Const, sensealg::Const, u0, p, args...; kwargs...) where RT

is that sufficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, the types there do not assert the types that are passed in. The results of Enzyme's activity analysis deduced that the prob object contained differentiable data, and therefore correctly requested a Duplicated first arg.

Instead the argument needs to be marked inactive in ActivityAnalysis, similar to EnzymeRules.inactive, but a variant for only particular args.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vchuravy you and someone else I forget had a use for an "this arg is inactive" that I made ActivityAnalysis parse some metadata for.

Did you have thoughts on what the API design for that should be?

@ChrisRackauckas
Copy link
Member

using Enzyme

Enzyme.API.typeWarning!(false)

# SciML Tools
using OrdinaryDiffEq,  SciMLSensitivity
using Optimization, OptimizationOptimisers

# Standard Libraries
using LinearAlgebra, Statistics

# External Libraries
using ComponentArrays, Lux, Zygote, StableRNGs, Plots

# Set a random seed for reproducible behaviour
const rng = StableRNG(1111)

rbf(x) = exp.(-(x .^ 2))

begin

	function lotka!(du, u, p, t)
	    α, β, γ, δ = p
	    du[1] = α * u[1] - β * u[2] * u[1]
	    du[2] = γ * u[1] * u[2] - δ * u[2]
	end

	# Define the experimental parameter
	tspan = (0.0, 5.0)
	u0 = 5.0f0 * rand(rng, 2)
	p_ = [1.3, 0.9, 0.8, 1.8]
	prob = ODEProblem(lotka!, u0, tspan, p_)
	solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = 0.25)

	# Add noise in terms of the mean
	const X = Array(solution)
	const t = solution.t

	x̄ = mean(X, dims = 2)
	noise_magnitude = 5e-3
	const Xₙ = X .+ (noise_magnitude * x̄) .* randn(rng, eltype(X), size(X))

	plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
	scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])

	# Multilayer FeedForward
	U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
	              Lux.Dense(5, 2))
	# Get the initial parameters and state variables of the model
	p, st = Lux.setup(rng, U)

	# Define the hybrid model
	function ude_dynamics!(du, u, p, t, p_true)
	    û = U(u, p, st)[1] # Network prediction
	    du[1] = p_true[1] * u[1] + û[1]
	    du[2] = -p_true[4] * u[2] + û[2]
	end

	# Closure with the known parameter
	nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
	# Define the problem

	const prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)

	function predict(θ, X, T)
	    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
	    Array(solve(_prob, Vern7(), saveat = T,
	                abstol = 1e-6, reltol = 1e-6))
	end

	function loss(θ)
	    X̂ = predict(θ, Xₙ[:, 1], t)
		mean(abs2, Xₙ .- X̂)
	end
end

losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 5 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

adtype = Optimization.AutoEnzyme()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)

# Multilayer FeedForward
U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
                Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)

optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 5000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

using Optim, OptimizationOptimJL

optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback = callback, maxiters = 1000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

# Rename the best candidate
p_trained = res2.u

ts = first(solution.t):(mean(diff(solution.t)) / 2):last(solution.t)
X̂ = predict(p_trained, Xₙ[:, 1], ts)
# Trained on noisy data vs real solution
pl_trajectory = plot(ts, transpose(X̂), xlabel = "t", ylabel = "x(t), y(t)", color = :red,
                     label = ["UDE Approximation" nothing])
scatter!(solution.t, transpose(Xₙ), color = :black, label = ["Measurements" nothing])
savefig("Plot.png")

Plot

@ChrisRackauckas ChrisRackauckas merged commit 7e3eff9 into SciML:master Sep 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants