Skip to content

Commit

Permalink
Merge pull request #603 from chriscoey/moresteppercleanup
Browse files Browse the repository at this point in the history
more stepper cleanup and options
  • Loading branch information
chriscoey authored Oct 18, 2020
2 parents d2b9b85 + ad81997 commit 02c609c
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 47 deletions.
2 changes: 1 addition & 1 deletion examples/runbenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ solve_time_limit = 1.2 * optimizer_time_limit
setup_time_limit = optimizer_time_limit

num_threads = Threads.nthreads()
blas_num_threads = LinearAlgebra.BLAS.get_num_threads()
blas_num_threads = LinearAlgebra.BLAS.get_num_threads() # requires Julia 1.6
@show num_threads
@show blas_num_threads
println()
Expand Down
79 changes: 51 additions & 28 deletions src/Solvers/steppers/heurcomb.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ combined directions stepper
=#

mutable struct HeurCombStepper{T <: Real} <: Stepper{T}
gamma_fun::Function
prev_pred_alpha::T
prev_alpha::T
prev_gamma::T
Expand All @@ -11,10 +12,21 @@ mutable struct HeurCombStepper{T <: Real} <: Stepper{T}
dir::Point{T}
res::Point{T}
dir_temp::Vector{T}
dir_cent::Vector{T}
dir_centcorr::Vector{T}
dir_pred::Vector{T}
dir_predcorr::Vector{T}

line_searcher::LineSearcher{T}

HeurCombStepper{T}() where {T <: Real} = new{T}()
function HeurCombStepper{T}(;
gamma_fun::Function = (a::T -> (1 - a)),
# gamma_fun::Function = (a -> abs2(1 - a)),
) where {T <: Real}
stepper = new{T}()
stepper.gamma_fun = gamma_fun
return stepper
end
end

# create the stepper cache
Expand All @@ -27,7 +39,12 @@ function load(stepper::HeurCombStepper{T}, solver::Solver{T}) where {T <: Real}
stepper.rhs = Point(model)
stepper.dir = Point(model)
stepper.res = Point(model)
stepper.dir_temp = zeros(T, length(stepper.rhs.vec))
dim = length(stepper.rhs.vec)
stepper.dir_temp = zeros(T, dim)
stepper.dir_cent = zeros(T, dim)
stepper.dir_centcorr = zeros(T, dim)
stepper.dir_pred = zeros(T, dim)
stepper.dir_predcorr = zeros(T, dim)

stepper.line_searcher = LineSearcher{T}(model)

Expand All @@ -38,52 +55,58 @@ end
function step(stepper::HeurCombStepper{T}, solver::Solver{T}) where {T <: Real}
point = solver.point
model = solver.model

# update linear system solver factorization and helpers
# Cones.grad.(model.cones)
rhs = stepper.rhs
dir = stepper.dir
dir_cent = stepper.dir_cent
dir_centcorr = stepper.dir_centcorr
dir_pred = stepper.dir_pred
dir_predcorr = stepper.dir_predcorr

# update linear system solver factorization
update_lhs(solver.system_solver, solver)

# calculate centering direction and keep in dir_cent
update_rhs_cent(solver, stepper.rhs)
# calculate centering direction and correction
update_rhs_cent(solver, rhs)
get_directions(stepper, solver, false, iter_ref_steps = 3)
dir_cent = copy(stepper.dir.vec) # TODO
update_rhs_centcorr(solver, stepper.rhs, stepper.dir, add = false)
copyto!(dir_cent, dir.vec)
update_rhs_centcorr(solver, rhs, dir, add = false)
get_directions(stepper, solver, false, iter_ref_steps = 3)
dir_centcorr = copy(stepper.dir.vec) # TODO
copyto!(dir_centcorr, dir.vec)

# calculate affine/prediction direction and keep in dir
update_rhs_pred(solver, stepper.rhs)
# calculate affine/prediction direction and correction
update_rhs_pred(solver, rhs)
get_directions(stepper, solver, true, iter_ref_steps = 3)
dir_pred = copy(stepper.dir.vec) # TODO
update_rhs_predcorr(solver, stepper.rhs, stepper.dir, add = false)
copyto!(dir_pred, dir.vec)
update_rhs_predcorr(solver, rhs, dir, add = false)
get_directions(stepper, solver, true, iter_ref_steps = 3)
dir_predcorr = copy(stepper.dir.vec) # TODO
copyto!(dir_predcorr, dir.vec)

