-
-
Notifications
You must be signed in to change notification settings - Fork 100
Add correct solve and init dispatches #1052
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+146
−4
Closed
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,147 @@ function Base.showerror(io::IO, e::OptimizerMissingError) | |
print(e.alg) | ||
end | ||
|
||
""" | ||
```julia | ||
solve(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm, | ||
args...; kwargs...)::OptimizationSolution | ||
``` | ||
|
||
For information about the returned solution object, refer to the documentation for [`OptimizationSolution`](@ref) | ||
|
||
## Keyword Arguments | ||
|
||
The arguments to `solve` are common across all of the optimizers. | ||
These common arguments are: | ||
|
||
- `maxiters`: the maximum number of iterations | ||
- `maxtime`: the maximum amount of time (typically in seconds) the optimization runs for | ||
- `abstol`: absolute tolerance in changes of the objective value | ||
- `reltol`: relative tolerance in changes of the objective value | ||
- `callback`: a callback function | ||
|
||
Some optimizer algorithms have special keyword arguments documented in the | ||
solver portion of the documentation and their respective documentation. | ||
These arguments can be passed as `kwargs...` to `solve`. Similarly, the special | ||
keyword arguments for the `local_method` of a global optimizer are passed as a | ||
`NamedTuple` to `local_options`. | ||
|
||
Over time, we hope to cover more of these keyword arguments under the common interface. | ||
|
||
A warning will be shown if a common argument is not implemented for an optimizer. | ||
|
||
## Callback Functions | ||
|
||
The callback function `callback` is a function that is called after every optimizer | ||
step. Its signature is: | ||
|
||
```julia | ||
callback = (state, loss_val) -> false | ||
``` | ||
|
||
where `state` is an `OptimizationState` and stores information for the current | ||
iteration of the solver and `loss_val` is loss/objective value. For more | ||
information about the fields of the `state` look at the `OptimizationState` | ||
documentation. The callback should return a Boolean value, and the default | ||
should be `false`, so the optimization stops if it returns `true`. | ||
|
||
### Callback Example | ||
|
||
Here we show an example of a callback function that plots the prediction at the current value of the optimization variables. | ||
For a visualization callback, we would need the prediction at the current parameters i.e. the solution of the `ODEProblem` `prob`. | ||
So we call the `predict` function within the callback again. | ||
|
||
```julia | ||
function predict(u) | ||
Array(solve(prob, Tsit5(), p = u)) | ||
end | ||
|
||
function loss(u, p) | ||
pred = predict(u) | ||
sum(abs2, batch .- pred) | ||
end | ||
|
||
callback = function (state, l; doplot = false) #callback function to observe training | ||
display(l) | ||
# plot current prediction against data | ||
if doplot | ||
pred = predict(state.u) | ||
pl = scatter(t, ode_data[1, :], label = "data") | ||
scatter!(pl, t, pred[1, :], label = "prediction") | ||
display(plot(pl)) | ||
end | ||
return false | ||
end | ||
``` | ||
|
||
If the chosen method is a global optimizer that employs a local optimization | ||
method, a similar set of common local optimizer arguments exists. Look at `MLSL` or `AUGLAG` | ||
from NLopt for an example. The common local optimizer arguments are: | ||
|
||
- `local_method`: optimizer used for local optimization in global method | ||
- `local_maxiters`: the maximum number of iterations | ||
- `local_maxtime`: the maximum amount of time (in seconds) the optimization runs for | ||
- `local_abstol`: absolute tolerance in changes of the objective value | ||
- `local_reltol`: relative tolerance in changes of the objective value | ||
- `local_options`: `NamedTuple` of keyword arguments for local optimizer | ||
""" | ||
function SciMLBase.solve(prob::OptimizationProblem, alg, args...; | ||
kwargs...)::SciMLBase.AbstractOptimizationSolution | ||
if supports_opt_cache_interface(alg) | ||
SciMLBase.solve!(init(prob, alg, args...; kwargs...)) | ||
else | ||
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0)) | ||
throw(NonConcreteEltypeError(eltype(prob.u0))) | ||
end | ||
_check_opt_alg(prob, alg; kwargs...) | ||
SciMLBase.__solve(prob, alg, args...; kwargs...) | ||
end | ||
end | ||
|
||
""" | ||
```julia | ||
init(prob::OptimizationProblem, alg::AbstractOptimizationAlgorithm, args...; kwargs...) | ||
``` | ||
|
||
## Keyword Arguments | ||
|
||
The arguments to `init` are the same as to `solve` and common across all of the optimizers. | ||
These common arguments are: | ||
|
||
- `maxiters` (the maximum number of iterations) | ||
- `maxtime` (the maximum of time the optimization runs for) | ||
- `abstol` (absolute tolerance in changes of the objective value) | ||
- `reltol` (relative tolerance in changes of the objective value) | ||
- `callback` (a callback function) | ||
|
||
Some optimizer algorithms have special keyword arguments documented in the | ||
solver portion of the documentation and their respective documentation. | ||
These arguments can be passed as `kwargs...` to `init`. | ||
|
||
See also [`solve(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref) | ||
""" | ||
function SciMLBase.init(prob::OptimizationProblem, alg, args...; kwargs...)::SciMLBase.AbstractOptimizationCache | ||
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0)) | ||
throw(NonConcreteEltypeError(eltype(prob.u0))) | ||
end | ||
_check_opt_alg(prob::OptimizationProblem, alg; kwargs...) | ||
cache = SciMLBase.__init(prob, alg, args...; prob.kwargs..., kwargs...) | ||
return cache | ||
end | ||
|
||
""" | ||
```julia | ||
solve!(cache::AbstractOptimizationCache) | ||
``` | ||
|
||
Solves the given optimization cache. | ||
|
||
See also [`init(prob::OptimizationProblem, alg, args...; kwargs...)`](@ref) | ||
""" | ||
function SciMLBase.solve!(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution | ||
SciMLBase.__solve(cache) | ||
end | ||
|
||
# Algorithm compatibility checking function | ||
function _check_opt_alg(prob::SciMLBase.OptimizationProblem, alg; kwargs...) | ||
!SciMLBase.allowsbounds(alg) && (!isnothing(prob.lb) || !isnothing(prob.ub)) && | ||
|
@@ -64,15 +205,15 @@ end | |
# Base solver dispatch functions (these will be extended by specific solver packages) | ||
supports_opt_cache_interface(alg) = false | ||
|
||
function __solve(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution | ||
function SciMLBase.__solve(cache::SciMLBase.AbstractOptimizationCache)::SciMLBase.AbstractOptimizationSolution | ||
throw(OptimizerMissingError(cache.opt)) | ||
end | ||
|
||
function __init(prob::SciMLBase.OptimizationProblem, alg, args...; | ||
function SciMLBase.__init(prob::SciMLBase.OptimizationProblem, alg, args...; | ||
kwargs...)::SciMLBase.AbstractOptimizationCache | ||
throw(OptimizerMissingError(alg)) | ||
end | ||
|
||
function __solve(prob::SciMLBase.OptimizationProblem, alg, args...; kwargs...) | ||
function SciMLBase.__solve(prob::SciMLBase.OptimizationProblem, alg, args...; kwargs...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. piracy, don't do indirect. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What should I do here instead? Isn't this how other __solve dispatches are set up? |
||
throw(OptimizerMissingError(alg)) | ||
end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's to run CI on SciMLBase master, since master has the solve dispatches for OptimizationProblem removed, but latest release still has them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just say that. Let's do this right, none of this weird workaround stuff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought doing a release as it is would break Optimization.jl?