Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CompatHelper: bump compat for OptimizationOptimJL to 0.2, (keep existing compat) #175

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ jobs:
num_threads:
- 1
- 2
include:
- version: '^1.10.0-rc1'
os: ubuntu-latest
arch: x64
num_threads: 1
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Pathfinder"
uuid = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
authors = ["Seth Axen <seth.axen@gmail.com> and contributors"]
version = "0.8.1"
version = "0.8.2"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down Expand Up @@ -47,7 +47,7 @@ LogDensityProblems = "2"
MCMCChains = "5, 6"
Optim = "1.4"
Optimization = "3"
OptimizationOptimJL = "0.1"
OptimizationOptimJL = "0.1, 0.2"
PDMats = "0.11.26"
PSIS = "0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9"
ProgressLogging = "0.1.4"
Expand Down
70 changes: 63 additions & 7 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,75 @@ function optimize_with_trace(
xs = typeof(u0)[]
fxs = typeof(fun.f(u0, nothing))[]
∇fxs = typeof(u0)[]
_callback = _make_optimization_callback(
xs, fxs, ∇fxs, ∇f; progress_name, progress_id, maxiters, callback, fail_on_nonfinite
_callback = OptimizationCallback(
xs, fxs, ∇fxs, ∇f, progress_name, progress_id, maxiters, callback, fail_on_nonfinite
)
sol = Optimization.solve(prob, optimizer; callback=_callback, maxiters, kwargs...)
return sol, OptimizationTrace(xs, fxs, ∇fxs)
end

function _make_optimization_callback(
xs, fxs, ∇fxs, ∇f; progress_name, progress_id, maxiters, callback, fail_on_nonfinite
)
return function (x, nfx, args...)
struct OptimizationCallback{X,FX,∇FX,∇F,ID,CB}
xs::X
fxs::FX
∇fxs::∇FX
∇f::∇F
progress_name::String
progress_id::ID
maxiters::Int
callback::CB
fail_on_nonfinite::Bool
end

@static if isdefined(Optimization, :OptimizationState)
# Optimization v3.21.0 and later
function (cb::OptimizationCallback)(state::Optimization.OptimizationState, args...)
@unpack (
xs,
fxs,
∇fxs,
∇f,
progress_name,
progress_id,
maxiters,
callback,
fail_on_nonfinite,
) = cb
ret = callback !== nothing && callback(state, args...)
iteration = state.iter
Base.@logmsg ProgressLogging.ProgressLevel progress_name progress =
iteration / maxiters _id = progress_id

x = copy(state.u)
fx = -state.objective
∇fx = state.grad === nothing ? ∇f(x) : -state.grad

# some backends mutate x, so we must copy it
push!(xs, x)
push!(fxs, fx)
push!(∇fxs, ∇fx)

if fail_on_nonfinite && !ret
ret = (isnan(fx) || fx == Inf || any(!isfinite, ∇fx))::Bool
end

return ret
end
else
# Optimization v3.20.X and earlier
function (cb::OptimizationCallback)(x, nfx, args...)
@unpack (
xs,
fxs,
∇fxs,
∇f,
progress_name,
progress_id,
maxiters,
callback,
fail_on_nonfinite,
) = cb
ret = callback !== nothing && callback(x, nfx, args...)
iteration = length(xs)
iteration = length(cb.xs)
Base.@logmsg ProgressLogging.ProgressLevel progress_name progress =
iteration / maxiters _id = progress_id

Expand Down
28 changes: 20 additions & 8 deletions test/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end
@test prob.p === nothing
end

@testset "_make_optimization_callback" begin
@testset "OptimizationCallback" begin
@testset "callback return value" begin
progress_name = "Optimizing"
progress_id = nothing
Expand All @@ -57,22 +57,34 @@ end
g[end] = gval
return g
end
callback = (x, fx, args...) -> cbfail
cb = Pathfinder._make_optimization_callback(
should_fail =
cbfail ||
(fail_on_nonfinite && (isnan(fval) || fval == Inf || !isfinite(gval)))
if isdefined(Optimization, :OptimizationState)
# Optimization v3.21.0 and later
callback = (state, args...) -> cbfail
state = Optimization.OptimizationState(;
iter=0, u=x, objective=-fval, grad=-∇f(x)
)
cb_args = (state, -fval)
else
# Optimization v3.20.X and earlier
callback = (x, fx, args...) -> cbfail
cb_args = (x, -fval)
end
cb = Pathfinder.OptimizationCallback(
xs,
fxs,
∇fxs,
∇f;
∇f,
progress_name,
progress_id,
maxiters,
callback,
fail_on_nonfinite,
)
should_fail =
cbfail ||
(fail_on_nonfinite && (isnan(fval) || fval == Inf || !isfinite(gval)))
@test cb(x, -fval) == should_fail
@test cb isa Pathfinder.OptimizationCallback
@test cb(cb_args...) == should_fail
end
end
end
Expand Down
Loading