Skip to content

Commit

Permalink
compat with latest julia version
Browse files Browse the repository at this point in the history
  • Loading branch information
gszep committed May 29, 2024
1 parent ad027f3 commit 90cb698
Show file tree
Hide file tree
Showing 7 changed files with 365 additions and 343 deletions.
19 changes: 12 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
authors = ["gszep <grisha.szep@gmail.com>"]
name = "BifurcationInference"
uuid = "7fe238d6-d31e-4646-aa16-9d8429fd6da8"
authors = ["gszep <gregory.szep@gmail.com>"]
version = "0.1.3"
version = "0.1.4"

[deps]
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
JLD = "4138dd39-2aa7-5051-a626-17a0bb65d9c8"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
BifurcationKit = "0.1.5, 0.1"
Flux = "0.12"
BifurcationKit = "0.3"
Flux = "0.14"
ForwardDiff = "0.10"
InvertedIndices = "1"
LaTeXStrings = "1"
Parameters = "0.12"
Plots = "1"
Setfield = "0.7, 0.8"
SpecialFunctions = "1.5, 2"
StaticArrays = "1.2"
Setfield = "1"
SpecialFunctions = "2"
StaticArrays = "1"
julia = "1"

[extras]
Expand Down
254 changes: 134 additions & 120 deletions src/BifurcationInference.jl
Original file line number Diff line number Diff line change
@@ -1,140 +1,154 @@
module BifurcationInference

using BifurcationKit: ContIterable, newton, ContinuationPar, NewtonPar, DeflationOperator
using BifurcationKit: BorderedArray, AbstractLinearSolver, AbstractEigenSolver, BorderingBLS
using BifurcationKit: ContState, detectBifucation
using BifurcationKit: BifurcationProblem, re_make, PALC, ContIterable, newton, ContinuationPar, NewtonPar, DeflationOperator
using BifurcationKit: BorderedArray, AbstractLinearSolver, AbstractEigenSolver, BorderingBLS
using BifurcationKit: ContState, detect_bifurcation

using ForwardDiff: Dual,tagtype,derivative,gradient,jacobian
using Flux: Momentum,update!
using ForwardDiff: Dual, tagtype, derivative, gradient, jacobian
using Flux: Momentum, update!

using Setfield: @lens,@set,setproperties
using Parameters: @unpack
using Setfield: @lens, @set, setproperties
using Parameters: @unpack

using InvertedIndices: Not
using LinearAlgebra, StaticArrays
using InvertedIndices: Not
using LinearAlgebra, StaticArrays

include("Structures.jl")
include("Utils.jl")
include("Structures.jl")
include("Utils.jl")

include("Objectives.jl")
include("Gradients.jl")
include("Plots.jl")
include("Objectives.jl")
include("Gradients.jl")
include("Plots.jl")

export plot,@unpack,BorderedArray,SizedVector
export StateSpace,deflationContinuation,train!
export getParameters,loss,∇loss,norm
export plot, @unpack, BorderedArray, SizedVector
export StateSpace, deflationContinuation, train!
export getParameters, loss, ∇loss, norm

""" root finding with newton deflation method"""
function findRoots!( f::Function, J::Function, roots::AbstractVector{<:AbstractVector},
parameters::NamedTuple, hyperparameters::ContinuationPar;
maxRoots::Int = 3, maxIter::Int=500, verbosity=0 )
""" root finding with newton deflation method"""
function findRoots!(f::Function, J::Function, roots::AbstractVector{<:AbstractVector},
parameters::NamedTuple, hyperparameters::ContinuationPar;
maxRoots::Int=3, max_iterations::Int=500, verbosity=0)

hyperparameters = @set hyperparameters.newtonOptions = setproperties(
hyperparameters.newtonOptions; maxIter = maxIter, verbose = verbosity )
hyperparameters = @set hyperparameters.newton_options = setproperties(
hyperparameters.newton_options; max_iterations=max_iterations, verbose=verbosity)

# search for roots across parameter range
pRange = range(hyperparameters.pMin,hyperparameters.pMax,length=length(roots))
roots .= findRoots.( Ref(f), Ref(J), roots, pRange, Ref(parameters), Ref(hyperparameters); maxRoots=maxRoots )
end
# search for roots across parameter range
pRange = range(hyperparameters.p_min, hyperparameters.p_max, length=length(roots))
roots .= findRoots.(Ref(f), Ref(J), roots, pRange, Ref(parameters), Ref(hyperparameters); maxRoots=maxRoots)
end

