Skip to content

Commit

Permalink
brent_newton method for lsqnonneg_gcv;
Browse files Browse the repository at this point in the history
fix failing mdp tests
  • Loading branch information
jondeuce committed Mar 29, 2024
1 parent 0bd21eb commit 200abb9
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 57 deletions.
105 changes: 65 additions & 40 deletions src/lsqnonneg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,17 +252,17 @@ end
increment_cache_index!(work::NNLSTikhonovRegProblemCache{N}) where {N} = (work.idx[] = mod1(work.idx[] + 1, N))
set_cache_index!(work::NNLSTikhonovRegProblemCache{N}, i) where {N} = (work.idx[] = mod1(i, N))
reset_cache!(work::NNLSTikhonovRegProblemCache{N}) where {N} = foreach(w -> mu!(w, NaN), work.cache)
get_cache(work::NNLSTikhonovRegProblemCache) = work.cache[work.idx[]]
Base.getindex(work::NNLSTikhonovRegProblemCache) = work.cache[work.idx[]]

function solve!(work::NNLSTikhonovRegProblemCache, μ::Real)
i = findfirst(w -> μ == mu(w), work.cache)
if i === nothing
increment_cache_index!(work)
solve!(get_cache(work), μ)
solve!(work[], μ)
else
set_cache_index!(work, i)
end
return solution(get_cache(work))
return solution(work[])
end

