Skip to content

Commit

Permalink
Fixed prox to handle pytrees
Browse files Browse the repository at this point in the history
Fixed prox_lasso and prox_elastic_net to handle pytrees as inputs and floats for hyperparameters

Added tests
  • Loading branch information
vroulet committed Jun 9, 2023
1 parent 674a992 commit e593c89
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 7 deletions.
22 changes: 15 additions & 7 deletions jaxopt/_src/prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@ def prox_lasso(x: Any,
if l1reg is None:
l1reg = 1.0

fun = lambda u, v: jnp.sign(u) * jax.nn.relu(jnp.abs(u) - v * scaling)
return tree_util.tree_map(fun, x, l1reg)
if type(l1reg) == float:
l1reg = tree_util.tree_map(lambda y: l1reg*jnp.ones_like(y), x)

def fun(u, v): return jnp.sign(u) * jax.nn.relu(jnp.abs(u) - v * scaling)
return tree_util.tree_map(fun, x, l1reg)

def prox_non_negative_lasso(x: Any,
l1reg: Optional[float] = None,
Expand All @@ -95,7 +97,7 @@ def prox_non_negative_lasso(x: Any,
if l1reg is None:
l1reg = 1.0

pytree = tree_util.tree_add(x, -l1reg * scaling)
pytree = tree_util.tree_map(lambda y: y - l1reg*scaling, x)
return tree_util.tree_map(jax.nn.relu, pytree)


Expand Down Expand Up @@ -123,10 +125,16 @@ def prox_elastic_net(x: Any,
if hyperparams is None:
hyperparams = (1.0, 1.0)

prox_l1 = lambda u, lam: jnp.sign(u) * jax.nn.relu(jnp.abs(u) - lam)
fun = lambda u, lam, gamma: (prox_l1(u, scaling * lam) /
(1.0 + scaling * lam * gamma))
return tree_util.tree_map(fun, x, hyperparams[0], hyperparams[1])
lam = tree_util.tree_map(lambda y: hyperparams[0]*jnp.ones_like(
y), x) if type(hyperparams[0]) == float else hyperparams[0]
gam = tree_util.tree_map(lambda y: hyperparams[1]*jnp.ones_like(
y), x) if type(hyperparams[1]) == float else hyperparams[1]

def prox_l1(u, lambd): return jnp.sign(u) * jax.nn.relu(jnp.abs(u) - lambd)

def fun(u, lambd, gamma): return (prox_l1(u, scaling * lambd) /
(1.0 + scaling * lambd * gamma))
return tree_util.tree_map(fun, x, lam, gam)


def prox_group_lasso(x: Any,
Expand Down
31 changes: 31 additions & 0 deletions tests/prox_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.

from absl.testing import absltest
from absl.testing import parameterized

import jax
import jax.numpy as jnp

from jaxopt import projection
from jaxopt import prox
from jaxopt import tree_util
from jaxopt._src import test_util

import numpy as onp
Expand Down Expand Up @@ -190,5 +192,34 @@ def test_make_prox_from_projection(self):
proxop = prox.make_prox_from_projection(projection.projection_simplex)
self.assertArraysAllClose(proxop(x), projection.projection_simplex(x))

@parameterized.product(
prox_op=[
prox.prox_none,
prox.prox_lasso,
prox.prox_non_negative_lasso,
prox.prox_ridge,
prox.prox_non_negative_ridge,
prox.prox_elastic_net,
]
)
def test_pytree_comptability(self, prox_op):
rng = onp.random.RandomState(0)
x = dict(a=rng.rand(16, 16), b=rng.rand(16))
got = prox_op(x)
expected = tree_util.tree_map(prox_op, x)
self.assertAllClose(got, expected)
if prox_op is prox.prox_lasso:
l1_reg = tree_util.tree_ones_like(x)
got = prox_op(x, l1_reg)
expected = tree_util.tree_map(prox_op, x, l1_reg)
self.assertAllClose(got, expected)
if prox_op is prox.prox_elastic_net:
hyperparams = [tree_util.tree_ones_like(x), tree_util.tree_ones_like(x)]
got = prox_op(x, hyperparams)
hyperparams_tree = tree_util.tree_map(
lambda y: [jnp.ones_like(y), jnp.ones_like(y)], x)
expected = tree_util.tree_map(prox_op, x, hyperparams_tree)
self.assertAllClose(got, expected)

if __name__ == '__main__':
absltest.main()

0 comments on commit e593c89

Please sign in to comment.