Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix convergence tests #206

Merged
merged 4 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/nemos/regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def get_proximal_operator(
term is not regularized.
"""

def prox_op(params, l2reg, scaling=0.5):
def prox_op(params, l2reg, scaling=1.):
Ws, bs = params
l2reg /= bs.shape[0]
return jaxopt.prox.prox_ridge(Ws, l2reg, scaling=scaling), bs
Expand Down Expand Up @@ -447,7 +447,7 @@ def _penalization(
) # this masks the param, (group, feature, neuron)

penalty = jax.numpy.sum(
jax.numpy.linalg.norm(masked_param, axis=1)
jax.numpy.linalg.norm(masked_param, axis=1).T
* jax.numpy.sqrt(self.mask.sum(axis=1))
)

Expand Down
75 changes: 52 additions & 23 deletions tests/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def test_unregularized_convergence():
Assert that solution found when using GradientDescent vs ProximalGradient with an
unregularized GLM is the same.
"""
jax.config.update("jax_enable_x64", True)

# generate toy data
np.random.seed(111)
# random design tensor. Shape (n_time_points, n_features).
Expand All @@ -30,22 +32,26 @@ def test_unregularized_convergence():
y = np.random.poisson(rate)

# instantiate and fit unregularized GLM with GradientDescent
model_GD = nmo.glm.GLM()
model_GD = nmo.glm.GLM(solver_kwargs=dict(tol=10**-12))
model_GD.fit(X, y)

# instantiate and fit unregularized GLM with ProximalGradient
model_PG = nmo.glm.GLM(solver_name="ProximalGradient")
model_PG = nmo.glm.GLM(
solver_name="ProximalGradient", solver_kwargs=dict(tol=10**-12)
)
model_PG.fit(X, y)

# assert weights are the same
assert np.allclose(np.round(model_GD.coef_, 2), np.round(model_PG.coef_, 2))
assert np.allclose(model_GD.coef_, model_PG.coef_)
assert np.allclose(model_GD.intercept_, model_PG.intercept_)


def test_ridge_convergence():
"""
Assert that solution found when using GradientDescent vs ProximalGradient with an
ridge GLM is the same.
"""
jax.config.update("jax_enable_x64", True)
# generate toy data
np.random.seed(111)
# random design tensor. Shape (n_time_points, n_features).
Expand All @@ -63,30 +69,40 @@ def test_ridge_convergence():
y = np.random.poisson(rate)

# instantiate and fit ridge GLM with GradientDescent
model_GD = nmo.glm.GLM(regularizer="Ridge")
model_GD = nmo.glm.GLM(regularizer="Ridge", solver_kwargs=dict(tol=10**-12))
model_GD.fit(X, y)

# instantiate and fit ridge GLM with ProximalGradient
model_PG = nmo.glm.GLM(regularizer="Ridge", solver_name="ProximalGradient")
model_PG = nmo.glm.GLM(
regularizer="Ridge",
solver_name="ProximalGradient",
solver_kwargs=dict(tol=10**-12),
)
model_PG.fit(X, y)

# assert weights are the same
assert np.allclose(np.round(model_GD.coef_, 2), np.round(model_PG.coef_, 2))
assert np.allclose(model_GD.coef_, model_PG.coef_)
assert np.allclose(model_GD.intercept_, model_PG.intercept_)


def test_lasso_convergence():
"""
Assert that solution found when using ProximalGradient versus Nelder-Mead method using
lasso GLM is the same.
"""
jax.config.update("jax_enable_x64", True)
# generate toy data
num_samples, num_features, num_groups = 1000, 5, 3
num_samples, num_features, num_groups = 1000, 1, 3
X = np.random.normal(size=(num_samples, num_features)) # design matrix
w = [0, 0.5, 1, 0, -0.5] # define some weights
w = [0.5] # define some weights
y = np.random.poisson(np.exp(X.dot(w))) # observed counts

# instantiate and fit GLM with ProximalGradient
model_PG = nmo.glm.GLM(regularizer="Lasso", solver_name="ProximalGradient")
model_PG = nmo.glm.GLM(
regularizer="Lasso",
solver_name="ProximalGradient",
solver_kwargs=dict(tol=10**-12),
)
model_PG.regularizer_strength = 0.1
model_PG.fit(X, y)

Expand All @@ -103,31 +119,37 @@ def test_lasso_convergence():
x,
y,
)
res = minimize(penalized_loss, [0] + w, args=(X, y), method="Nelder-Mead")
res = minimize(
penalized_loss, [0] + w, args=(X, y), method="Nelder-Mead", tol=10**-12
)

# assert absolute difference between the weights is less than 0.1
a = np.abs(np.subtract(np.round(res.x[1:], 2), np.round(model_PG.coef_, 2))) < 1e-1
assert a.all()
# assert weights are the same
assert np.allclose(res.x[1:], model_PG.coef_)
assert np.allclose(res.x[:1], model_PG.intercept_)


def test_group_lasso_convergence():
"""
Assert that solution found when using ProximalGradient versus Nelder-Mead method using
group lasso GLM is the same.
"""
jax.config.update("jax_enable_x64", True)
# generate toy data
num_samples, num_features, num_groups = 1000, 5, 3
num_samples, num_features, num_groups = 1000, 3, 2
X = np.random.normal(size=(num_samples, num_features)) # design matrix
w = [0, 0.5, 1, 0, -0.5] # define some weights
w = [-0.5, 0.25, 0.5] # define some weights
y = np.random.poisson(np.exp(X.dot(w))) # observed counts

mask = np.zeros((num_groups, num_features))
mask[0] = [1, 0, 0, 1, 0] # Group 0 includes features 0 and 3
mask[1] = [0, 1, 0, 0, 0] # Group 1 includes features 1
mask[2] = [0, 0, 1, 0, 1] # Group 2 includes features 2 and 4
mask[0] = [1, 1, 0] # Group 0 includes features 0 and 1
mask[1] = [0, 0, 1] # Group 1 includes features 1

# instantiate and fit GLM with ProximalGradient
model_PG = nmo.glm.GLM(regularizer=nmo.regularizer.GroupLasso(mask=mask))
model_PG = nmo.glm.GLM(
regularizer=nmo.regularizer.GroupLasso(mask=mask),
solver_kwargs=dict(tol=10**-14, maxiter=10000),
regularizer_strength=0.2,
)
model_PG.fit(X, y)

# use the penalized loss function to solve optimization via Nelder-Mead
Expand All @@ -144,8 +166,15 @@ def test_group_lasso_convergence():
y,
)

res = minimize(penalized_loss, [0] + w, args=(X, y), method="Nelder-Mead")
res = minimize(
penalized_loss,
[0] + w,
args=(X, y),
method="Nelder-Mead",
tol=10**-12,
options=dict(maxiter=1000),
)

# assert absolute difference between the weights is less than 0.5
a = np.abs(np.subtract(np.round(res.x[1:], 2), np.round(model_PG.coef_, 2))) < 0.5
assert a.all()
# assert weights are the same
assert np.allclose(res.x[1:], model_PG.coef_)
assert np.allclose(res.x[:1], model_PG.intercept_)
Loading