Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type stability #228

Draft
wants to merge 3 commits into
base: fix-adjoints
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ODINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ using Reexport
@reexport using Huginn # imports Muninn and Sleipnir

using Statistics, LinearAlgebra, Random, Polynomials
using Enzyme
Enzyme.API.runtimeActivity!(true) # This reduces performance but fixes AD issues
using JLD2
using OrdinaryDiffEq
using SciMLSensitivity
Expand Down
2 changes: 1 addition & 1 deletion src/models/machine_learning/MLmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function NN(params::Sleipnir.Parameters;
end

# Build the simulation parameters based on input values
ft = params.simulation.float_type
ft = Sleipnir.Float
neural_net = NN{ft}(architecture, NN_f, θ)

return neural_net
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,12 @@ function simulate_iceflow_UDE!(
apply_UDE_parametrization!(θ, simulation, nothing, batch_id)
SIA2D_UDE_closure(H, θ, t) = SIA2D_UDE(H, θ, t, simulation, batch_id)

iceflow_prob = ODEProblem(SIA2D_UDE_closure, model.iceflow[batch_id].H, params.simulation.tspan, tstops=params.solver.tstops, θ)
tstops = Enzyme.Const(params.solver.tstops) # Not clear if we need to make tstops constant, try to remove Enzyme.Const once AD is working
iceflow_prob = ODEProblem(SIA2D_UDE_closure, model.iceflow[batch_id].H, params.simulation.tspan, tstops=tstops, θ)
iceflow_sol = solve(iceflow_prob,
params.solver.solver,
callback=cb,
tstops=params.solver.tstops,
tstops=tstops,
u0=model.iceflow[batch_id].H₀,
p=θ,
sensealg=params.UDE.sensealg,
Expand Down
12 changes: 6 additions & 6 deletions src/simulations/inversions/inversion_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ function invert_iceflow_transient(glacier_idx::Int, simulation::Inversion)
inversion_metrics = InversionResults(glacier.rgi_id,optimized_A, optimized_n, optimized_C,
H_pred, H_obs, H_diff, V_pred, V_obs, V_diff,
MSE_V, #change
simulation.glaciers[glacier_idx].Δx,
simulation.glaciers[glacier_idx].Δx,
simulation.glaciers[glacier_idx].Δy,
Sleipnir.safe_getproperty(simulation.glaciers[glacier_idx].gdir, :cenlon),
Sleipnir.safe_getproperty(simulation.glaciers[glacier_idx].gdir, :cenlat))
simulation.glaciers[glacier_idx].cenlon,
simulation.glaciers[glacier_idx].cenlat)

return inversion_metrics
end
Expand Down Expand Up @@ -270,8 +270,8 @@ function invert_iceflow_ss(glacier_idx::Int, simulation::Inversion)
MSE,
simulation.glaciers[glacier_idx].Δx,
simulation.glaciers[glacier_idx].Δy,
Sleipnir.safe_getproperty(simulation.glaciers[glacier_idx].gdir, :cenlon),
Sleipnir.safe_getproperty(simulation.glaciers[glacier_idx].gdir, :cenlat)
simulation.glaciers[glacier_idx].cenlon,
simulation.glaciers[glacier_idx].cenlat,
)

return inversion_metrics
Expand Down Expand Up @@ -358,7 +358,7 @@ function H_from_V(V::Matrix{<:Real}, C::Matrix{<:Real}, simulation::SIM) where {
params::Sleipnir.Parameters = simulation.parameters

iceflow_model = simulation.model.iceflow
glacier::Sleipnir.Glacier2D = simulation.glaciers[iceflow_model.glacier_idx[]]
glacier = simulation.glaciers[iceflow_model.glacier_idx[]]
B = glacier.B
Δx = glacier.Δx
Δy = glacier.Δy
Expand Down
2 changes: 1 addition & 1 deletion test/PDE_UDE_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function ude_solve_test(atol; MB=false, fast=true)
## Retrieving simulation data for the following glaciers
## Fast version includes less glacier to reduce the amount of downloaded files and computation time on GitHub CI
if fast
rgi_ids = ["RGI60-11.03638", "RGI60-11.01450"]#, "RGI60-08.00213", "RGI60-04.04351"]
rgi_ids = ["RGI60-11.03638"]#, "RGI60-11.01450"]#, "RGI60-08.00213", "RGI60-04.04351"]
else
rgi_ids = ["RGI60-11.03638", "RGI60-11.01450", "RGI60-08.00213", "RGI60-04.04351", "RGI60-01.02170",
"RGI60-02.05098", "RGI60-01.01104", "RGI60-01.09162", "RGI60-01.00570", "RGI60-04.07051",
Expand Down
Binary file modified test/data/PDE_refs_MB.jld2
Binary file not shown.
2 changes: 0 additions & 2 deletions test/params_construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ function params_constructor_specified(save_refs::Bool = false)
plots = false,
velocities = false,
overwrite_climate = false,
float_type = Float64,
int_type = Int64,
tspan = (2010.0,2015.0),
multiprocessing = false,
workers = 10,
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import Pkg
Pkg.activate(dirname(Base.current_project()))

using Revise
using Enzyme
Enzyme.API.runtimeActivity!(true) # This reduces performance but fixes AD issues
using ODINN
using Test
using JLD2
Expand Down
Loading