From 68861b8199decc3f58fcdc46e926e5e37074ef94 Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa Date: Thu, 26 Oct 2023 12:06:04 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 576947921 --- jaxopt/_src/polyak_sgd.py | 94 ++++++++++++++++++++++++++++----------- tests/polyak_sgd_test.py | 46 ++++++++++++++----- 2 files changed, 105 insertions(+), 35 deletions(-) diff --git a/jaxopt/_src/polyak_sgd.py b/jaxopt/_src/polyak_sgd.py index c8b19a12..49c18442 100644 --- a/jaxopt/_src/polyak_sgd.py +++ b/jaxopt/_src/polyak_sgd.py @@ -14,14 +14,12 @@ """SGD solver with Polyak step size.""" +import dataclasses from typing import Any from typing import Callable from typing import NamedTuple from typing import Optional -import dataclasses - -import jax import jax.numpy as jnp from jaxopt._src import base @@ -49,23 +47,50 @@ class PolyakSGDState(NamedTuple): @dataclasses.dataclass(eq=False) class PolyakSGD(base.StochasticSolver): - """SGD with Polyak step size. + r"""SGD with Polyak step size. + + The stochastic Polyak step-size is a simple and efficient step-size for SGD. + Although this algorithm does not require to set a step-size parameter, it does + require knowledge of a lower bound on the objective function (see below). + Furthermore, some variants accept other hyperparameters. + + .. warning:: + This method requires knowledge of an approximate value of the of the objective function + minimum, passed through the ``fun_min`` argument. For overparametrized models, this can be + set to 0 (default value). Failing to set an appropriate value for ``fun_min`` can lead to + a model that diverges or converges to a suboptimal solution. + + + This class implements two different variants of the stochastic Polyak step size method: ``SPS_max`` + and ``SPS+``. The ``SPS_max`` variant from (Loizou et al. 2021) accepts the hyperparameters + ``max_stepsize`` and ``delta`` and sets the current step-size :math:`\gamma` as + + .. math:: + + \gamma = \min\left\{\frac{\text{fun}(x) - \text{fun}(x^\star)}{\|\nabla \text{fun}(x)\|^2 + \text{delta}}, \text{max_stepsize}\right\} - This solver computes step sizes in an adaptive manner. If the computed step - size at a given iteration is smaller than ``max_stepsize``, it is accepted. - Otherwise, ``max_stepsize`` is used. This ensures that the solver does not - take over-confident steps. This is why ``max_stepsize`` is the most important - hyper-parameter. + while for the ``SPS+`` variant, it is given by + + .. math:: + + \gamma = \max\left\{0, \frac{\text{fun}(x) - \text{fun}(x^\star)}{\|\nabla \text{fun}(x)\|^2}\right\} + + and the step-size is zero whenever :math:`\|\nabla \text{fun}(x)\|^2` is zero. + + In all cases, the step size is then used in the update + + .. math:: + + v_{t+1} &= \text{momentum} v_t - \gamma \nabla \text{fun}(x) \\ + x_{t+1} &= x_t + v_{t+1} - This implementation assumes that the interpolation property holds: - the global optimum over D must also be a global optimum for any finite sample of D - This is typically achieved by overparametrized models (e.g neural networks) - in classification tasks with separable classes, or on regression tasks without noise. Attributes: fun: a function of the form ``fun(params, *args, **kwargs)``, where ``params`` are parameters of the model, ``*args`` and ``**kwargs`` are additional arguments. + value_and_grad: whether ``fun`` just returns the value (False) or both the + value and gradient (True). has_aux: whether ``fun`` outputs auxiliary data or not. If ``has_aux`` is False, ``fun`` is expected to be scalar-valued. @@ -77,8 +102,11 @@ class PolyakSGD(base.StochasticSolver): ``(value, aux), grad = fun(...)``. At each iteration of the algorithm, the auxiliary outputs are stored in ``state.aux``. + fun_min: a lower bound on fun. - max_stepsize: a maximum step size to use. + variant: which version of the stochastic Polyak step-size is implemented. + Can be one of "SPS_max" or "SPS+". + max_stepsize: a maximum step size to use. Only used when variant="SPS_max". delta: a value to add in the denominator of the update (default: 0). momentum: momentum parameter, 0 corresponding to no momentum. pre_update: a function to execute before the solver's update. @@ -98,21 +126,22 @@ class PolyakSGD(base.StochasticSolver): References: Berrada, Leonard and Zisserman, Andrew and Kumar, M Pawan. - "Training neural networks for and by interpolation". + `"Training neural networks for and by interpolation" `_. International Conference on Machine Learning, 2020. - https://arxiv.org/abs/1906.05661 + Loizou, Nicolas and Vaswani, Sharan and Laradji, Issam Hadj and Lacoste-Julien, Simon. - "Stochastic polyak step-size for sgd: An adaptive learning rate for fast - convergence". + `"Stochastic polyak step-size for sgd: An adaptive learning rate for fast + convergence" `_. International Conference on Artificial Intelligence and Statistics, 2021. - https://arxiv.org/abs/2002.10542 """ fun: Callable value_and_grad: bool = False has_aux: bool = False + fun_min: float = 0.0 + variant: str = "SPS_max" max_stepsize: float = 1.0 delta: float = 0.0 momentum: float = 0.0 @@ -182,8 +211,20 @@ def update(self, (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs) grad_sqnorm = tree_l2_norm(grad, squared=True) - stepsize = jnp.minimum(value / (grad_sqnorm + self.delta), - self.max_stepsize) + if self.variant == "SPS_max": + stepsize = jnp.minimum( + (value - self.fun_min) / (grad_sqnorm + self.delta), self.max_stepsize + ) + elif self.variant == "SPS+": + # if grad_sqnorm is smaller than machine epsilon, we set the stepsize to 0 + stepsize = jnp.where( + grad_sqnorm <= jnp.finfo(dtype).eps, + 0.0, + jnp.maximum((value - self.fun_min) / grad_sqnorm, 0), + ) + else: + raise NotImplementedError(f"Unknown variant {self.variant}") + stepsize = stepsize.astype(state.stepsize.dtype) if self.momentum == 0: @@ -218,9 +259,12 @@ def __hash__(self): def __post_init__(self): super().__post_init__() - self._fun, self._grad_fun, self._value_and_grad_fun = \ - base._make_funs_with_aux(fun=self.fun, - value_and_grad=self.value_and_grad, - has_aux=self.has_aux) + self._fun, self._grad_fun, self._value_and_grad_fun = ( + base._make_funs_with_aux( + fun=self.fun, + value_and_grad=self.value_and_grad, + has_aux=self.has_aux, + ) + ) self.reference_signature = self.fun diff --git a/tests/polyak_sgd_test.py b/tests/polyak_sgd_test.py index 54d86c9f..a04b863d 100644 --- a/tests/polyak_sgd_test.py +++ b/tests/polyak_sgd_test.py @@ -28,21 +28,47 @@ class PolyakSgdTest(test_util.JaxoptTestCase): - @parameterized.product(momentum=[0.0, 0.9]) - def test_logreg_with_intercept_manual_loop(self, momentum): - X, y = datasets.make_classification(n_samples=10, n_features=5, n_classes=3, - n_informative=3, random_state=0) - data = (X, y) - l2reg = 100.0 + @parameterized.product(momentum=[0.0, 0.9], sps_variant=['SPS_max', 'SPS+']) + def test_logreg_overparameterized(self, momentum, sps_variant): + # Test SPS on an over-parameterized logistic regression problem. + # The loss' infimum is zero and SPS should converge to a minimizer. + data = datasets.make_classification( + n_samples=10, n_features=10, random_state=0 + ) # fun(params, data) fun = objective.l2_multiclass_logreg_with_intercept + n_classes = len(jnp.unique(data[1])) + + w_init = jnp.zeros((data[0].shape[1], n_classes)) + b_init = jnp.zeros(n_classes) + params = (w_init, b_init) + + opt = PolyakSGD(fun=fun, fun_min=0, momentum=momentum, variant=sps_variant) + error_init = opt.l2_optimality_error(params, l2reg=0, data=data) + params, _ = opt.run(params, l2reg=0., data=data) + + # Check optimality conditions. + error = opt.l2_optimality_error(params, l2reg=0., data=data) + self.assertLessEqual(error / error_init, 0.01) + + @parameterized.product(momentum=[0.0, 0.9], sps_variant=['SPS_max', 'SPS+']) + def test_logreg_with_intercept_manual_loop(self, momentum, sps_variant): + x, y = datasets.make_classification(n_samples=10, n_features=5, n_classes=3, + n_informative=3, random_state=0) + data = (x, y) + l2reg = 0.1 + # fun(params, l2reg, data) + fun = objective.l2_multiclass_logreg_with_intercept n_classes = len(jnp.unique(y)) - W_init = jnp.zeros((X.shape[1], n_classes)) + w_init = jnp.zeros((x.shape[1], n_classes)) b_init = jnp.zeros(n_classes) - params = (W_init, b_init) + params = (w_init, b_init) - opt = PolyakSGD(fun=fun, max_stepsize=0.01, momentum=momentum) + opt = PolyakSGD( + fun=fun, fun_min=0.6975, momentum=momentum, variant=sps_variant + ) + error_init = opt.l2_optimality_error(params, l2reg=l2reg, data=data) state = opt.init_state(params, l2reg=l2reg, data=data) for _ in range(200): @@ -50,7 +76,7 @@ def test_logreg_with_intercept_manual_loop(self, momentum): # Check optimality conditions. error = opt.l2_optimality_error(params, l2reg=l2reg, data=data) - self.assertLessEqual(error, 0.05) + self.assertLessEqual(error / error_init, 0.02) @parameterized.product(has_aux=[True, False]) def test_logreg_with_intercept_run(self, has_aux):