####
Expand All @@ -284,8 +284,8 @@ function NNLSChi2RegProblem(A::AbstractMatrix{T}, b::AbstractVector{T}) where {T
return NNLSChi2RegProblem(A, b, m, n, nnls_prob, nnls_prob_smooth_cache)
end

@inline solution(work::NNLSChi2RegProblem) = solution(get_cache(work.nnls_prob_smooth_cache))
@inline ncomponents(work::NNLSChi2RegProblem) = ncomponents(get_cache(work.nnls_prob_smooth_cache))
@inline solution(work::NNLSChi2RegProblem) = solution(work.nnls_prob_smooth_cache[])
@inline ncomponents(work::NNLSChi2RegProblem) = ncomponents(work.nnls_prob_smooth_cache[])

@doc raw"""
lsqnonneg_chi2(A::AbstractMatrix, b::AbstractVector, chi2_target::Real)
Expand Down Expand Up @@ -345,7 +345,7 @@ function lsqnonneg_chi2!(work::NNLSChi2RegProblem{T}, chi2_target::T, legacy::Bo
mu_final, res²_final = chi2_search_from_minimum(res²_min, chi2_target; legacy) do μ
μ == 0 && return res²_min
solve!(work.nnls_prob_smooth_cache, μ)
return resnorm_sq(get_cache(work.nnls_prob_smooth_cache))
return resnorm_sq(work.nnls_prob_smooth_cache[])
end
if mu_final == 0
x_final = x_unreg
Expand All @@ -356,15 +356,15 @@ function lsqnonneg_chi2!(work::NNLSChi2RegProblem{T}, chi2_target::T, legacy::Bo
elseif method === :bisect
f = function (logμ)
solve!(work.nnls_prob_smooth_cache, exp(logμ))
return chi2_relerr!(get_cache(work.nnls_prob_smooth_cache), res²_target, logμ)
return chi2_relerr!(work.nnls_prob_smooth_cache[], res²_target, logμ)
end

# Find bracketing interval containing root, then perform bisection search with slightly higher tolerance to not waste f evals
a, b, fa, fb = bracket_root_monotonic(f, T(-4.0), T(1.0); dilate = T(1.5), mono = +1, maxiters = 6) # maxiters = 100
a, b, fa, fb = bracket_root_monotonic(f, T(-4.0), T(1.0); dilate = T(1.5), mono = +1, maxiters = 6)

if fa * fb < 0
# Bracketing interval found
a, fa, c, fc, b, fb = bisect_root(f, a, b, fa, fb; xatol = T(0.05), xrtol = T(0.0), ftol = (chi2_target - 1) / 100) # maxiters = 100
a, fa, c, fc, b, fb = bisect_root(f, a, b, fa, fb; xatol = T(0.05), xrtol = T(0.0), ftol = T(1e-2) * (chi2_target - 1), maxiters = 100)

# Root of secant line through `(a, fa), (b, fb)` or `(c, fc), (b, fb)` to improve bisection accuracy
tmp = fa * fc < 0 ? root_real_linear(a, c, fa, fc) : fc * fb < 0 ? root_real_linear(c, b, fc, fb) : T(NaN)
Expand All @@ -387,15 +387,15 @@ function lsqnonneg_chi2!(work::NNLSChi2RegProblem{T}, chi2_target::T, legacy::Bo
elseif method === :brent
f = function (logμ)
solve!(work.nnls_prob_smooth_cache, exp(logμ))
return chi2_relerr!(get_cache(work.nnls_prob_smooth_cache), res²_target, logμ)
return chi2_relerr!(work.nnls_prob_smooth_cache[], res²_target, logμ)
end

# Find bracketing interval containing root
a, b, fa, fb = bracket_root_monotonic(f, T(-4.0), T(1.0); dilate = T(1.5), mono = +1, maxiters = 100)

if fa * fb < 0
# Find root using Brent's method
logmu_final, relerr_final = brent_root(f, a, b, fa, fb; xatol = T(0.0), xrtol = T(0.0), ftol = (chi2_target - 1) / 1000, maxiters = 100)
logmu_final, relerr_final = brent_root(f, a, b, fa, fb; xatol = T(0.0), xrtol = T(0.0), ftol = T(1e-3) * (chi2_target - 1), maxiters = 100)
else
# No bracketing interval found; choose point with smallest value of f (note: this branch should never be reached)
logmu_final, relerr_final = !isfinite(fa) ? (b, fb) : !isfinite(fb) ? (a, fa) : abs(fa) < abs(fb) ? (a, fa) : (b, fb)
Expand Down Expand Up @@ -476,8 +476,8 @@ function NNLSMDPRegProblem(A::AbstractMatrix{T}, b::AbstractVector{T}) where {T}
return NNLSMDPRegProblem(A, b, m, n, nnls_prob, nnls_prob_smooth_cache)
end

@inline solution(work::NNLSMDPRegProblem) = solution(get_cache(work.nnls_prob_smooth_cache))
@inline ncomponents(work::NNLSMDPRegProblem) = ncomponents(get_cache(work.nnls_prob_smooth_cache))
@inline solution(work::NNLSMDPRegProblem) = solution(work.nnls_prob_smooth_cache[])
@inline ncomponents(work::NNLSMDPRegProblem) = ncomponents(work.nnls_prob_smooth_cache[])

@doc raw"""
lsqnonneg_mdp(A::AbstractMatrix, b::AbstractVector, δ::Real)
Expand Down Expand Up @@ -539,15 +539,15 @@ function lsqnonneg_mdp!(work::NNLSMDPRegProblem{T}, δ::T) where {T}

function f(logμ)
solve!(work.nnls_prob_smooth_cache, exp(logμ))
return resnorm_sq(get_cache(work.nnls_prob_smooth_cache)) - δ^2
return resnorm_sq(work.nnls_prob_smooth_cache[]) - δ^2
end

# Find bracketing interval containing root
a, b, fa, fb = bracket_root_monotonic(f, T(-4.0), T(1.0); dilate = T(1.5), mono = +1, maxiters = 100)

if fa * fb < 0
# Find root using Brent's method
logmu_final, err_final = brent_root(f, a, b, fa, fb; xatol = T(0.0), xrtol = T(0.0), ftol = δ / 1000, maxiters = 100)
logmu_final, err_final = brent_root(f, a, b, fa, fb; xatol = T(0.0), xrtol = T(0.0), ftol = T(1e-3) * δ, maxiters = 100)
else
# No bracketing interval found; choose point with smallest value of f (note: this branch should never be reached)
logmu_final, err_final = !isfinite(fa) ? (b, fb) : !isfinite(fb) ? (a, fa) : abs(fa) < abs(fb) ? (a, fa) : (b, fb)
Expand Down Expand Up @@ -589,8 +589,8 @@ function NNLSLCurveRegProblem(A::AbstractMatrix{T}, b::AbstractVector{T}) where
return NNLSLCurveRegProblem(A, b, m, n, nnls_prob, nnls_prob_smooth_cache, lsqnonneg_lcurve_fun_cache, lcurve_corner_caches)
end

@inline solution(work::NNLSLCurveRegProblem) = solution(get_cache(work.nnls_prob_smooth_cache))
@inline ncomponents(work::NNLSLCurveRegProblem) = ncomponents(get_cache(work.nnls_prob_smooth_cache))
@inline solution(work::NNLSLCurveRegProblem) = solution(work.nnls_prob_smooth_cache[])
@inline ncomponents(work::NNLSLCurveRegProblem) = ncomponents(work.nnls_prob_smooth_cache[])

@doc raw"""
lsqnonneg_lcurve(A::AbstractMatrix, b::AbstractVector)
Expand Down Expand Up @@ -626,7 +626,7 @@ function lsqnonneg_lcurve(A::AbstractMatrix, b::AbstractVector)
end
lsqnonneg_lcurve_work(A::AbstractMatrix, b::AbstractVector) = NNLSLCurveRegProblem(A, b)

function lsqnonneg_lcurve!(work::NNLSLCurveRegProblem{T, N}) where {T, N}
function lsqnonneg_lcurve!(work::NNLSLCurveRegProblem{T}) where {T}
# Compute the regularization using the L-curve method
reset_cache!(work.nnls_prob_smooth_cache)

Expand All @@ -635,8 +635,8 @@ function lsqnonneg_lcurve!(work::NNLSLCurveRegProblem{T, N}) where {T, N}
# this scales the L-curve, but does not change μ* = argmax C(ξ(μ), η(μ)).
function f_lcurve(logμ)
solve!(work.nnls_prob_smooth_cache, exp(logμ))
ξ = log(resnorm_sq(get_cache(work.nnls_prob_smooth_cache)))
η = log(seminorm_sq(get_cache(work.nnls_prob_smooth_cache)))
ξ = log(resnorm_sq(work.nnls_prob_smooth_cache[]))
η = log(seminorm_sq(work.nnls_prob_smooth_cache[]))
return SA{T}[ξ, η]
end

Expand All @@ -651,7 +651,7 @@ function lsqnonneg_lcurve!(work::NNLSLCurveRegProblem{T, N}) where {T, N}
mu_final = exp(logmu_final)
x_final = solve!(work.nnls_prob_smooth_cache, mu_final)
x_unreg = solve!(work.nnls_prob)
chi2_final = resnorm_sq(get_cache(work.nnls_prob_smooth_cache)) / resnorm_sq(work.nnls_prob)
chi2_final = resnorm_sq(work.nnls_prob_smooth_cache[]) / resnorm_sq(work.nnls_prob)

return (; x = x_final, mu = mu_final, chi2 = chi2_final)
end
Expand Down Expand Up @@ -989,8 +989,8 @@ function NNLSGCVRegProblem(A::AbstractMatrix{T}, b::AbstractVector{T}) where {T}
return NNLSGCVRegProblem(A, b, m, n, γ, svd_work, nnls_prob, nnls_prob_smooth_cache)
end

@inline solution(work::NNLSGCVRegProblem) = solution(get_cache(work.nnls_prob_smooth_cache))
@inline ncomponents(work::NNLSGCVRegProblem) = ncomponents(get_cache(work.nnls_prob_smooth_cache))
@inline solution(work::NNLSGCVRegProblem) = solution(work.nnls_prob_smooth_cache[])
@inline ncomponents(work::NNLSGCVRegProblem) = ncomponents(work.nnls_prob_smooth_cache[])

@doc raw"""
lsqnonneg_gcv(A::AbstractMatrix, b::AbstractVector)
Expand Down Expand Up @@ -1030,22 +1030,35 @@ Details of the GCV method can be found in Hansen (1992)[1].
1. Hansen, P.C., 1992. Analysis of Discrete Ill-Posed Problems by Means of the L-Curve. SIAM Review, 34(4), 561-580, https://doi.org/10.1137/1034115.
"""
function lsqnonneg_gcv(A::AbstractMatrix, b::AbstractVector)
function lsqnonneg_gcv(A::AbstractMatrix, b::AbstractVector; kwargs...)
work = lsqnonneg_gcv_work(A, b)
return lsqnonneg_gcv!(work)
return lsqnonneg_gcv!(work; kwargs...)
end
lsqnonneg_gcv_work(A::AbstractMatrix, b::AbstractVector) = NNLSGCVRegProblem(A, b)

function lsqnonneg_gcv!(work::NNLSGCVRegProblem{T, N}; method = :brent) where {T, N}
function lsqnonneg_gcv!(work::NNLSGCVRegProblem{T}; method = :brent, init = -4.0, bounds = (-8.0, 2.0), rtol = 0.05, atol = 1e-4, maxiters = 10) where {T}
# Find μ by minimizing the function G(μ) (GCV method)
reset_cache!(work.nnls_prob_smooth_cache)
@assert bounds[1] < init < bounds[2] "Initial value must be within bounds"
logμ₋, logμ₊ = T.(bounds)
logμ₀ = T(init)

# Precompute singular values for GCV computation
svdvals!(work.svd_work, work.A)

# Non-zero lower bound for GCV to avoid log(0) in the objective function
gcv_low = gcv_lower_bound(work)

# Objective functions
reset_cache!(work.nnls_prob_smooth_cache)
function log𝒢(logμ)
return log(max(gcv!(work, logμ), gcv_low))
end
function log𝒢_and_∇log𝒢(logμ)
𝒢, ∇𝒢 = gcv_and_∇gcv!(work, logμ)
𝒢 = max(𝒢, gcv_low)
return log(𝒢), ∇𝒢 / 𝒢
end

if method === :nlopt
# alg = :LN_COBYLA # local, gradient-free, linear approximation of objective
alg = :LN_BOBYQA # local, gradient-free, quadratic approximation of objective
Expand All @@ -1054,18 +1067,30 @@ function lsqnonneg_gcv!(work::NNLSGCVRegProblem{T, N}; method = :brent) where {T
# alg = :LN_SBPLX # local, gradient-free, subspace searching simplex method
# alg = :LD_CCSAQ # local, first-order (rough ranking: [:LD_MMA, :LD_SLSQP, :LD_LBFGS, :LD_CCSAQ, :LD_AUGLAG])
opt = NLopt.Opt(alg, 1)
opt.lower_bounds = -8.0
opt.upper_bounds = 2.0
opt.xtol_abs = 1e-4
opt.xtol_rel = 1e-4
opt.lower_bounds = Float64(logμ₋)
opt.upper_bounds = Float64(logμ₊)
opt.xtol_abs = Float64(atol)
opt.xtol_rel = Float64(rtol)
opt.ftol_abs = 0.0
opt.ftol_rel = 0.0
opt.min_objective = (logμ, ∇logμ) -> @inbounds Float64(log(max(gcv!(work, logμ[1]), gcv_low)))
minf, minx, ret = NLopt.optimize(opt, [-4.0])
opt.min_objective = (logμ, ∇logμ) -> @inbounds Float64(log𝒢(T(logμ[1])))
minf, minx, ret = NLopt.optimize(opt, Float64[logμ₀])
logmu_final = @inbounds T(minx[1])
log𝒢_final = T(minf)
elseif method === :brent
logmu_final, _ = brent_minimize(-8.0, 2.0; xrtol = T(0.05), xatol = T(1e-4), maxiters = 10) do logμ
return log(max(gcv!(work, logμ), gcv_low))
logmu_final, log𝒢_final = brent_minimize(log𝒢, logμ₋, logμ₊; xrtol = T(rtol), xatol = T(atol), maxiters)
elseif method === :brent_newton
log𝒢₋, ∇log𝒢₋ = log𝒢_and_∇log𝒢(logμ₋)
log𝒢₊, ∇log𝒢₊ = log𝒢_and_∇log𝒢(logμ₊)
logμ_bdry, log𝒢_bdry = log𝒢₋ < log𝒢₊ ? (logμ₋, log𝒢₋) : (logμ₊, log𝒢₊)
if ∇log𝒢₋ < 0 && ∇log𝒢₊ > 0
log𝒢₀, ∇log𝒢₀ = log𝒢_and_∇log𝒢(logμ₀)
logmu_final, log𝒢_final = brent_newton_minimize(log𝒢_and_∇log𝒢, logμ₋, logμ₊, logμ₀, log𝒢₀, ∇log𝒢₀; xrtol = T(rtol), xatol = T(atol), maxiters)
else
logmu_final, log𝒢_final = logμ_bdry, log𝒢_bdry
end
if log𝒢_bdry < log𝒢_final
logmu_final, log𝒢_final = logμ_bdry, log𝒢_bdry
end
else
error("Unknown minimization method: $method")
Expand All @@ -1075,7 +1100,7 @@ function lsqnonneg_gcv!(work::NNLSGCVRegProblem{T, N}; method = :brent) where {T
mu_final = exp(logmu_final)
x_final = solve!(work.nnls_prob_smooth_cache, mu_final)
x_unreg = solve!(work.nnls_prob)
chi2_final = resnorm_sq(get_cache(work.nnls_prob_smooth_cache)) / resnorm_sq(work.nnls_prob)
chi2_final = resnorm_sq(work.nnls_prob_smooth_cache[]) / resnorm_sq(work.nnls_prob)

return (; x = x_final, mu = mu_final, chi2 = chi2_final)
end
Expand All @@ -1093,7 +1118,7 @@ function gcv!(work::NNLSGCVRegProblem, logμ)
# Solve regularized NNLS problem
μ = exp(logμ)
solve!(work.nnls_prob_smooth_cache, μ)
cache = get_cache(work.nnls_prob_smooth_cache)
cache = work.nnls_prob_smooth_cache[]

# Compute GCV
res² = resnorm_sq(cache) # squared residual norm ||A * x(μ) - b||^2
Expand All @@ -1110,7 +1135,7 @@ function gcv_and_∇gcv!(work::NNLSGCVRegProblem, logμ)
# Solve regularized NNLS problem
μ = exp(logμ)
solve!(work.nnls_prob_smooth_cache, μ)
cache = get_cache(work.nnls_prob_smooth_cache)
cache = work.nnls_prob_smooth_cache[]

# Compute primal
res² = resnorm_sq(cache) # squared residual norm ||A * x(μ) - b||^2
Expand Down Expand Up @@ -1145,12 +1170,12 @@ function gcv!(work::NNLSGCVRegProblem, logμ, ::Val{extract_subproblem} = Val(fa
# Solve regularized NNLS problem and record residual norm ||A * x(μ) - b||^2
μ = exp(logμ)
solve!(work.nnls_prob_smooth_cache, μ)
res² = resnorm_sq(get_cache(work.nnls_prob_smooth_cache))
res² = resnorm_sq(work.nnls_prob_smooth_cache[])
if extract_subproblem
# Extract equivalent unconstrained least squares subproblem from NNLS problem
# by extracting columns of A which correspond to nonzero components of x(μ)
idx = NNLS.components(get_cache(work.nnls_prob_smooth_cache).nnls_prob.nnls_work)
idx = NNLS.components(work.nnls_prob_smooth_cache[].nnls_prob.nnls_work)
n′ = length(idx)
A′ = reshape(view(A_buf, 1:m*n′), m, n′)
At′ = reshape(view(Aᵀ_buf, 1:n′*m), n′, m)
Expand Down
26 changes: 12 additions & 14 deletions test/interactive/compare/compare.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ for settings in settings_files
histogram(
log10.(resnorm);
xlabel = "log10(resnorm)", ylabel = "count",
title = "resnorm = $(median(resnorm)) ± $(iqr(resnorm)), log10(resnorm) = $(median(log10.(resnorm))) ± $(iqr(log10.(resnorm)))",
title = "resnorm = $(median(resnorm)) ± $(iqr(resnorm) / 2), log10(resnorm) = $(median(log10.(resnorm))) ± $(iqr(log10.(resnorm)) / 2)",
nbins = 64, vertical = true, height = 10, width = 80,
) |> display
histogram(
cosd.(alpha);
xlabel = "cosd(alpha)", ylabel = "count",
title = "alpha = $(median(alpha)) ± $(iqr(alpha)), cosd(alpha) = $(median(cosd.(alpha))) ± $(iqr(cosd.(alpha)))",
title = "alpha = $(median(alpha)) ± $(iqr(alpha) / 2), cosd(alpha) = $(median(cosd.(alpha))) ± $(iqr(cosd.(alpha)) / 2)",
nbins = 64, vertical = true, height = 10, width = 80,
) |> display
end
Expand All @@ -69,8 +69,8 @@ for settings in settings_files
for file in readdir(outputs[i]; join = false, sort = true)
endswith(file, ".mat") || continue
@assert isfile(joinpath(outputs[i+1], file)) "File $(file) not found in $(outputs[i+1])"
file1 = joinpath(outputs[i], file)
file2 = joinpath(outputs[i+1], file)
global file1 = joinpath(outputs[i], file)
global file2 = joinpath(outputs[i+1], file)

@info "Comparing file: $(file)"
global data1 = matread(file1)
Expand All @@ -86,7 +86,8 @@ for settings in settings_files
I = intersect(I, I2)
end

global x1, x2 = data1[key][I], data2[key][I]
global x1 = data1[key][I]
global x2 = data2[key][I]
if key == "mu"
x1 .= log.(x1)
x2 .= log.(x2)
Expand All @@ -105,16 +106,13 @@ for settings in settings_files
else
@error "$(key => err): relative error is large"
global dx = abs.(x1 .- x2) ./ mean(abs, x1 .- mean(x1))
dx_q99 = quantile(dx, 0.99)
dx_low = 0.0
dx_high = quantile(dx, 0.99)
nz = count(iszero, dx)
n99 = count(>(dx_q99), dx)
dxhist = filter(x -> 0 < x <= dx_q99, dx)
title = "$key: nz = $nz, n99 = $n99, dx_q99 = $dx_q99"
if isempty(dxhist)
@info title
else
histogram(filter(x -> 0 < x <= dx_q99, dx); nbins = 64, vertical = true, height = 10, width = 80, title) |> display
end
n99 = count(>(dx_high), dx)
dxhist = filter(x -> dx_low < x < dx_high, dx)
@info "$key: nz = $nz, n99 = $n99, $dx_low < dx < $dx_high"
!isempty(dxhist) && display(histogram(dxhist; nbins = 64, vertical = true, height = 10, width = 80))
end
end
println()
Expand Down
Loading

0 comments on commit 200abb9

Please sign in to comment.