function findRoots(f::Function, J::Function, roots::AbstractVector{V}, p::T,
parameters::NamedTuple, hyperparameters::ContinuationPar{T,S,E}; maxRoots::Int=3, converged=false
) where {T<:Number,V<:AbstractVector{T},S<:AbstractLinearSolver,E<:AbstractEigenSolver}

Zero = zero(first(roots))
inf = Zero .+ Inf

function findRoots( f::Function, J::Function, roots::AbstractVector{V}, p::T,
parameters::NamedTuple, hyperparameters::ContinuationPar{T, S, E}; maxRoots::Int = 3, converged = false
) where { T<:Number, V<:AbstractVector{T}, S<:AbstractLinearSolver, E<:AbstractEigenSolver }
# search for roots at specific parameter value
deflation = DeflationOperator(one(T), dot, one(T), [inf]) # dummy deflation at infinity
parameters = @set parameters.p = p

Zero = zero(first(roots))
inf = Zero .+ Inf
problem = BifurcationProblem(f, roots[begin] .+ hyperparameters.ds, parameters; J=J)
solution = newton(problem, deflation, hyperparameters.newton_options)

# search for roots at specific parameter value
deflation = DeflationOperator(one(T), dot, one(T), [inf] ) # dummy deflation at infinity
parameters = @set parameters.p = p
for u roots # update existing roots
solution = newton(re_make(problem; u0=u .+ hyperparameters.ds), deflation, hyperparameters.newton_options)

for u roots # update existing roots
u, residual, converged, niter = newton( f, J, u.+hyperparameters.ds, parameters,
hyperparameters.newtonOptions, deflation)
i = 0
while any(isnan.(solution.residuals)) & (i < hyperparameters.newton_options.max_iterations)
u .= randn(length(u))

solution = newton(re_make(problem; u0=u .+ hyperparameters.ds), deflation, hyperparameters.newton_options)
i += 1
end

i = 0
while any(isnan.(residual)) & (i<hyperparameters.newtonOptions.maxIter)
u .= randn(length(u))

u, residual, converged, niter = newton( f, J, u.+hyperparameters.ds, parameters,
hyperparameters.newtonOptions, deflation)
i += 1
end
@assert(!any(isnan.(solution.residuals)), "f(u,p) = $(solution.residuals[end]) at u = $(solution.u), p = $(parameters.p), θ = $(parameters.θ)")
if solution.converged
push!(deflation, solution.u)
else
break
end
end

u = Zero
if solution.converged || length(deflation) == 1 # search for new roots
while length(deflation) - 1 < maxRoots

solution = newton(re_make(problem; u0=u .+ hyperparameters.ds), deflation, hyperparameters.newton_options)

# make sure new roots are different from existing
if any(isapprox.(Ref(solution.u), deflation.roots, atol=2 * hyperparameters.ds))
break
end
if solution.converged
push!(deflation, solution.u)
else
break
end
end
end

filter!(root -> root inf, deflation.roots) # remove dummy deflation at infinity
@assert(length(deflation.roots) > 0, "No roots f(u,p)=0 found at p = $(parameters.p), θ = $(parameters.θ); try increasing max_iterations")
return deflation.roots
end

@assert( !any(isnan.(residual)), "f(u,p) = $(residual[end]) at u = $u, p = $(parameters.p), θ = $(parameters.θ)")
if converged push!(deflation,u) else break end
""" deflation continuation method """
function deflationContinuation(f::Function, roots::AbstractVector{<:AbstractVector{V}},
parameters::NamedTuple, hyperparameters::ContinuationPar{T,S,E};
maxRoots::Int=3, max_iterations::Int=500, resolution=400, verbosity=0, kwargs...
) where {T<:Number,V<:AbstractVector{T},S<:AbstractLinearSolver,E<:AbstractEigenSolver}

max_iterationsContinuation, ds = hyperparameters.newton_options.max_iterations, hyperparameters.ds
J(u, p) = jacobian(x -> f(x, p), u)

findRoots!(f, J, roots, parameters, hyperparameters; maxRoots=maxRoots, max_iterations=max_iterations, verbosity=verbosity)
pRange = range(hyperparameters.p_min, hyperparameters.p_max, length=length(roots))
intervals = ([zero(T), step(pRange)], [-step(pRange), zero(T)])

branches = Vector{Branch{V,T}}()
problem = BifurcationProblem(f, roots[begin][begin], parameters, (@lens _.p); J=J)

hyperparameters = @set hyperparameters.newton_options.max_iterations = max_iterationsContinuation
linsolver = BorderingBLS(hyperparameters.newton_options.linsolver)
algorithm = PALC()

