Skip to content

Commit

Permalink
snlls: Fix violation of boundaries due to float-point round-off err…
Browse files Browse the repository at this point in the history
…ors (#188)

* Jacobian: add rounding to avoid float-point round-off errors

* Jacobian: use exception handling to avoid rounding in all cases

* snlls: shield linear LSQ results against float-point errors
  • Loading branch information
luisfabib authored and Luis Fabregas committed Jul 2, 2021
1 parent ccc961d commit a1ec280
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions deerlab/snlls.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ def snlls(y, Amodel, par0, lb=None, ub=None, lbl=None, ubl=None, nnlsSolver='cvx
elif nnlsSolver == 'cvx':
linSolver = lambda AtA, Aty: cvxnnls(AtA, Aty, tol=lin_tol, maxiter=lin_maxiter)
parseResult = lambda result: result

# Ensure correct formatting and shield against float-point errors
validateResult = lambda result: np.maximum(lbl,np.minimum(ubl,np.atleast_1d(result)))
# ----------------------------------------------------------

# Containers for alpha-update checks
Expand All @@ -319,9 +322,9 @@ def linear_problem(A,optimize_alpha,alpha):

# Solve the linear least-squares problem
result = linSolver(AtA, Aty)
linfit = parseResult(result)
linfit = np.atleast_1d(linfit)
result = parseResult(result)
linfit = validateResult(result)

return linfit, alpha
#===========================================================================

Expand Down Expand Up @@ -566,4 +569,3 @@ def _plot(subsets,y,yfit,show):
plt.close()
return fig
# ===========================================================================================

2 changes: 1 addition & 1 deletion deerlab/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def Jacobian(fcn, x0, lb, ub):
"""
J = opt._numdiff.approx_derivative(fcn,x0,method='2-point',bounds=(lb,ub))
J = np.atleast_2d(J)
J = np.atleast_2d(J)
return J
#===============================================================================

Expand Down

0 comments on commit a1ec280

Please sign in to comment.