Skip to content

Commit

Permalink
Bug fixes and intercept support
Browse files Browse the repository at this point in the history
  • Loading branch information
PTNobel committed Dec 11, 2024
1 parent 9383f3d commit 0afa3b9
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 29 deletions.
30 changes: 21 additions & 9 deletions randalo/adelie_integration.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
import adelie
import adelie as ad
from dataclasses import dataclass
import linops as lo
import numpy as np
import scipy.sparse as sp
import torch

import randalo as ra

class AdelieOperator(lo.LinearOperator):
supports_operator_matrix = True

def __init__(self, X, adjoint=None, shape=None):
def __init__(self, X, intercept=False, adjoint=None, shape=None):
if intercept:
X = ad.matrix.concatenate([X, np.ones(X.shape[0])], axis=1)

if shape is not None:
self._shape = shape
else:
m, n = X.shape
self._shape = (m, n)

self.X = X
self._adjoint = adjoint if adjoint is not None else AdelieOperator(X.T, self, (n, m))
self._adjoint = adjoint if adjoint is not None else \
AdelieOperator(X.T, False, self, (n, m))

def _matmul_impl(self, v):
return torch.from_numpy(self.X @ v.numpy())
Expand Down Expand Up @@ -46,16 +52,22 @@ def adelie_state_to_jacobian(y, state, adelie_state):

assert p == G, "Group lasso with adelie is not supported."

assert not state.intercept
ell_1_term = state.alpha * ra.L1Regularizer()
ell_2_2_term = (1 - state.alpha) / 2 * ra.SquareRegularizer()
reg = adelie_state.ra_lmda * (ell_1_term + ell_2_2_term)
if not state.intercept:
ell_1_term = state.alpha * ra.L1Regularizer()
ell_2_2_term = (1 - state.alpha) / 2 * ra.SquareRegularizer()
reg = adelie_state.ra_lmda * (ell_1_term + ell_2_2_term)
else:
ell_1_term = state.alpha * ra.L1Regularizer(slice(None, -1))
ell_2_2_term = (1 - state.alpha) / 2 * ra.SquareRegularizer(slice(None, -1))
reg = adelie_state.ra_lmda * (ell_1_term + ell_2_2_term)

