Skip to content

Commit

Permalink
ODE Interface fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 24, 2024
1 parent 094e02e commit 61dda98
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 89 deletions.
2 changes: 1 addition & 1 deletion Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ version = "0.6.42"

[[deps.SciMLBase]]
deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "SciMLStructures", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"]
git-tree-sha1 = "0e4fca3dd5de4d4a82c0ffae1e51ab6234af4df0"
git-tree-sha1 = "3fbd6a361ec965a89f1ec64320a8b8a80b7409c3"
repo-rev = "ap/nlls_bvp"
repo-url = "https://github.com/SciML/SciMLBase.jl.git"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand Down
104 changes: 60 additions & 44 deletions ext/BoundaryValueDiffEqODEInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@ import ODEInterface: colnew
import FastClosures: @closure
import ForwardDiff

function _test_bvpm2_bvpsol_colnew_problem_criteria(
_, ::SciMLBase.StandardBVProblem, alg::Symbol)
throw(ArgumentError("$(alg) does not support standard BVProblem. Only TwoPointBVProblem is supported."))
end
function _test_bvpm2_bvpsol_colnew_problem_criteria(prob, ::TwoPointBVProblem, alg::Symbol)
@assert isinplace(prob) "$(alg) only supports inplace TwoPointBVProblem!"
end

#------
# BVPM2
#------
Expand Down Expand Up @@ -54,7 +46,7 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3,
if prob.u0 isa Function
guess_function = @closure (x, y) -> (y .= vec(__initial_guess(prob.u0, prob.p, x)))
bvpm2_init(obj, no_odes, no_left_bc, mesh, guess_function,
eltype(u0_)[], alg.max_num_subintervals, prob.u0)
eltype(u0_)[], alg.max_num_subintervals)
else
u0 = __flatten_initial_guess(prob.u0)
bvpm2_init(
Expand Down Expand Up @@ -98,7 +90,7 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3,
bvpm2_destroy(obj)
bvpm2_destroy(sol)

return SciMLBase.build_solution(prob, ivpsol, nothing)
return ivpsol
end

#-------
Expand Down Expand Up @@ -171,9 +163,9 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000,
elseif retcode == -5
@warn "Given initial values inconsistent with separable linear bc"
elseif retcode == -6
@warn """Iterative refinement faild to converge for `sol_method=0`
Termination since multiple shooting condition or
condition of Jacobian is too bad for `sol_method=1`"""
@warn "Iterative refinement faild to converge for `sol_method=0` \
Termination since multiple shooting condition or \
condition of Jacobian is too bad for `sol_method=1`"
elseif retcode == -8
@warn "Condensing algorithm for linear block system fails, try `sol_method=1`"
elseif retcode == -9
Expand All @@ -187,68 +179,92 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000,

ivpsol = SciMLBase.build_solution(prob, alg, sol_t,
map(x -> reshape(convert(Vector{eltype(u0_)}, x), u0_size), eachcol(sol_x));
retcode = retcode 0 ? ReturnCode.Success : ReturnCode.Failure, stats,
original = (sol_t, sol_x, retcode, stats))
retcode = retcode 0 ? ReturnCode.Success : ReturnCode.Failure,
stats, original = (sol_t, sol_x, retcode, stats))

return SciMLBase.build_solution(prob, ivpsol, nothing)
return ivpsol
end

#-------
# COLNEW
#-------
#= TODO: FIX this
function SciMLBase.__solve(prob::BVProblem, alg::COLNEW; maxiters = 1000, reltol=1e-4, dt = 0.0, verbose = true, kwargs...)
_test_bvpm2_bvpsol_colnew_problem_criteria(prob, prob.problem_type, :COLNEW)
has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray}
dt ≤ 0 && throw(ArgumentError("dt must be positive"))
no_odes, n, u0 = if has_initial_guess
length(first(prob.u0)), (length(prob.u0) - 1), reduce(hcat, prob.u0)
else
length(prob.u0), Int(cld((prob.tspan[2] - prob.tspan[1]), dt)), prob.u0
function SciMLBase.__solve(prob::BVProblem, alg::COLNEW; maxiters = 1000,
reltol = 1e-3, dt = 0.0, verbose = true, kwargs...)
# FIXME: COLNEW does support MP-BVPs but in a very clunky way
if !(prob.problem_type isa TwoPointBVProblem)
throw(ArgumentError("`COLNEW` only supports `TwoPointBVProblem!`"))
end

