Skip to content

Commit

Permalink
Merge pull request #206 from flatironinstitute/update_converge
Browse files Browse the repository at this point in the history
fix convergence tests
  • Loading branch information
BalzaniEdoardo authored Aug 12, 2024
2 parents 90d7ffc + 49d8577 commit c3dbe3d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 25 deletions.
4 changes: 2 additions & 2 deletions src/nemos/regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def get_proximal_operator(
term is not regularized.
"""

def prox_op(params, l2reg, scaling=0.5):
def prox_op(params, l2reg, scaling=1.0):
Ws, bs = params
l2reg /= bs.shape[0]
return jaxopt.prox.prox_ridge(Ws, l2reg, scaling=scaling), bs
Expand Down Expand Up @@ -447,7 +447,7 @@ def _penalization(
) # this masks the param, (group, feature, neuron)

penalty = jax.numpy.sum(
jax.numpy.linalg.norm(masked_param, axis=1)
jax.numpy.linalg.norm(masked_param, axis=1).T
* jax.numpy.sqrt(self.mask.sum(axis=1))
)

Expand Down
75 changes: 52 additions & 23 deletions tests/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def test_unregularized_convergence():
Assert that solution found when using GradientDescent vs ProximalGradient with an
unregularized GLM is the same.
"""
jax.config.update("jax_enable_x64", True)

# generate toy data
np.random.seed(111)
# random design tensor. Shape (n_time_points, n_features).
Expand All @@ -30,22 +32,26 @@ def test_unregularized_convergence():
y = np.random.poisson(rate)

# instantiate and fit unregularized GLM with GradientDescent
model_GD = nmo.glm.GLM()
model_GD = nmo.glm.GLM(solver_kwargs=dict(tol=10**-12))
model_GD.fit(X, y)

# instantiate and fit unregularized GLM with ProximalGradient
model_PG = nmo.glm.GLM(solver_name="ProximalGradient")
model_PG = nmo.glm.GLM(
solver_name="ProximalGradient", solver_kwargs=dict(tol=10**-12)
)
model_PG.fit(X, y)

# assert weights are the same
assert np.allclose(np.round(model_GD.coef_, 2), np.round(model_PG.coef_, 2))
assert np.allclose(model_GD.coef_, model_PG.coef_)
assert np.allclose(model_GD.intercept_, model_PG.intercept_)


def test_ridge_convergence():
"""
Assert that solution found when using GradientDescent vs ProximalGradient with an
ridge GLM is the same.
"""
jax.config.update("jax_enable_x64", True)
# generate toy data
np.random.seed(111)
# random design tensor. Shape (n_time_points, n_features).
Expand All @@ -63,30 +69,40 @@ def test_ridge_convergence():
y = np.random.poisson(rate)

# instantiate and fit ridge GLM with GradientDescent
model_GD = nmo.glm.GLM(regularizer="Ridge")
model_GD = nmo.glm.GLM(regularizer="Ridge", solver_kwargs=dict(tol=10**-12))
model_GD.fit(X, y)

# instantiate and fit ridge GLM with ProximalGradient
model_PG = nmo.glm.GLM(regularizer="Ridge", solver_name="ProximalGradient")
model_PG = nmo.glm.GLM(
regularizer="Ridge",
solver_name="ProximalGradient",
solver_kwargs=dict(tol=10**-12),
)
model_PG.fit(X, y)

# assert weights are the same
assert np.allclose(np.round(model_GD.coef_, 2), np.round(model_PG.coef_, 2))
assert np.allclose(model_GD.coef_, model_PG.coef_)
assert np.allclose(model_GD.intercept_, model_PG.intercept_)


def test_lasso_convergence():
"""
Assert that solution found when using ProximalGradient versus Nelder-Mead method using
lasso GLM is the same.
"""
jax.config.update("jax_enable_x64", True)
# generate toy data
num_samples, num_features, num_groups = 1000, 5, 3
num_samples, num_features, num_groups = 1000, 1, 3
X = np.random.normal(size=(num_samples, num_features)) # design matrix
w = [0, 0.5, 1, 0, -0.5] # define some weights
w = [0.5] # define some weights
y = np.random.poisson(np.exp(X.dot(w))) # observed counts

# instantiate and fit GLM with ProximalGradient
model_PG = nmo.glm.GLM(regularizer="Lasso", solver_name="ProximalGradient")
model_PG = nmo.glm.GLM(
regularizer="Lasso",
solver_name="ProximalGradient",
solver_kwargs=dict(tol=10**-12),
)
model_PG.regularizer_strength = 0.1
model_PG.fit(X, y)

Expand All @@ -103,31 +119,37 @@ def test_lasso_convergence():
x,
y,
)
res = minimize(penalized_loss, [0] + w, args=(X, y), method="Nelder-Mead")
res = minimize(
penalized_loss, [0] + w, args=(X, y), method="Nelder-Mead", tol=10**-12
)

# assert absolute difference between the weights is less than 0.1
a = np.abs(np.subtract(np.round(res.x[1:], 2), np.round(model_PG.coef_, 2))) < 1e-1
assert a.all()
# assert weights are the same
assert np.allclose(res.x[1:], model_PG.coef_)
assert np.allclose(res.x[:1], model_PG.intercept_)


def test_group_lasso_convergence():
"""
Assert that solution found when using ProximalGradient versus Nelder-Mead method using
group lasso GLM is the same.
"""
jax.config.update("jax_enable_x64", True)
# generate toy data
num_samples, num_features, num_groups = 1000, 5, 3
num_samples, num_features, num_groups = 1000, 3, 2
X = np.random.normal(size=(num_samples, num_features)) # design matrix
w = [0, 0.5, 1, 0, -0.5] # define some weights
w = [-0.5, 0.25, 0.5] # define some weights
y = np.random.poisson(np.exp(X.dot(w))) # observed counts

mask = np.zeros((num_groups, num_features))
mask[0] = [1, 0, 0, 1, 0] # Group 0 includes features 0 and 3
mask[1] = [0, 1, 0, 0, 0] # Group 1 includes features 1
mask[2] = [0, 0, 1, 0, 1] # Group 2 includes features 2 and 4
mask[0] = [1, 1, 0] # Group 0 includes features 0 and 1
mask[1] = [0, 0, 1] # Group 1 includes features 1

# instantiate and fit GLM with ProximalGradient
model_PG = nmo.glm.GLM(regularizer=nmo.regularizer.GroupLasso(mask=mask))
model_PG = nmo.glm.GLM(
regularizer=nmo.regularizer.GroupLasso(mask=mask),
solver_kwargs=dict(tol=10**-14, maxiter=10000),
regularizer_strength=0.2,
)
model_PG.fit(X, y)

# use the penalized loss function to solve optimization via Nelder-Mead
Expand All @@ -144,8 +166,15 @@ def test_group_lasso_convergence():
y,
)

res = minimize(penalized_loss, [0] + w, args=(X, y), method="Nelder-Mead")
res = minimize(
penalized_loss,
[0] + w,
args=(X, y),
method="Nelder-Mead",
tol=10**-12,
options=dict(maxiter=1000),
)

# assert absolute difference between the weights is less than 0.5
a = np.abs(np.subtract(np.round(res.x[1:], 2), np.round(model_PG.coef_, 2))) < 0.5
assert a.all()
# assert weights are the same
assert np.allclose(res.x[1:], model_PG.coef_)
assert np.allclose(res.x[:1], model_PG.intercept_)

0 comments on commit c3dbe3d

Please sign in to comment.