# calculate centering factor gamma by finding distance pred_alpha for stepping in pred direction
copyto!(stepper.dir.vec, dir_pred)
stepper.prev_pred_alpha = pred_alpha = find_max_alpha(point, stepper.dir, stepper.line_searcher, model, prev_alpha = stepper.prev_pred_alpha, min_alpha = T(1e-2), max_nbhd = one(T)) # TODO try max_nbhd = Inf, but careful of cones with no dual feas check

# TODO allow different function (heuristic) as option?
# stepper.prev_gamma = gamma = abs2(1 - pred_alpha)
stepper.prev_gamma = gamma = 1 - pred_alpha
copyto!(dir.vec, dir_pred)
# TODO try max_nbhd = Inf, but careful of cones with no dual feas check
stepper.prev_pred_alpha = pred_alpha = find_max_alpha(point, dir, stepper.line_searcher, model, prev_alpha = stepper.prev_pred_alpha, min_alpha = T(1e-2), max_nbhd = one(T))
stepper.prev_gamma = gamma = stepper.gamma_fun(pred_alpha)

# calculate combined direction and keep in dir
@. stepper.dir.vec = gamma * (dir_cent + pred_alpha * dir_centcorr) + (1 - gamma) * (dir_pred + pred_alpha * dir_predcorr) # TODO
gamma_alpha = gamma * pred_alpha
gamma1 = 1 - gamma
gamma1_alpha = gamma1 * pred_alpha
@. dir.vec = gamma * dir_cent + gamma_alpha * dir_centcorr + gamma1 * dir_pred + gamma1_alpha * dir_predcorr

# find distance alpha for stepping in combined direction
alpha = find_max_alpha(point, stepper.dir, stepper.line_searcher, model, prev_alpha = stepper.prev_alpha, min_alpha = T(1e-3))
alpha = find_max_alpha(point, dir, stepper.line_searcher, model, prev_alpha = stepper.prev_alpha, min_alpha = T(1e-3))

if iszero(alpha)
# could not step far in combined direction, so attempt a pure centering step
solver.verbose && println("performing centering step")
@. stepper.dir.vec = dir_cent + dir_centcorr
@. dir.vec = dir_cent + dir_centcorr

# find distance alpha for stepping in centering direction
alpha = find_max_alpha(point, stepper.dir, stepper.line_searcher, model, prev_alpha = one(T), min_alpha = T(1e-6))
alpha = find_max_alpha(point, dir, stepper.line_searcher, model, prev_alpha = one(T), min_alpha = T(1e-3))

if iszero(alpha)
copyto!(stepper.dir.vec, dir_cent)
alpha = find_max_alpha(point, stepper.dir, stepper.line_searcher, model, prev_alpha = one(T), min_alpha = T(1e-6))
copyto!(dir.vec, dir_cent)
alpha = find_max_alpha(point, dir, stepper.line_searcher, model, prev_alpha = one(T), min_alpha = T(1e-6))
if iszero(alpha)
@warn("numerical failure: could not step in centering direction; terminating")
solver.status = NumericalFailure
Expand All @@ -94,7 +117,7 @@ function step(stepper::HeurCombStepper{T}, solver::Solver{T}) where {T <: Real}
stepper.prev_alpha = alpha

# step
@. point.vec += alpha * stepper.dir.vec
@. point.vec += alpha * dir.vec
calc_mu(solver)

return true
Expand Down
35 changes: 21 additions & 14 deletions src/Solvers/steppers/predorcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ predict or center stepper
=#

mutable struct PredOrCentStepper{T <: Real} <: Stepper{T}
use_correction::Bool
prev_pred_alpha::T
prev_alpha::T
prev_is_pred::Bool
Expand All @@ -15,7 +16,13 @@ mutable struct PredOrCentStepper{T <: Real} <: Stepper{T}

line_searcher::LineSearcher{T}

PredOrCentStepper{T}() where {T <: Real} = new{T}()
function PredOrCentStepper{T}(;
use_correction::Bool = true,
) where {T <: Real}
stepper = new{T}()
stepper.use_correction = use_correction
return stepper
end
end

