-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HamiltonianProblem not supported #53
Comments
As an update, we have certain mixings of ForwardDiff and Zygote working, like in JuliaDiff/SparseDiffTools.jl@26a3fc0 . You have to be careful with tags but then it works if you avoid JuliaLang/julia#265 related Zygote issues. The adjoint is almost setup to be using Zygote for the vjps SciML/SciMLSensitivity.jl#71 , so if that's the case Zygote adjoint may just fix this. |
If you want to manually work around it, you can do something like: using DiffEqFlux,Flux,Zygote,Random,Plots,OrdinaryDiffEq,DiffEqSensitivity
Random.seed!(42)
# NN for Hamiltonian. Tiny for debugging
neural_H = Chain(
Dense(2,1,tanh)
)
p2,re = Flux.destructure(neural_H)
ps = Flux.params(p2)
function neural_hamiltonian!(du,u,p,t)
H = Zygote.gradient(u -> re(p)(u)[1],u)[1]
# Commented line below trains as normal
#du[1] = re(p)(u)[1]
du[1] = -H[1]
du[2] = H[2]
end
tspan = (0.0f0,10.0f0)
dsize = 10
t = range(tspan[1],tspan[2],length=dsize)
u0 = Float32[1.0,1.0]
neural_ode_prob = ODEProblem(neural_hamiltonian!,u0,tspan,p2)
function predict_adjoint()
Array(concrete_solve(neural_ode_prob,Tsit5(),u0,p2,saveat=t,sensealg=InterpolatingAdjoint(autojacvec=false)))
end
# ode_data is the dataset I'm trying to fit.
ode_data = ones(2,10)
loss() = sum((ode_data .- predict_adjoint()).^2)
cb = function()
#println(ps)
display(loss())
end
dummydata = Iterators.repeated((),10)
opt = ADAM(0.03)
Flux.train!(loss,ps,dummydata,opt,cb=cb) |
Hi, thanks for the awesome package! Unfortunately I'm having issues using it for my work. I basically want to integrate a set of equations of motions defined by a Hamiltonian
H(p, x)
that is too complex to differentiate by hand. The code below sets up the problem:Now I want to differentiate solutions of the equations of motion with respect to the parameter vector. The function below sets this up and tries to run the ODE solver:
When I run
testsolve()
I get the errorfollowed by a massive stack trace.
It seems like there's a conflict between the datatypes used to define the equations of motion and the parameters tracked by Flux. Is there a way to fix this?
I was not able to get this to work using Zygote rather than ForwardDiff to differentiate
H
-- hopefully I'm missing something about how to use nested autodifferentiation?The text was updated successfully, but these errors were encountered: