Skip to content

Commit

Permalink
Merge pull request #4 from slimgroup/wl1
Browse files Browse the repository at this point in the history
add weighted l1 to bregman
  • Loading branch information
mloubout authored Jul 15, 2022
2 parents 23b046f + 352470b commit 9d78d7e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
22 changes: 13 additions & 9 deletions src/bregman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,26 @@ Options structure for the bregman iteration algorithm
- `λfunc`: a function to calculate threshold value, default is nothing
- `λ`: a pre-set threshold, will only be used if `λfunc` is not defined, default is nothing
- `quantile`: a percentage to calculate the threshold by quantile of the dual variable in 1st iteration, will only be used if neither `λfunc` nor `λ` are defined, default is .95 i.e thresholds 95% of the vector
- `w`: a weight vector that is applied on the threshold element-wise according to relaxation of weighted l1, default is 1 (no weighting)
"""
function bregman_options(;verbose=1, progTol=1e-8, maxIter=20, store_trace=false, antichatter=true, alpha=.5, spg=false, TD=LinearAlgebra.I, quantile=.95, λ=nothing, λfunc=nothing)
function bregman_options(;verbose=1, progTol=1e-8, maxIter=20, store_trace=false, antichatter=true, alpha=.5, spg=false, TD=LinearAlgebra.I, quantile=.95, λ=nothing, λfunc=nothing, w=1)
if isnothing(λfunc)
if ~isnothing(λ)
λfunc = z->λ
else
λfunc = z->Statistics.quantile(abs.(z), quantile)
end
end
return BregmanParams(verbose, progTol, maxIter, store_trace, antichatter, alpha, spg, TD, λfunc)
return BregmanParams(verbose, progTol, maxIter, store_trace, antichatter, alpha, spg, TD, z->w.*λfunc(z))
end

"""
bregman(A, x, b, options)
Linearized bregman iteration for the system
``\\frac{1}{2} ||TD \\ x||_2^2 + λ ||TD \\ x||_1 \\ \\ \\ s.t Ax = b``
``\\frac{1}{2} ||TD \\ x||_2^2 + λ ||TD \\ x||_{1,w} \\ \\ \\ s.t Ax = b``
For example, for sparsity promoting denoising (i.e LSRTM)
Expand All @@ -61,11 +62,13 @@ For example, for sparsity promoting denoising (i.e LSRTM)
- `b`: observed data
# Optional Arguments
- `callback` : Callback function. Must take as input a `result` callback(x::result)
# Non-required arguments
- `options`: bregman options, default is bregman_options(); options.TD provides the sparsifying transform (e.g. curvelet)
- `options`: bregman options, default is bregman_options(); options.TD provides the sparsifying transform (e.g. curvelet), options.w provides the weight vector for the weighted l1
"""
function bregman(A, x::AbstractVector{T1}, b::AbstractVector{T2}, options::BregmanParams=bregman_options(); callback=noop_callback) where {T1<:Number, T2<:Number}
# residual function wrapper
Expand All @@ -89,20 +92,21 @@ end
Linearized bregman iteration for the system
``\\frac{1}{2} ||TD \\ x||_2^2 + λ ||TD \\ x||_1 \\ \\ \\ s.t Ax = b``
``\\frac{1}{2} ||TD \\ x||_2^2 + λ ||TD \\ x||_{1,w} \\ \\ \\ s.t Ax = b``
# Required arguments
- `funobj`: a function that calculates the objective value (`0.5 * norm(Ax-b)^2`) and the gradient (`A'(Ax-b)`)
- `x`: Initial guess
# Non-required arguments
- `options`: bregman options, default is bregman_options(); options.TD provides the sparsifying transform (e.g. curvelet)
# Optional Arguments
- `callback` : Callback function. Must take as input a `result` callback(x::result)
# Non-required arguments
- `options`: bregman options, default is bregman_options(); options.TD provides the sparsifying transform (e.g. curvelet), options.w provides the weight vector for the weighted l1
"""
function bregman(funobj::Function, x::AbstractVector{T}, options::BregmanParams=bregman_options(); callback=noop_callback) where {T}
# Output Parameter Settings
Expand Down
6 changes: 2 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,5 @@ function subHv(p::AbstractArray{T}, x::AbstractArray{T}, g::AbstractArray{T}, Hv
end

# THresholding
soft_thresholding(x::AbstractArray{Complex{T}}, λ::T) where {T} = exp.(angle.(x)im) .* max.(abs.(x) .- convert(T, λ), T(0))
soft_thresholding(x::AbstractArray{Complex{T}}, λ::Array{T}) where {T} = exp.(angle.(x)im) .* max.(abs.(x) .- convert(Array{T}, λ), T(0))
soft_thresholding(x::AbstractArray{T}, λ::T) where {T} = sign.(x) .* max.(abs.(x) .- convert(T, λ), T(0))
soft_thresholding(x::AbstractArray{T}, λ::Array{T}) where {T} = sign.(x) .* max.(abs.(x) .- convert(Array{T}, λ), T(0))
soft_thresholding(x::AbstractArray{Complex{T}}, λ::Union{T, Array{T}}) where {T} = exp.(angle.(x)im) .* max.(abs.(x) .- λ, T(0))
soft_thresholding(x::AbstractArray{T}, λ::Union{Array{T}, T}) where {T} = sign.(x) .* max.(abs.(x) .- λ, T(0))
4 changes: 2 additions & 2 deletions test/test_bregman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using LinearAlgebra
N1 = 100
N2 = div(N1, 2) + 5

@testset "Bregman test for type $(T)" for T = [Float32, ComplexF32]
@testset "Bregman test for type $(T) and weighted $(weighted)" for (T, weighted) in [(Float32, true), (ComplexF32, true), (Float32, false), (ComplexF32, false)]

A = randn(T, N1, N2)
x0 = 10 .* randn(T, N2)
Expand All @@ -22,7 +22,7 @@ N2 = div(N1, 2) + 5
return fun, grad
end

opt = bregman_options(maxIter=200, progTol=0, verbose=2, antichatter=T==Float32)
opt = weighted ? bregman_options(maxIter=200, progTol=0, verbose=2, antichatter=T==Float32, w=Float32.(x0.==0)) : bregman_options(maxIter=200, progTol=0, verbose=2, antichatter=T==Float32)
sol = bregman(obj, 1 .+ randn(T, N2), opt)

@show sol.x[inds]
Expand Down

0 comments on commit 9d78d7e

Please sign in to comment.