loss = ra.MSELoss()
J = ra.Jacobian(
y,
AdelieOperator(state.X),
lambda: state.betas[adelie_state.index],
AdelieOperator(state.X, state.intercept),
lambda: sp.hstack((state.betas[adelie_state.index], sp.csr_matrix(np.array([[
state.intercepts[adelie_state.index]
]])))),
loss,
reg,
'minres'
Expand Down
45 changes: 27 additions & 18 deletions randalo/modeling_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def value(self, val):

@dataclass
class Regularizer(ABC):
linear: np.ndarray | list[int] = field(default=None)
linear: np.ndarray | list[int] | slice = field(default=None)
scale: float = field(init=False, default=1.0)
parameter: HyperParameter = field(init=False, default=None)

Expand Down Expand Up @@ -96,8 +96,8 @@ def get_constraint_hessian_mask(self, beta_hat, epsilon=1e-6):
ie any entry associated with a False in mask is held to always be zeros
"""

def get_constraint_hessian_mask_sparse(self, beta_hat, mask, epsilon=1e-6):
beta = torch.zeros(mask.shape, device=beta_hat.device, dtype=beta_hat.dtype)
def get_constraint_hessian_mask_sparse(self, beta_hat, mask, p, epsilon=1e-6):
beta = torch.zeros(p, device=beta_hat.device, dtype=beta_hat.dtype)
beta[mask] = beta_hat
return self.get_constraint_hessian_mask(beta, epsilon)

Expand All @@ -115,7 +115,7 @@ def to_cvxpy(self, variable):
def get_constraint_hessian_mask(self, beta_hat, epsilon=1e-6):
return self._internal(beta_hat.shape, beta_hat.dtype, beta_hat.device, epsilon)

def get_constraint_hessian_mask_sparse(self, beta_hat, mask, epsilon=1e-6):
def get_constraint_hessian_mask_sparse(self, beta_hat, mask, p, epsilon=1e-6):
return self._internal(mask.shape, beta_hat.dtype, beta_hat.device, epsilon)

def _internal(self, shape, dtype, device, epsilon):
Expand All @@ -125,12 +125,12 @@ def _internal(self, shape, dtype, device, epsilon):
if self.linear is None:
return None, torch.diag(
scale * torch.ones(shape, dtype=dtype, device=device)), None
elif isinstance(linear, list):
elif isinstance(self.linear, (list, slice)):
diag = torch.zeros(shape, dtype=dtype, device=device)
diag[linear] = scale
diag[self.linear] = scale
return None, torch.diag(diag), None
else:
A = utils.to_tensor(linear)
A = utils.to_tensor(self.linear)
return None, torch.diag(scale * (A.mT @ A)), None


Expand All @@ -143,24 +143,33 @@ def get_constraint_hessian_mask(self, beta_hat, epsilon=1e-6):
if self.linear is None:
mask[torch.abs(beta_hat) <= epsilon] = False
return None, None, mask
elif isinstance(linear, list):
mask[linear][torch.abs(beta_hat[linear]) <= epsilon] = False
elif isinstance(self.linear, (list, slice)):
mask[self.linear][torch.abs(beta_hat[self.linear]) <= epsilon] = False
return None, None, mask
else:
A = utils.from_numpy(linear)
A = utils.from_numpy(self.linear)
return A[torch.abs(A @ beta_hat) <= epsilon, :], None, None

def get_constraint_hessian_mask_sparse(self, beta_hat, mask, epsilon=1e-6):
def get_constraint_hessian_mask_sparse(self, beta_hat, mask, p, epsilon=1e-6):
mask_0 = torch.ones_like(beta_hat, dtype=bool)
if self.linear is None:
mask_0[torch.abs(beta_hat) <= epsilon] = False
return None, None, mask_0
elif isinstance(linear, list):
idx = torch.cumsum(mask)[linear]
elif isinstance(self.linear, slice):
if mask.dtype == bool or mask.dtype == torch.bool:
raise NotImplementedError()
mask_0[mask[self.linear]][torch.abs(beta_hat[idx]) <= epsilon] = False
else:
start, end, step = self.linear.indices(p)
mask_1 = (mask >= start) & (mask < end) & ((mask - start) % step == 0)
mask_0[mask_1][torch.abs(beta_hat[mask_1]) <= epsilon] = False
return None, None, mask_0
elif isinstance(self.linear, (list, slice)):
idx = torch.cumsum(mask_0)[self.linear]
mask_0[idx][torch.abs(beta_hat[idx]) <= epsilon] = False
return None, None, mask_0
else:
return super().get_constraint_hessian_mask_sparse(beta_hat, mask, epsilon)
return super().get_constraint_hessian_mask_sparse(beta_hat, mask, p, epsilon)


class L2Regularizer(Regularizer):
Expand All @@ -169,7 +178,7 @@ def to_cvxpy(self, variable):

def get_constraint_hessian_mask(self, beta_hat, epsilon=1e-6):
linear = self.linear
if self.linear is None:
if linear is None:
norm = torch.linalg.norm(beta_hat)
if norm <= epsilon:
mask = torch.zeros_like(beta_hat, dtype=bool)
Expand All @@ -178,7 +187,7 @@ def get_constraint_hessian_mask(self, beta_hat, epsilon=1e-6):
tilde_beta_hat_2d = torch.atleast_2d(beta_hat).T / norm
hessian = self._scale() * (torch.eye(beta_hat.shape) - beta_hat_2d @ beta_hat_2d.T)
return None, hessian, None
elif isinstance(linear, list):
elif isinstance(linear, (list, slice)):
norm = torch.linalg.norm(beta_hat[linear])
if norm <= epsilon:
mask = torch.ones_like(beta_hat, dtype=bool)
Expand Down Expand Up @@ -244,12 +253,12 @@ def get_constraint_hessian_mask(self, beta_hat, epsilon=1e-6):
hessians = sum(hessians) if len(hessians) > 0 else None
return constraints, hessians, mask

def get_constraint_hessian_mask_sparse(self, beta_hat, mask_1, epsilon=1e-6):
def get_constraint_hessian_mask_sparse(self, beta_hat, mask_1, p, epsilon=1e-6):
constraints = []
hessians = []
mask = torch.ones_like(beta_hat, dtype=bool)
for reg in self.exprs:
cons, hess, m = reg.get_constraint_hessian_mask_sparse(beta_hat, mask_1, epsilon)
cons, hess, m = reg.get_constraint_hessian_mask_sparse(beta_hat, mask_1, p, epsilon)
if cons is not None:
constraints.append(cons)
if hess is not None:
Expand Down
6 changes: 4 additions & 2 deletions randalo/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(self, y, X, solution_func, loss, regularizer, inverse_method=None):
self.y = utils.to_tensor(y)
self.X = lo.aslinearoperator(X)

self._adjoint = self # Not actually symmetric

@property
def _shape(self):
n = self.y.shape[0]
Expand All @@ -66,13 +68,13 @@ def _matmul_impl(self, rhs):
solution = self.solution_func()
if isinstance(solution, scipy.sparse.csr_matrix):
beta_hat = utils.to_tensor(solution.data)
mask_0 = utils.to_tensor(solution.indices)
mask_0 = utils.to_tensor(solution.indices, dtype=torch.int)
if solution.data.shape == (0,):
return torch.zeros_like(rhs).squeeze() if needs_squeeze else torch.zeros_like(rhs)
X = X[:, mask_0]
constraints, hessians, mask = \
self.regularizer.get_constraint_hessian_mask_sparse(
beta_hat, mask_0)
beta_hat, mask_0, X.shape[1])
else:
beta_hat = utils.to_tensor(solution)

Expand Down

0 comments on commit 0afa3b9

Please sign in to comment.