# create the stepper cache
Expand All @@ -37,50 +44,50 @@ function load(stepper::PredOrCentStepper{T}, solver::Solver{T}) where {T <: Real
end

function step(stepper::PredOrCentStepper{T}, solver::Solver{T}) where {T <: Real}
point = solver.point
model = solver.model
rhs = stepper.rhs
dir = stepper.dir

# update linear system solver factorization
update_lhs(solver.system_solver, solver)

# TODO option
use_corr = true
# use_corr = false
stepper.prev_is_pred = (stepper.cent_count > 3) || all(Cones.in_neighborhood.(model.cones, sqrt(solver.mu), T(0.05)))

if stepper.prev_is_pred
# predict
stepper.cent_count = 0
update_rhs_pred(solver, stepper.rhs)
update_rhs_pred(solver, rhs)
get_directions(stepper, solver, true, iter_ref_steps = 3)
if use_corr
update_rhs_predcorr(solver, stepper.rhs, stepper.dir)
if stepper.use_correction
update_rhs_predcorr(solver, rhs, dir)
get_directions(stepper, solver, true, iter_ref_steps = 3)
end
else
# center
stepper.cent_count += 1
update_rhs_cent(solver, stepper.rhs)
update_rhs_cent(solver, rhs)
get_directions(stepper, solver, false, iter_ref_steps = 3)
if use_corr
update_rhs_centcorr(solver, stepper.rhs, stepper.dir)
if stepper.use_correction
update_rhs_centcorr(solver, rhs, dir)
get_directions(stepper, solver, false, iter_ref_steps = 3)
end
end

# alpha step length
alpha = find_max_alpha(solver.point, stepper.dir, stepper.line_searcher, model, prev_alpha = one(T), min_alpha = T(1e-3), max_nbhd = T(0.99))
stepper.prev_alpha = alpha = find_max_alpha(point, dir, stepper.line_searcher, model, prev_alpha = one(T), min_alpha = T(1e-3), max_nbhd = T(0.99))

if iszero(alpha)
# TODO attempt recovery
@warn("very small alpha")
solver.status = NumericalFailure
return false
end
stepper.prev_alpha = alpha
if stepper.prev_is_pred
stepper.prev_pred_alpha = alpha
end

# step
@. solver.point.vec += alpha * stepper.dir.vec
@. point.vec += alpha * dir.vec
calc_mu(solver)

return true
Expand Down
2 changes: 1 addition & 1 deletion test/moi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ conic_exclude = String[
function test_moi(T::Type{<:Real}; solver_options...)
optimizer = MOIU.CachingOptimizer(MOIU.UniversalFallback(MOIU.Model{T}()), Hypatia.Optimizer{T}(; solver_options...))

tol = sqrt(sqrt(Float64(eps(T)))) # TODO remove Float64, waiting for https://github.com/jump-dev/MathOptInterface.jl/pull/1176
tol = sqrt(sqrt(Float64(eps(T)))) # TODO remove Float64, waiting for MOI to be tagged after https://github.com/jump-dev/MathOptInterface.jl/pull/1176
config = MOIT.TestConfig{T}(
atol = tol,
rtol = tol,
Expand Down
12 changes: 9 additions & 3 deletions test/runmoitests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,25 @@ include(joinpath(@__DIR__, "moi.jl"))
end
end

default_options = (
# verbose = true,
verbose = false,
default_tol_relax = 3,
)

@testset "MOI.Test tests" begin
println("\nstarting MOI.Test tests")
options = [
(Float64, Solvers.SymIndefSparseSystemSolver, false),
# (Float64, Solvers.QRCholDenseSystemSolver, true), # TODO fails a few
(Float64, Solvers.QRCholDenseSystemSolver, true),
# (Float32, Solvers.QRCholDenseSystemSolver, false), # TODO fails a few
# (BigFloat, Solvers.QRCholDenseSystemSolver, true), # TODO uncomment when https://github.com/jump-dev/MathOptInterface.jl/pull/1175 merged
# (BigFloat, Solvers.QRCholDenseSystemSolver, true), # TODO uncomment when MOI has been tagged
]
for (T, system_solver, use_dense_model) in options
test_info = "$system_solver, $T, $use_dense_model"
@testset "$test_info" begin
println(test_info, " ...")
test_time = @elapsed test_moi(T, use_dense_model = use_dense_model, verbose = false, system_solver = system_solver{T}())
test_time = @elapsed test_moi(T, system_solver = system_solver{T}(), use_dense_model = use_dense_model; default_options...)
@printf("%8.2e seconds\n", test_time)
end
end
Expand Down

0 comments on commit 02c609c

Please sign in to comment.