Skip to content

Commit

Permalink
Finishes Adelie integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PTNobel committed Nov 20, 2024
1 parent 3a24e22 commit 4281e2c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 17 deletions.
4 changes: 2 additions & 2 deletions examples/adelie_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
import randalo.adelie_integration as ai
import torch

alo = ai.get_alo_for_sweep(y, state, torch.nn.MSELoss())
ld, alo = ai.get_alo_for_sweep(y, state, torch.nn.MSELoss())
dg = ad.diagnostic.diagnostic(state)
dg.plot_devs()
plt.plot(state.lmda, alo)
plt.plot(-np.log(ld), alo)
plt.show()
43 changes: 32 additions & 11 deletions randalo/adelie_integration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
import adelie
import randalo as ra
import numpy as np
from dataclasses import dataclass
import linops as lo
import numpy as np
import torch

import randalo as ra

class AdelieOperator(lo.LinearOperator):
supports_operator_matrix = True

def __init__(self, X, adjoint=None, shape=None):
if shape is not None:
self._shape = shape
else:
m, n = X.shape
self._shape = (m, n)
self.X = X
self._adjoint = adjoint if adjoint is not None else AdelieOperator(X.T, self, (n, m))

def _matmul_impl(self, v):
return torch.from_numpy(self.X @ v.numpy())

def __getitem__(self, key):
if isinstance(key, tuple):
key = tuple(k.numpy() if isinstance(k, torch.Tensor) else k for k in key)
return AdelieOperator(self.X[key])

def curry(f, *args0, **kwargs0):
return lambda *args, **kwargs: f(*args0, *args, **kwargs0, **kwargs)
Expand Down Expand Up @@ -29,21 +52,19 @@ def adelie_state_to_jacobian(y, state, adelie_state):
reg = adelie_state.ra_lmda * (ell_1_term + ell_2_2_term)

loss = ra.MSELoss()
breakpoint()
J = ra.Jacobian(
y,
state.X,
lambda: (
betas[adelie_state.index], # What is the type of this?
screen_set[active_set[active_sizes[:adelie_state.index]]]),
AdelieOperator(state.X),
lambda: state.betas[adelie_state.index],
loss,
reg,
'minres'
)

return loss, J

def adelie_state_to_randalo(y, state, adelie_state, loss, J, index, rng):
y_hat = state.X @ state.beta[index]
def adelie_state_to_randalo(y, state, adelie_state, loss, J, index, rng=None):
y_hat = (state.X @ state.betas[index].T).squeeze()
adelie_state.set_index(index)
randalo = ra.RandALO(
loss,
Expand All @@ -55,7 +76,7 @@ def adelie_state_to_randalo(y, state, adelie_state, loss, J, index, rng):
return randalo

def get_alo_for_sweep(y, state, risk_fun):
L, = state.lmda_path.shape
L, _ = state.betas.shape
adelie_state = AdelieState(state)
loss, J = adelie_state_to_jacobian(y, state, adelie_state)

Expand All @@ -65,5 +86,5 @@ def get_alo_for_sweep(y, state, risk_fun):
randalo = adelie_state_to_randalo(y, state, adelie_state, loss, J, i)
output[i] = randalo.evaluate(risk_fun)

return output
return state.lmda_path[:L], output

24 changes: 20 additions & 4 deletions randalo/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from typing import Callable, Literal
from dataclasses import dataclass, field

import scipy
import numpy as np
import linops as lo
from linops.minres import minres
import torch

from . import modeling_layer as ml
Expand Down Expand Up @@ -38,7 +40,7 @@ def __init__(self, y, X, solution_func, loss, regularizer, inverse_method=None):
self.regularizer = regularizer
self.inverse_method = inverse_method
self.y = utils.to_tensor(y)
self.X = utils.to_tensor(X)
self.X = lo.aslinearoperator(X)

@property
def _shape(self):
Expand All @@ -62,8 +64,11 @@ def _matmul_impl(self, rhs):
needs_squeeze = True

solution = self.solution_func()
if isinstance(solution, tuple):
beta_hat, mask_0 = solution
if isinstance(solution, scipy.sparse.csr_matrix):
beta_hat = utils.to_tensor(solution.data)
mask_0 = utils.to_tensor(solution.indices)
if solution.data.shape == (0,):
return torch.zeros_like(rhs).squeeze() if needs_squeeze else torch.zeros_like(rhs)
X = X[:, mask_0]
constraints, hessians, mask = \
self.regularizer.get_constraint_hessian_mask_sparse(
Expand All @@ -85,7 +90,18 @@ def _matmul_impl(self, rhs):

rhs_scaled = -d2loss_dboth[:, None] * rhs

if constraints is None and hessians is None:
# TODO: Split cholesky/minres code paths into seperate ones
if self.inverse_method == 'minres':
if constraints is None and hessians is None:
sqrt_d2loss_dy_hat2 = torch.sqrt(d2loss_dy_hat2)[:, None]
tilde_X = sqrt_d2loss_dy_hat2 * X_mask
return ((
X @ minres(X.T @ X, (X.T @ (rhs_scaled / sqrt_d2loss_dy_hat2)))
) / sqrt_d2loss_dy_hat2).to(rhs.dtype)
else:
raise NotImplementedError()

elif constraints is None and hessians is None:
sqrt_d2loss_dy_hat2 = torch.sqrt(d2loss_dy_hat2)[:, None]
tilde_X = sqrt_d2loss_dy_hat2 * X_mask
Q, _ = torch.linalg.qr(tilde_X)
Expand Down

0 comments on commit 4281e2c

Please sign in to comment.