diff --git a/Project.toml b/Project.toml index 312e7964b..787aa4e6d 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" [sources] +SciMLBase = {url = "https://github.com/SciML/SciMLBase.jl", rev = "master"} OptimizationBase = {path = "lib/OptimizationBase"} OptimizationLBFGSB = {path = "lib/OptimizationLBFGSB"} OptimizationMOI = {path = "lib/OptimizationMOI"} diff --git a/lib/OptimizationBase/src/OptimizationBase.jl b/lib/OptimizationBase/src/OptimizationBase.jl index 244f086ce..c478202d9 100644 --- a/lib/OptimizationBase/src/OptimizationBase.jl +++ b/lib/OptimizationBase/src/OptimizationBase.jl @@ -7,7 +7,7 @@ using Reexport using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra import SciMLBase: OptimizationProblem, OptimizationFunction, ObjSense, - MaxSense, MinSense, OptimizationStats + MaxSense, MinSense, OptimizationStats, NonConcreteEltypeError export ObjSense, MaxSense, MinSense using FastClosures diff --git a/lib/OptimizationBase/src/solve.jl b/lib/OptimizationBase/src/solve.jl index 70688bb54..580db338e 100644 --- a/lib/OptimizationBase/src/solve.jl +++ b/lib/OptimizationBase/src/solve.jl @@ -29,6 +29,147 @@ function Base.showerror(io::IO, e::OptimizerMissingError) print(e.alg) end +""" +```julia +solve(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm, + args...; kwargs...)::OptimizationSolution +``` + +For information about the returned solution object, refer to the documentation for [`OptimizationSolution`](@ref) + +## Keyword Arguments + +The arguments to `solve` are common across all of the optimizers. +These common arguments are: + + - `maxiters`: the maximum number of iterations + - `maxtime`: the maximum amount of time (typically in seconds) the optimization runs for + - `abstol`: absolute tolerance in changes of the objective value + - `reltol`: relative tolerance in changes of the objective value + - `callback`: a callback function + +Some optimizer algorithms have special keyword arguments documented in the +solver portion of the documentation and their respective documentation. +These arguments can be passed as `kwargs...` to `solve`. Similarly, the special +keyword arguments for the `local_method` of a global optimizer are passed as a +`NamedTuple` to `local_options`. + +Over time, we hope to cover more of these keyword arguments under the common interface. + +A warning will be shown if a common argument is not implemented for an optimizer. + +## Callback Functions + +The callback function `callback` is a function that is called after every optimizer +step. Its signature is: + +```julia +callback = (state, loss_val) -> false +``` + +where `state` is an `OptimizationState` and stores information for the current +iteration of the solver and `loss_val` is loss/objective value. For more +information about the fields of the `state` look at the `OptimizationState` +documentation. The callback should return a Boolean value, and the default +should be `false`, so the optimization stops if it returns `true`. + +### Callback Example + +Here we show an example of a callback function that plots the prediction at the current value of the optimization variables. +For a visualization callback, we would need the prediction at the current parameters i.e. the solution of the `ODEProblem` `prob`. +So we call the `predict` function within the callback again. + +```julia +function predict(u) + Array(solve(prob, Tsit5(), p = u)) +end + +function loss(u, p) + pred = predict(u) + sum(abs2, batch .- pred) +end + +callback = function (state, l; doplot = false) #callback function to observe training + display(l) + # plot current prediction against data + if doplot + pred = predict(state.u) + pl = scatter(t, ode_data[1, :], label = "data") + scatter!(pl, t, pred[1, :], label = "prediction") + display(plot(pl)) + end + return false +end +``` + +If the chosen method is a global optimizer that employs a local optimization +method, a similar set of common local optimizer arguments exists. Look at `MLSL` or `AUGLAG` +from NLopt for an example. The common local optimizer arguments are: + + - `local_method`: optimizer used for local optimization in global method + - `local_maxiters`: the maximum number of iterations + - `local_maxtime`: the maximum amount of time (in seconds) the optimization runs for + - `local_abstol`: absolute tolerance in changes of the objective value + - `local_reltol`: relative tolerance in changes of the objective value + - `local_options`: `NamedTuple` of keyword arguments for local optimizer +""" +function SciMLBase.solve(prob::OptimizationProblem, alg, args...; + kwargs...)::SciMLBase.AbstractOptimizationSolution + if supports_opt_cache_interface(alg) + SciMLBase.solve!(init(prob, alg, args...; kwargs...)) + else + if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0)) + throw(NonConcreteEltypeError(eltype(prob.u0))) + end + _check_opt_alg(prob, alg; kwargs...) + SciMLBase.__solve(prob, alg, args...; kwargs...) + end +end + +""" +```julia +init(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm, args...; kwargs...) +``` + +## Keyword Arguments + +The arguments to `init` are the same as to `solve` and common across all of the optimizers. +These common arguments are: + + - `maxiters` (the maximum number of iterations) + - `maxtime` (the maximum of time the optimization runs for) + - `abstol` (absolute tolerance in changes of the objective value) + - `reltol` (relative tolerance in changes of the objective value) + - `callback` (a callback function) + +Some optimizer algorithms have special keyword arguments documented in the +solver portion of the documentation and their respective documentation. +These arguments can be passed as `kwargs...` to `init`. + +See also [`solve(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref) +""" +function SciMLBase.init(prob::OptimizationProblem, alg, args...; kwargs...)::SciMLBase.AbstractOptimizationCache + if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0)) + throw(NonConcreteEltypeError(eltype(prob.u0))) + end + _check_opt_alg(prob::OptimizationProblem, alg; kwargs...) + cache = SciMLBase.__init(prob, alg, args...; prob.kwargs..., kwargs...) + return cache +end + +""" +```julia +solve!(cache::AbstractOptimizationCache) +``` + +Solves the given optimization cache. + +See also [`init(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref) +""" +function SciMLBase.solve!(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution + SciMLBase.__solve(cache) +end + # Algorithm compatibility checking function function _check_opt_alg(prob::SciMLBase.OptimizationProblem, alg; kwargs...) !SciMLBase.allowsbounds(alg) && (!isnothing(prob.lb) || !isnothing(prob.ub)) && @@ -64,15 +205,15 @@ end # Base solver dispatch functions (these will be extended by specific solver packages) supports_opt_cache_interface(alg) = false -function __solve(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution +function SciMLBase.__solve(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution throw(OptimizerMissingError(cache.opt)) end -function __init(prob::SciMLBase.OptimizationProblem, alg, args...; +function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, alg, args...; kwargs...)::SciMLBase.AbstractOptimizationCache throw(OptimizerMissingError(alg)) end -function __solve(prob::SciMLBase.OptimizationProblem, alg, args...; kwargs...) +function SciMLBase.__solve(prob::SciMLBase.OptimizationProblem, alg, args...; kwargs...) throw(OptimizerMissingError(alg)) end \ No newline at end of file