dt 0 && throw(ArgumentError("`dt` must be positive"))

t₀, t₁ = prob.tspan
u0_ = __extract_u0(prob.u0, prob.p, t₀)
u0_size = size(u0_)
n = __initial_guess_length(prob.u0)

u0 = __flatten_initial_guess(prob.u0)
mesh = __extract_mesh(prob.u0, t₀, t₁, ifelse(n == -1, dt, n - 1))
if u0 === nothing
# initial_guess function was provided
u0 = mapreduce(@closure(t->vec(__initial_guess(prob.u0, prob.p, t))), hcat, mesh)
end

no_odes = length(u0_)

# has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray}
# dt ≤ 0 && throw(ArgumentError("dt must be positive"))
# no_odes, n, u0 = if has_initial_guess
# length(first(prob.u0)), (length(prob.u0) - 1), reduce(hcat, prob.u0)
# else
# length(prob.u0), Int(cld((prob.tspan[2] - prob.tspan[1]), dt)), prob.u0
# end

T = eltype(u0)
mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))
opt = OptionsODE(
OPT_BVPCLASS => alg.bvpclass, OPT_COLLOCATIONPTS => alg.collocationpts,
# mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))
opt = OptionsODE(OPT_BVPCLASS => alg.bvpclass, OPT_COLLOCATIONPTS => alg.collocationpts,
OPT_MAXSTEPS => maxiters, OPT_DIAGNOSTICOUTPUT => alg.diagnostic_output,
OPT_MAXSUBINTERVALS => alg.max_num_subintervals, OPT_RTOL => reltol)
orders = ones(Int, no_odes)
_tspan = [prob.tspan[1], prob.tspan[2]]
iip = SciMLBase.isinplace(prob)

rhs(t, u, du) =
rhs = @closure (t, u, du) -> begin
if iip
prob.f(du, u, prob.p, t)
else
(du .= prob.f(u, prob.p, t))
end
end

if prob.f.jac === nothing
if iip
jac = function (df, u, p, t)
jac = (df, u, p, t) -> begin
_du = similar(u)
prob.f(_du, u, p, t)
_f = (du, u) -> prob.f(du, u, p, t)
_f = @closure (du, u) -> prob.f(du, u, p, t)
ForwardDiff.jacobian!(df, _f, _du, u)
return
end
else
jac = function (df, u, p, t)
jac = (df, u, p, t) -> begin
_du = prob.f(u, p, t)
_f = (du, u) -> (du .= prob.f(u, p, t))
_f = @closure (du, u) -> (du .= prob.f(u, p, t))
ForwardDiff.jacobian!(df, _f, _du, u)
return
end
end
else
jac = prob.f.jac
end
Drhs(t, u, df) = jac(df, u, prob.p, t)
Drhs = @closure (t, u, df) -> jac(df, u, prob.p, t)

#TODO: Fix bc and bcjac for multi-points BVP
bcresid_prototype, _ = BoundaryValueDiffEq.__get_bcresid_prototype(
prob.problem_type, prob, u0)

n_bc_a = length(first(prob.f.bcresid_prototype.x))
n_bc_b = length(last(prob.f.bcresid_prototype.x))
n_bc_a = length(first(bcresid_prototype))
n_bc_b = length(last(bcresid_prototype))
zeta = vcat(fill(first(prob.tspan), n_bc_a), fill(last(prob.tspan), n_bc_b))
bc = function (i, z, resid)
bc = @closure (i, z, resid) -> begin
tmpa = copy(z)
tmpb = copy(z)
tmp_resid_a = zeros(T, n_bc_a)
Expand All @@ -268,7 +284,7 @@ function SciMLBase.__solve(prob::BVProblem, alg::COLNEW; maxiters = 1000, reltol
end
end

Dbc = function (i, z, dbc)
Dbc = @closure (i, z, dbc) -> begin
for j in 1:n_bc_a
if i == j
dbc[i] = 1.0
Expand All @@ -287,7 +303,8 @@ function SciMLBase.__solve(prob::BVProblem, alg::COLNEW; maxiters = 1000, reltol
if retcode == 0
@warn "Collocation matrix is singular"
elseif retcode == -1
@warn "The expected no. of subintervals exceeds storage(try to increase `OPT_MAXSUBINTERVALS`)"
@warn "The expected no. of subintervals exceeds storage(try to increase \
`OPT_MAXSUBINTERVALS`)"
elseif retcode == -2
@warn "The nonlinear iteration has not converged"
elseif retcode == -3
Expand All @@ -299,11 +316,10 @@ function SciMLBase.__solve(prob::BVProblem, alg::COLNEW; maxiters = 1000, reltol
destats = SciMLBase.DEStats(
stats["no_rhs_calls"], 0, 0, 0, stats["no_jac_calls"], 0, 0, 0, 0, 0, 0, 0, 0)

return DiffEqBase.build_solution(prob, alg, mesh,
collect(Vector{eltype(evalsol)}, eachrow(evalsol));
return DiffEqBase.build_solution(
prob, alg, mesh, collect(Vector{eltype(evalsol)}, eachrow(evalsol));
retcode = retcode > 0 ? ReturnCode.Success : ReturnCode.Failure,
stats = destats)
stats = destats, original = (sol, retcode, stats))
end
=#

end
4 changes: 4 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ struct BVPSOL{O} <: BoundaryValueDiffEqAlgorithm
end
end

function BVPSOL(; bvpclass = 2, sol_method = 0, odesolver = nothing)
return BVPSOL(bvpclass, sol_method, odesolver)
end

"""
COLNEW(; bvpclass = 1, collocationpts = 7, diagnostic_output = 1,
max_num_subintervals = 3000)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test, BoundaryValueDiffEq, LinearAlgebra, ODEInterface, Random, OrdinaryDiffEq,
RecursiveArrayTools
@testsetup module ODEInterfaceWrapperTestSetup

using BoundaryValueDiffEq, LinearAlgebra, ODEInterface, Random, RecursiveArrayTools

# Adaptation of https://github.com/luchr/ODEInterface.jl/blob/958b6023d1dabf775033d0b89c5401b33100bca3/examples/BasicExamples/ex7.jl
function ex7_f!(du, u, p, t)
Expand Down Expand Up @@ -27,8 +28,16 @@ tspan = (-π / 2, π / 2)
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), u0, tspan,
p; bcresid_prototype = (zeros(1), zeros(1)))

@testset "BVPM2" begin
@info "Testing BVPM2"
# Just generate a solution for bvpsol
sol_ms = solve(tpprob, MultipleShooting(10, DP5(), NewtonRaphson());
dt = π / 20, abstol = 1e-5, maxiters = 1000, adaptive = false)

export ex7_f!, ex7_2pbc1!, ex7_2pbc2!, u0, p, tspan, tpprob, sol_ms

end

@testitem "BVPM2" setup=[ODEInterfaceWrapperTestSetup] begin
using ODEInterface, RecursiveArrayTools

sol_bvpm2 = solve(tpprob, BVPM2(); dt = π / 20)
@test SciMLBase.successful_retcode(sol_bvpm2)
Expand All @@ -38,70 +47,61 @@ tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), u0, tspan,
@test norm(resid_f, Inf) < 1e-6
end

# Just generate a solution for bvpsol
sol_ms = solve(tpprob, MultipleShooting(10, DP5(), NewtonRaphson());
dt = π / 20, abstol = 1e-5, maxiters = 1000,
odesolve_kwargs = (; adaptive = false, dt = 0.01, abstol = 1e-6, maxiters = 1000))

# Just test that it runs. BVPSOL only works with linearly separable BCs.
@testset "BVPSOL" begin
@info "Testing BVPSOL"

@info "BVPSOL with Vector{<:AbstractArray}"
@testitem "BVPSOL" setup=[ODEInterfaceWrapperTestSetup] begin
using ODEInterface, RecursiveArrayTools

initial_u0 = [sol_ms(t) .+ rand() for t in tspan[1]:/ 20):tspan[2]]
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0,
tspan, p; bcresid_prototype = (zeros(1), zeros(1)))

# Just test that it runs. BVPSOL only works with linearly separable BCs.
sol_bvpsol = solve(tpprob, BVPSOL(); dt = π / 20)

@info "BVPSOL with VectorOfArray"
@test sol_bvpsol isa SciMLBase.ODESolution

initial_u0 = VectorOfArray([sol_ms(t) .+ rand() for t in tspan[1]:/ 20):tspan[2]])
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0,
tspan, p; bcresid_prototype = (zeros(1), zeros(1)))

# Just test that it runs. BVPSOL only works with linearly separable BCs.
sol_bvpsol = solve(tpprob, BVPSOL(); dt = π / 20)

@info "BVPSOL with DiffEqArray"
@test sol_bvpsol isa SciMLBase.ODESolution

ts = collect(tspan[1]:/ 20):tspan[2])
initial_u0 = DiffEqArray([sol_ms(t) .+ rand() for t in ts], ts)
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0,
tspan, p; bcresid_prototype = (zeros(1), zeros(1)))

sol_bvpsol = solve(tpprob, BVPSOL(); dt = π / 20)

@info "BVPSOL with initial guess function"
@test sol_bvpsol isa SciMLBase.ODESolution

initial_u0 = (p, t) -> sol_ms(t) .+ rand()
# FIXME: Upstream fix
# tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan, p;
# bcresid_prototype = (zeros(1), zeros(1)))
# sol_bvpsol = solve(tpprob, BVPSOL(); dt = π / 20)
end

#=
@info "COLNEW"
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0,
tspan, p; bcresid_prototype = (zeros(1), zeros(1)))
sol_bvpsol = solve(tpprob, BVPSOL(); dt = π / 20)

function f!(du, u, p, t)
du[1] = u[2]
du[2] = u[1]
end
function bca!(resid_a, u_a, p)
resid_a[1] = u_a[1] - 1
end
function bcb!(resid_b, u_b, p)
resid_b[1] = u_b[1]
@test sol_bvpsol isa SciMLBase.ODESolution
end

fun = BVPFunction(
f!, (bca!, bcb!), bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
tspan = (0.0, 1.0)
prob = TwoPointBVProblem(fun, [1.0, 0.0], tspan)
sol_colnew = solve(prob, COLNEW(), dt = 0.01)
@test SciMLBase.successful_retcode(sol_colnew)
=#
@testitem "COLNEW" setup=[ODEInterfaceWrapperTestSetup] begin
using ODEInterface, RecursiveArrayTools

function f!(du, u, p, t)
du[1] = u[2]
du[2] = u[1]
end
function bca!(resid_a, u_a, p)
resid_a[1] = u_a[1] - 1
end
function bcb!(resid_b, u_b, p)
resid_b[1] = u_b[1]
end

fun = BVPFunction(
f!, (bca!, bcb!), bcresid_prototype = (zeros(1), zeros(1)), twopoint = Val(true))
tspan = (0.0, 1.0)

prob = TwoPointBVProblem(fun, [1.0, 0.0], tspan)
sol_colnew = solve(prob, COLNEW(), dt = 0.01)
@test SciMLBase.successful_retcode(sol_colnew)
end

0 comments on commit 61dda98

Please sign in to comment.