for (i, us) enumerate(roots)
for u us # perform continuation for each root

# forwards and backwards branches
for (p_min, p_max) intervals

hyperparameters = setproperties(hyperparameters;
p_min=pRange[i] + p_min, p_max=pRange[i] + p_max,
ds=sign(hyperparameters.ds) * ds)

# main continuation method
branch = Branch{V,T}()
parameters = @set parameters.p = pRange[i] + hyperparameters.ds

try
iterator = ContIterable(re_make(problem; u0=u, params=parameters), algorithm, hyperparameters; verbosity=verbosity)
for state iterator
push!(branch, state)
end

midpoint = sum(s -> s.z.p, branch) / length(branch)
if minimum(pRange) < midpoint < maximum(pRange)
push!(branches, p_min < 0 ? reverse(branch) : branch)
end

catch error
printstyled(color=:red, "Continuation Error at f(u,p)=$(f(u,parameters))\nu=$u, p=$(parameters.p), θ=$(parameters.θ)\n")
rethrow(error)
end
hyperparameters = @set hyperparameters.ds = -hyperparameters.ds
end
end
end

u = Zero
if converged || length(deflation)==1 # search for new roots
while length(deflation)-1 < maxRoots

u, residual, converged, niter = newton( f, J, u.+hyperparameters.ds, parameters,
hyperparameters.newtonOptions, deflation)

# make sure new roots are different from existing
if any( isapprox.( Ref(u), deflation.roots, atol=2*hyperparameters.ds ) ) break end
if converged push!(deflation,u) else break end
end
end

filter!( root->rootinf, deflation.roots ) # remove dummy deflation at infinity
@assert( length(deflation.roots)>0, "No roots f(u,p)=0 found at p = $(parameters.p), θ = $(parameters.θ); try increasing maxIter")
return deflation.roots
end

""" deflation continuation method """
function deflationContinuation( f::Function, roots::AbstractVector{<:AbstractVector{V}},
parameters::NamedTuple, hyperparameters::ContinuationPar{T, S, E};
maxRoots::Int = 3, maxIter::Int=500, resolution=400, verbosity=0, kwargs...
) where {T<:Number, V<:AbstractVector{T}, S<:AbstractLinearSolver, E<:AbstractEigenSolver}

maxIterContinuation,ds = hyperparameters.newtonOptions.maxIter,hyperparameters.ds
J(u,p) = jacobian(x->f(x,p),u)

findRoots!( f, J, roots, parameters, hyperparameters; maxRoots=maxRoots, maxIter=maxIter, verbosity=verbosity)
pRange = range(hyperparameters.pMin,hyperparameters.pMax,length=length(roots))
intervals = ([zero(T),step(pRange)],[-step(pRange),zero(T)])

branches = Vector{Branch{V,T}}()
hyperparameters = @set hyperparameters.newtonOptions.maxIter = maxIterContinuation
linsolver = BorderingBLS(hyperparameters.newtonOptions.linsolver)

for (i,us) enumerate(roots)
for u us # perform continuation for each root

# forwards and backwards branches
for (pMin,pMax) intervals

hyperparameters = setproperties(hyperparameters;
pMin=pRange[i]+pMin, pMax=pRange[i]+pMax,
ds=sign(hyperparameters.ds)*ds)

# main continuation method
branch = Branch{V,T}()
parameters = @set parameters.p = pRange[i]+hyperparameters.ds

try
iterator = ContIterable( f, J, u, parameters, (@lens _.p), hyperparameters, linsolver, verbosity=verbosity)
for state iterator
push!(branch,state)
end

midpoint = sum( s -> s.z.p, branch ) / length(branch)
if minimum(pRange) < midpoint < maximum(pRange)
push!(branches,pMin < 0 ? reverse(branch) : branch) end

catch error
printstyled(color=:red,"Continuation Error at f(u,p)=$(f(u,parameters))\nu=$u, p=$(parameters.p), θ=$(parameters.θ)\n")
rethrow(error)
end
hyperparameters = @set hyperparameters.ds = -hyperparameters.ds
end
end
end

hyperparameters = @set hyperparameters.ds = ds
updateParameters!(hyperparameters,branches;resolution=resolution)
return unique(branches; atol=10*hyperparameters.ds)
end
hyperparameters = @set hyperparameters.ds = ds
updateParameters!(hyperparameters, branches; resolution=resolution)
return unique(branches; atol=10 * hyperparameters.ds)
end
end # module
Loading

0 comments on commit 90cb698

Please sign in to comment.