Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prevent f32 overflow #4

Merged
merged 1 commit into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
fail-fast: false

matrix:
version: ['1.5', '1.6', '1.7', 'nightly']
version: ['1.6', '1.7']
os: [ubuntu-latest]

include:
Expand Down
2 changes: 1 addition & 1 deletion src/PARSDMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ for i=1:maxit #main loop
end #end Q-update timer

if i==maxit
println("PARSDMM reached maxit")
constr_log("PARSDMM reached maxit")
(TD_OP,AtA,log_PARSDMM) = output_check_PARSDMM(x,TD_OP,AtA,log_PARSDMM,i,counter)
end

Expand Down
4 changes: 2 additions & 2 deletions src/PARSDMM_initialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ function PARSDMM_initialize(
end
end
if maximum(feasibility_initial)<options.feas_tol #accept input as feasible and return
println("input to PARSDMM is feasible, returning")
constr_log("input to PARSDMM is feasible, returning")
stop = true
end

# if one of the sets is non-convex, use different lambda and rho update frequency, don't update gamma and set a different fixed gamma
for ii=1:pp
if set_Prop.ncvx[ii] == true
println("non-convex set(s) involved, using special settings")
constr_log("non-convex set(s) involved, using special settings")
rho_update_frequency = 3;
adjust_gamma = false
gamma_ini = TF(0.75)
Expand Down
3 changes: 3 additions & 0 deletions src/SetIntersectionProjection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ using TimerOutputs

export log_type_PARSDMM, set_properties, PARSDMM_options, set_definitions

const _verbose = false
constr_log(msg...) = _verbose ? nothing : println(msg...)

#main scripts
include("PARSDMM.jl")
include("PARSDMM_multi_level.jl")
Expand Down
18 changes: 9 additions & 9 deletions src/cg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ function cg(A::Function,b::Vector{TF}; tol::Real=1e-2,maxIter::Integer=100,M::Fu


if out==2
println("=== cg ===")
println(@sprintf("%4s\t%7s","iter","relres"))
constr_log("=== cg ===")
constr_log(@sprintf("%4s\t%7s","iter","relres"))
end

resvec = zeros(TF,maxIter)
Expand Down Expand Up @@ -99,7 +99,7 @@ function cg(A::Function,b::Vector{TF}; tol::Real=1e-2,maxIter::Integer=100,M::Fu
#resvec[iter] = BLAS.nrm2(n, r, 1) / nr0#
resvec[iter] = norm(r)/nr0
if out==2
println(iter,resvec[iter])
constr_log(iter,resvec[iter])
end
if resvec[iter] <= tol
flag = 0; break
Expand All @@ -116,12 +116,12 @@ function cg(A::Function,b::Vector{TF}; tol::Real=1e-2,maxIter::Integer=100,M::Fu

if out>=0
if flag==-1
println("cg iterated maxIter (=%d) times but reached only residual norm %1.2e instead of tol=%1.2e.",
constr_log("cg iterated maxIter (=%d) times but reached only residual norm %1.2e instead of tol=%1.2e.",
maxIter,resvec[lastIter],tol)
elseif flag==-2
println("Matrix A in cg has to be positive definite.")
constr_log("Matrix A in cg has to be positive definite.")
elseif flag==0 && out>=1
println("cg achieved desired tolerance at iteration %d. Residual norm is %1.2e.",lastIter,resvec[lastIter])
constr_log("cg achieved desired tolerance at iteration %d. Residual norm is %1.2e.",lastIter,resvec[lastIter])
end
end
return x,flag,resvec[lastIter],lastIter,resvec[1:lastIter]
Expand Down Expand Up @@ -194,12 +194,12 @@ end
#
# if out>=0
# if flag==-1
# println("cg iterated maxIter (=%d) times but reached only residual norm %1.2e instead of tol=%1.2e.",
# constr_log("cg iterated maxIter (=%d) times but reached only residual norm %1.2e instead of tol=%1.2e.",
# maxIter,resvec[lastIter],tol)
# elseif flag==-2
# println("Matrix A in cg has to be positive definite.")
# constr_log("Matrix A in cg has to be positive definite.")
# elseif flag==0 && out>=1
# println("cg achieved desired tolerance at iteration %d. Residual norm is %1.2e.",lastIter,resvec[lastIter])
# constr_log("cg achieved desired tolerance at iteration %d. Residual norm is %1.2e.",lastIter,resvec[lastIter])
# end
# end

Expand Down
4 changes: 3 additions & 1 deletion src/default_PARSDMM_options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ export default_PARSDMM_options
"""
Returns a set of default options for the PARSDMM solver
"""
function default_PARSDMM_options(options,TF)
function default_PARSDMM_options(options,TF; verbose=false)

if TF == Float64
TI = Int64
Expand All @@ -28,5 +28,7 @@ function default_PARSDMM_options(options,TF)
options.parallel = false #comput proximal mappings, multiplier updates, rho and gamma updates in parallel
options.zero_ini_guess = true #zero initial guess for primal, auxilliary, and multipliers
Minkowski = false #the intersection of sets includes a Minkowski set

_verbose = verbose
return options
end
17 changes: 6 additions & 11 deletions src/projectors/project_l1_Duchi!.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function project_l1_Duchi!(v::Union{Vector{TF},Vector{Complex{TF}}}, b::TF) wher
u = similar(v)
sv = Vector{TF}(undef, lv)

#use RadixSort for Float32 (short keywords)
# use RadixSort for Float32 (short keywords)
copyto!(u, v)
u .= abs.(u)
u = convert(Vector{TF},u)
Expand All @@ -35,19 +35,14 @@ function project_l1_Duchi!(v::Union{Vector{TF},Vector{Complex{TF}}}, b::TF) wher
else
u = sort!(u, rev=true, alg=QuickSort)
end


# if TF==Float32
# u = sort!(abs.(u), rev=true, alg=RadixSort)
# else
# u = sort!(abs.(u), rev=true, alg=QuickSort)
# end

cumsum!(sv, u)

# Thresholding level
temp = TF(1.0):TF(1.0):TF(lv)
rho = max(1, min(lv, findlast(u .> ((sv.-b) ./ temp ) ) ))::Int
rho = 0
while u[rho+1] > ((sv[rho+1] - b)/(rho+1)) && (rho+1) < lv
rho += 1
end
rho = max(1, rho)
theta = max.(TF(0) , (sv[rho] .- b) ./ rho)::TF

# Projection as soft thresholding
Expand Down
2 changes: 1 addition & 1 deletion src/setup_multi_level_PARSDMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ for i=2:n_levels
constraint_level = constraint2coarse(constraint_level,comp_grid_levels[i],coarsening_factor)

#set up constraints on new level
println(TF)
constr_log(TF)
(P_sub_l,TD_OP_l,set_Prop_l) = setup_constraints(constraint_level,comp_grid_levels[i],TF)
(TD_OP_l,AtA_l,dummy1,dummy2) = PARSDMM_precompute_distribute(TD_OP_l,set_Prop_l,comp_grid_levels[i],options)

Expand Down
8 changes: 4 additions & 4 deletions src/stop_PARSDMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ function stop_PARSDMM(

#stop if objective value does not change and x is sufficiently feasible for all sets
if i>6 && maximum(log_PARSDMM.set_feasibility[counter-1,:])<feas_tol && maximum(abs.( (log_PARSDMM.obj[i-5:i]-log_PARSDMM.obj[i-1-5:i-1])./log_PARSDMM.obj[i-1-5:i-1] )) < obj_tol
println("stationary objective and reached feasibility, exiting PARSDMM (iteration ",i,")")
constr_log("stationary objective and reached feasibility, exiting PARSDMM (iteration ",i,")")
stop=true;
end

#stop if x doesn't change significantly anyjore
if i>5 && maximum(log_PARSDMM.evol_x[i-5:i])<evol_rel_tol
println("relative evolution to small, exiting PARSDMM (iteration ",i,")")
constr_log("relative evolution to small, exiting PARSDMM (iteration ",i,")")
stop=true;
end

# fix rho to ensure regular ADMM convergence if primal residual does not decrease over a 20 iteration window
if i>20 && adjust_rho==true && log_PARSDMM.r_pri_total[i]>maximum(log_PARSDMM.r_pri_total[(i-1):-1:max((i-50),1)])
println("no primal residual reduction, fixing PARSDMM rho & gamma (iteration ",i,")")
constr_log("no primal residual reduction, fixing PARSDMM rho & gamma (iteration ",i,")")
adjust_rho = false;
adjust_feasibility_rho = false;
adjust_gamma = false;
Expand All @@ -47,7 +47,7 @@ function stop_PARSDMM(

#if rho is fixed and still no decrease in primal residual is observed over a window, we give up
if adjust_rho==false && i>(ind_ref+25) && log_PARSDMM.r_pri_total[i]>maximum(log_PARSDMM.r_pri_total[(i-1):-1:max(ind_ref,max((i-50),1))])
println("no primal residual reduction, exiting PARSDMM (iteration ",i,")")
constr_log("no primal residual reduction, exiting PARSDMM (iteration ",i,")")
stop = true;
end
return stop,adjust_rho,adjust_gamma,adjust_feasibility_rho,ind_ref
Expand Down