Skip to content

Commit

Permalink
[ADD] svd solver
Browse files Browse the repository at this point in the history
  • Loading branch information
mathdugre committed Oct 21, 2023
1 parent 5194d6f commit bebf23b
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions dask_glm/algorithms.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -504,10 +504,43 @@ def proximal_grad(
return beta return beta




@normalize
def svd(X, y, alpha=1., **kwargs):
# This SVD algorithm expect y to be a 2d-array.
ravel = False
if y.ndim == 1:
y = y.reshape(-1, 1)
ravel = True

# There should be either 1 or n_targets penalties
_, n_targets = y.shape
alpha = da.asarray(alpha, dtype=X.dtype).ravel()
if alpha.size not in [1, n_targets]:
raise ValueError(
"Number of targets and number of penalties do not correspond: %d != %d"
% (alpha.size, n_targets)
)

# REF: https://github.com/scikit-learn/scikit-learn/commit/814ad9ba14ea1f53da353368d34daf89061cf92e#diff-e7a1a6bac747c5273cc1e858fb418e7a86ceb453dc8a9ba92839da3d69932ba9
U, s, Vt = da.linalg.svd(X)
idx = s <= 1e-15 # same default value as scipy.linalg.pinv
UTy = dot(U.T, y)
s[idx] = 0.
d = s[:, np.newaxis] / (s[:, np.newaxis] ** 2 + alpha)
d_UT_y = d * UTy
coef = dot(Vt.T, d_UT_y).T

if ravel:
# When y was passed as a 1d-array, we flatten the coefficients.
coef = coef.ravel()
return coef.compute()


_solvers = { _solvers = {
"admm": admm, "admm": admm,
"gradient_descent": gradient_descent, "gradient_descent": gradient_descent,
"newton": newton, "newton": newton,
"lbfgs": lbfgs, "lbfgs": lbfgs,
"proximal_grad": proximal_grad, "proximal_grad": proximal_grad,
"svd": svd,
} }

0 comments on commit bebf23b

Please sign in to comment.