Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"

[sources]
SciMLBase = {url = "https://github.com/SciML/SciMLBase.jl", rev = "master"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's to run CI on SciMLBase master, since master has the solve dispatches for OptimizationProblem removed, but latest release still has them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just say that. Let's do this right, none of this weird workaround stuff.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought doing a release as it is would break Optimization.jl?

OptimizationBase = {path = "lib/OptimizationBase"}
OptimizationLBFGSB = {path = "lib/OptimizationLBFGSB"}
OptimizationMOI = {path = "lib/OptimizationMOI"}
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationBase/src/OptimizationBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Reexport
using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra
import SciMLBase: OptimizationProblem,
OptimizationFunction, ObjSense,
MaxSense, MinSense, OptimizationStats
MaxSense, MinSense, OptimizationStats, NonConcreteEltypeError
export ObjSense, MaxSense, MinSense

using FastClosures
Expand Down
147 changes: 144 additions & 3 deletions lib/OptimizationBase/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,147 @@ function Base.showerror(io::IO, e::OptimizerMissingError)
print(e.alg)
end

"""
```julia
solve(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm,
args...; kwargs...)::OptimizationSolution
```

For information about the returned solution object, refer to the documentation for [`OptimizationSolution`](@ref)

## Keyword Arguments

The arguments to `solve` are common across all of the optimizers.
These common arguments are:

- `maxiters`: the maximum number of iterations
- `maxtime`: the maximum amount of time (typically in seconds) the optimization runs for
- `abstol`: absolute tolerance in changes of the objective value
- `reltol`: relative tolerance in changes of the objective value
- `callback`: a callback function

Some optimizer algorithms have special keyword arguments documented in the
solver portion of the documentation and their respective documentation.
These arguments can be passed as `kwargs...` to `solve`. Similarly, the special
keyword arguments for the `local_method` of a global optimizer are passed as a
`NamedTuple` to `local_options`.

Over time, we hope to cover more of these keyword arguments under the common interface.

A warning will be shown if a common argument is not implemented for an optimizer.

## Callback Functions

The callback function `callback` is a function that is called after every optimizer
step. Its signature is:

```julia
callback = (state, loss_val) -> false
```

where `state` is an `OptimizationState` and stores information for the current
iteration of the solver and `loss_val` is loss/objective value. For more
information about the fields of the `state` look at the `OptimizationState`
documentation. The callback should return a Boolean value, and the default
should be `false`, so the optimization stops if it returns `true`.

### Callback Example

Here we show an example of a callback function that plots the prediction at the current value of the optimization variables.
For a visualization callback, we would need the prediction at the current parameters i.e. the solution of the `ODEProblem` `prob`.
So we call the `predict` function within the callback again.

```julia
function predict(u)
Array(solve(prob, Tsit5(), p = u))
end

function loss(u, p)
pred = predict(u)
sum(abs2, batch .- pred)
end

callback = function (state, l; doplot = false) #callback function to observe training
display(l)
# plot current prediction against data
if doplot
pred = predict(state.u)
pl = scatter(t, ode_data[1, :], label = "data")
scatter!(pl, t, pred[1, :], label = "prediction")
display(plot(pl))
end
return false
end
```

If the chosen method is a global optimizer that employs a local optimization
method, a similar set of common local optimizer arguments exists. Look at `MLSL` or `AUGLAG`
from NLopt for an example. The common local optimizer arguments are:

- `local_method`: optimizer used for local optimization in global method
- `local_maxiters`: the maximum number of iterations
- `local_maxtime`: the maximum amount of time (in seconds) the optimization runs for
- `local_abstol`: absolute tolerance in changes of the objective value
- `local_reltol`: relative tolerance in changes of the objective value
- `local_options`: `NamedTuple` of keyword arguments for local optimizer
"""
function SciMLBase.solve(prob::OptimizationProblem, alg, args...;
kwargs...)::SciMLBase.AbstractOptimizationSolution
if supports_opt_cache_interface(alg)
SciMLBase.solve!(init(prob, alg, args...; kwargs...))
else
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0))
throw(NonConcreteEltypeError(eltype(prob.u0)))
end
_check_opt_alg(prob, alg; kwargs...)
SciMLBase.__solve(prob, alg, args...; kwargs...)
end
end

"""
```julia
init(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm, args...; kwargs...)
```

## Keyword Arguments

The arguments to `init` are the same as to `solve` and common across all of the optimizers.
These common arguments are:

- `maxiters` (the maximum number of iterations)
- `maxtime` (the maximum of time the optimization runs for)
- `abstol` (absolute tolerance in changes of the objective value)
- `reltol` (relative tolerance in changes of the objective value)
- `callback` (a callback function)

Some optimizer algorithms have special keyword arguments documented in the
solver portion of the documentation and their respective documentation.
These arguments can be passed as `kwargs...` to `init`.

See also [`solve(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
"""
function SciMLBase.init(prob::OptimizationProblem, alg, args...; kwargs...)::SciMLBase.AbstractOptimizationCache
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0))
throw(NonConcreteEltypeError(eltype(prob.u0)))
end
_check_opt_alg(prob::OptimizationProblem, alg; kwargs...)
cache = SciMLBase.__init(prob, alg, args...; prob.kwargs..., kwargs...)
return cache
end

"""
```julia
solve!(cache::AbstractOptimizationCache)
```

Solves the given optimization cache.

See also [`init(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref)
"""
function SciMLBase.solve!(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution
SciMLBase.__solve(cache)
end

# Algorithm compatibility checking function
function _check_opt_alg(prob::SciMLBase.OptimizationProblem, alg; kwargs...)
!SciMLBase.allowsbounds(alg) && (!isnothing(prob.lb) || !isnothing(prob.ub)) &&
Expand Down Expand Up @@ -64,15 +205,15 @@ end
# Base solver dispatch functions (these will be extended by specific solver packages)
supports_opt_cache_interface(alg) = false

function __solve(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution
function SciMLBase.__solve(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution
throw(OptimizerMissingError(cache.opt))
end

function __init(prob::SciMLBase.OptimizationProblem, alg, args...;
function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, alg, args...;
kwargs...)::SciMLBase.AbstractOptimizationCache
throw(OptimizerMissingError(alg))
end

function __solve(prob::SciMLBase.OptimizationProblem, alg, args...; kwargs...)
function SciMLBase.__solve(prob::SciMLBase.OptimizationProblem, alg, args...; kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

piracy, don't do indirect.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should I do here instead? Isn't this how other __solve dispatches are set up?

throw(OptimizerMissingError(alg))
end
Loading