Skip to content

Commit

Permalink
Merge pull request #921 from mblondel:projection_doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623416281
  • Loading branch information
OptaxDev committed Apr 10, 2024
2 parents addb322 + 4e089dc commit 437d79c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
48 changes: 48 additions & 0 deletions docs/api/projections.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
Projections
===========

.. currentmodule:: optax.projections

Projections can be used to perform constrained optimization.
The Euclidean projection onto a set :math:`\mathcal{C}` is:

.. math::
\text{proj}_{\mathcal{C}}(u) :=
\underset{v}{\text{argmin}} ~ ||u - v||^2_2 \textrm{ subject to } v \in \mathcal{C}.
For instance, here is an example how we can project parameters to the non-negative orthant::

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> num_weights = 2
>>> xs = jnp.array([[-1.8, 2.2], [-2.0, 1.2]])
>>> ys = jnp.array([0.5, 0.8])
>>> optimizer = optax.adam(learning_rate=1e-3)
>>> params = {'w': jnp.zeros(num_weights)}
>>> opt_state = optimizer.init(params)
>>> loss = lambda params, x, y: jnp.mean((params['w'].dot(x) - y) ** 2)
>>> grads = jax.grad(loss)(params, xs, ys)
>>> updates, opt_state = optimizer.update(grads, opt_state)
>>> params = optax.apply_updates(params, updates)
>>> params = optax.projections.projection_non_negative(params)

Available projections
~~~~~~~~~~~~~~~~~~~~~
.. autosummary::
projection_box
projection_hypercube
projection_non_negative

Projection onto a box
~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_box

Projection onto a hypercube
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_hypercube

Projection onto the non-negative orthant
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: projection_non_negative
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ for instructions on installing JAX.
api/optimizer_wrappers
api/optimizer_schedules
api/apply_updates
api/projections
api/losses
api/control_variates
api/stochastic_gradient_estimators
Expand Down

0 comments on commit 437d79c

Please sign in to comment.