Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576947921
  • Loading branch information
fabianp authored and JAXopt authors committed Oct 26, 2023
1 parent 7afa914 commit 68861b8
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 35 deletions.
94 changes: 69 additions & 25 deletions jaxopt/_src/polyak_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@

"""SGD solver with Polyak step size."""

import dataclasses
from typing import Any
from typing import Callable
from typing import NamedTuple
from typing import Optional

import dataclasses

import jax
import jax.numpy as jnp

from jaxopt._src import base
Expand Down Expand Up @@ -49,23 +47,50 @@ class PolyakSGDState(NamedTuple):

@dataclasses.dataclass(eq=False)
class PolyakSGD(base.StochasticSolver):
"""SGD with Polyak step size.
r"""SGD with Polyak step size.
The stochastic Polyak step-size is a simple and efficient step-size for SGD.
Although this algorithm does not require to set a step-size parameter, it does
require knowledge of a lower bound on the objective function (see below).
Furthermore, some variants accept other hyperparameters.
.. warning::
This method requires knowledge of an approximate value of the of the objective function
minimum, passed through the ``fun_min`` argument. For overparametrized models, this can be
set to 0 (default value). Failing to set an appropriate value for ``fun_min`` can lead to
a model that diverges or converges to a suboptimal solution.
This class implements two different variants of the stochastic Polyak step size method: ``SPS_max``
and ``SPS+``. The ``SPS_max`` variant from (Loizou et al. 2021) accepts the hyperparameters
``max_stepsize`` and ``delta`` and sets the current step-size :math:`\gamma` as
.. math::
\gamma = \min\left\{\frac{\text{fun}(x) - \text{fun}(x^\star)}{\|\nabla \text{fun}(x)\|^2 + \text{delta}}, \text{max_stepsize}\right\}
This solver computes step sizes in an adaptive manner. If the computed step
size at a given iteration is smaller than ``max_stepsize``, it is accepted.
Otherwise, ``max_stepsize`` is used. This ensures that the solver does not
take over-confident steps. This is why ``max_stepsize`` is the most important
hyper-parameter.
while for the ``SPS+`` variant, it is given by
.. math::
\gamma = \max\left\{0, \frac{\text{fun}(x) - \text{fun}(x^\star)}{\|\nabla \text{fun}(x)\|^2}\right\}
and the step-size is zero whenever :math:`\|\nabla \text{fun}(x)\|^2` is zero.
In all cases, the step size is then used in the update
.. math::
v_{t+1} &= \text{momentum} v_t - \gamma \nabla \text{fun}(x) \\
x_{t+1} &= x_t + v_{t+1}
This implementation assumes that the interpolation property holds:
the global optimum over D must also be a global optimum for any finite sample of D
This is typically achieved by overparametrized models (e.g neural networks)
in classification tasks with separable classes, or on regression tasks without noise.
Attributes:
fun: a function of the form ``fun(params, *args, **kwargs)``, where
``params`` are parameters of the model,
``*args`` and ``**kwargs`` are additional arguments.
value_and_grad: whether ``fun`` just returns the value (False) or both the
value and gradient (True).
has_aux: whether ``fun`` outputs auxiliary data or not.
If ``has_aux`` is False, ``fun`` is expected to be
scalar-valued.
Expand All @@ -77,8 +102,11 @@ class PolyakSGD(base.StochasticSolver):
``(value, aux), grad = fun(...)``.
At each iteration of the algorithm, the auxiliary outputs are stored
in ``state.aux``.
fun_min: a lower bound on fun.
max_stepsize: a maximum step size to use.
variant: which version of the stochastic Polyak step-size is implemented.
Can be one of "SPS_max" or "SPS+".
max_stepsize: a maximum step size to use. Only used when variant="SPS_max".
delta: a value to add in the denominator of the update (default: 0).
momentum: momentum parameter, 0 corresponding to no momentum.
pre_update: a function to execute before the solver's update.
Expand All @@ -98,21 +126,22 @@ class PolyakSGD(base.StochasticSolver):
References:
Berrada, Leonard and Zisserman, Andrew and Kumar, M Pawan.
"Training neural networks for and by interpolation".
`"Training neural networks for and by interpolation" <https://arxiv.org/abs/1906.05661>`_.
International Conference on Machine Learning, 2020.
https://arxiv.org/abs/1906.05661
Loizou, Nicolas and Vaswani, Sharan and Laradji, Issam Hadj and
Lacoste-Julien, Simon.
"Stochastic polyak step-size for sgd: An adaptive learning rate for fast
convergence".
`"Stochastic polyak step-size for sgd: An adaptive learning rate for fast
convergence" <https://arxiv.org/abs/2002.10542>`_.
International Conference on Artificial Intelligence and Statistics, 2021.
https://arxiv.org/abs/2002.10542
"""
fun: Callable
value_and_grad: bool = False
has_aux: bool = False
fun_min: float = 0.0

variant: str = "SPS_max"
max_stepsize: float = 1.0
delta: float = 0.0
momentum: float = 0.0
Expand Down Expand Up @@ -182,8 +211,20 @@ def update(self,
(value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)

grad_sqnorm = tree_l2_norm(grad, squared=True)
stepsize = jnp.minimum(value / (grad_sqnorm + self.delta),
self.max_stepsize)
if self.variant == "SPS_max":
stepsize = jnp.minimum(
(value - self.fun_min) / (grad_sqnorm + self.delta), self.max_stepsize
)
elif self.variant == "SPS+":
# if grad_sqnorm is smaller than machine epsilon, we set the stepsize to 0
stepsize = jnp.where(
grad_sqnorm <= jnp.finfo(dtype).eps,
0.0,
jnp.maximum((value - self.fun_min) / grad_sqnorm, 0),
)
else:
raise NotImplementedError(f"Unknown variant {self.variant}")

stepsize = stepsize.astype(state.stepsize.dtype)

if self.momentum == 0:
Expand Down Expand Up @@ -218,9 +259,12 @@ def __hash__(self):
def __post_init__(self):
super().__post_init__()

self._fun, self._grad_fun, self._value_and_grad_fun = \
base._make_funs_with_aux(fun=self.fun,
value_and_grad=self.value_and_grad,
has_aux=self.has_aux)
self._fun, self._grad_fun, self._value_and_grad_fun = (
base._make_funs_with_aux(
fun=self.fun,
value_and_grad=self.value_and_grad,
has_aux=self.has_aux,
)
)

self.reference_signature = self.fun
46 changes: 36 additions & 10 deletions tests/polyak_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,55 @@

class PolyakSgdTest(test_util.JaxoptTestCase):

@parameterized.product(momentum=[0.0, 0.9])
def test_logreg_with_intercept_manual_loop(self, momentum):
X, y = datasets.make_classification(n_samples=10, n_features=5, n_classes=3,
n_informative=3, random_state=0)
data = (X, y)
l2reg = 100.0
@parameterized.product(momentum=[0.0, 0.9], sps_variant=['SPS_max', 'SPS+'])
def test_logreg_overparameterized(self, momentum, sps_variant):
# Test SPS on an over-parameterized logistic regression problem.
# The loss' infimum is zero and SPS should converge to a minimizer.
data = datasets.make_classification(
n_samples=10, n_features=10, random_state=0
)
# fun(params, data)
fun = objective.l2_multiclass_logreg_with_intercept
n_classes = len(jnp.unique(data[1]))

w_init = jnp.zeros((data[0].shape[1], n_classes))
b_init = jnp.zeros(n_classes)
params = (w_init, b_init)

opt = PolyakSGD(fun=fun, fun_min=0, momentum=momentum, variant=sps_variant)
error_init = opt.l2_optimality_error(params, l2reg=0, data=data)
params, _ = opt.run(params, l2reg=0., data=data)

# Check optimality conditions.
error = opt.l2_optimality_error(params, l2reg=0., data=data)
self.assertLessEqual(error / error_init, 0.01)

@parameterized.product(momentum=[0.0, 0.9], sps_variant=['SPS_max', 'SPS+'])
def test_logreg_with_intercept_manual_loop(self, momentum, sps_variant):
x, y = datasets.make_classification(n_samples=10, n_features=5, n_classes=3,
n_informative=3, random_state=0)
data = (x, y)
l2reg = 0.1
# fun(params, l2reg, data)
fun = objective.l2_multiclass_logreg_with_intercept
n_classes = len(jnp.unique(y))

W_init = jnp.zeros((X.shape[1], n_classes))
w_init = jnp.zeros((x.shape[1], n_classes))
b_init = jnp.zeros(n_classes)
params = (W_init, b_init)
params = (w_init, b_init)

opt = PolyakSGD(fun=fun, max_stepsize=0.01, momentum=momentum)
opt = PolyakSGD(
fun=fun, fun_min=0.6975, momentum=momentum, variant=sps_variant
)
error_init = opt.l2_optimality_error(params, l2reg=l2reg, data=data)

state = opt.init_state(params, l2reg=l2reg, data=data)
for _ in range(200):
params, state = opt.update(params, state, l2reg=l2reg, data=data)

# Check optimality conditions.
error = opt.l2_optimality_error(params, l2reg=l2reg, data=data)
self.assertLessEqual(error, 0.05)
self.assertLessEqual(error / error_init, 0.02)

@parameterized.product(has_aux=[True, False])
def test_logreg_with_intercept_run(self, has_aux):
Expand Down

0 comments on commit 68861b8

Please sign in to comment.