Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to use control variate in Jacobian of perturbed_argmax. #418

Merged
merged 1 commit into from
Apr 5, 2023

Conversation

LawrenceMMStewart
Copy link
Contributor

I have closed and re-opened my previous PR in order to fix an email issue with the CLA agreement.

The Jacobian of the perturbed argmax Proposition 3.1 Berthet et al. explodes as sigma tends to 0. It is possible to avoid this problem by replacing y(theta + sigma Z) with y(theta + sigma Z) - y(theta) in the monte-carlo estimator, as described in Le Lidec 2021 and Bach 2023.

This PR implements this minor change by adding a boolean variable control_variate to perturbations.make_perturbed_argmax, which changes how the Monte-Carlo estimate for the Jacobian is calculated. Relevant unit tests have been added.

I have provided an example notebook, which demonstrates the functionality of the PR on a simple differentiable ranking examples similar to that of the existing jaxopt example gallery.

@q-berthet
Copy link
Collaborator

Thanks Lawrence! LGTM

Copy link
Collaborator

@mblondel mblondel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you @LawrenceMMStewart!

@copybara-service copybara-service bot merged commit 674a992 into google:main Apr 5, 2023
vroulet added a commit to vroulet/jaxopt that referenced this pull request Jun 13, 2023
commit 3614817
Author: Vincent Roulet <vroulet@google.com>
Date:   Tue Jun 13 15:07:47 2023 -0700

    added approx wolfe condition to fix bug at high precision

commit d2211bc
Author: Vincent Roulet <vroulet@google.com>
Date:   Mon Jun 12 13:43:05 2023 -0700

    fix attempt for failed test correctness lbfgsb

commit 024bca8
Merge: c21479e 257b673
Author: Vincent Roulet <vroulet@google.com>
Date:   Mon Jun 12 09:06:42 2023 -0700

    Merge branch 'main' into zoom_linesearch_as_iterative_linesearch_solver

commit 257b673
Merge: d40b6d7 0cfa882
Author: JAXopt authors <no-reply@google.com>
Date:   Mon Jun 12 08:20:07 2023 -0700

    Merge pull request google#440 from mblondel:lbfb_failure

    PiperOrigin-RevId: 539658221

commit 0cfa882
Author: Mathieu Blondel <mblondel@google.com>
Date:   Mon Jun 12 16:28:14 2023 +0200

    Drop Python 3.7 support.

commit c21479e
Merge: 675ae9d d108ddf
Author: Vincent Roulet <vroulet@google.com>
Date:   Fri Jun 9 15:33:21 2023 -0700

    merging with main branch

commit d108ddf
Author: Srinivas Vasudevan <srvasude@google.com>
Date:   Tue Jun 6 13:49:34 2023 -0700

    Internal change

    PiperOrigin-RevId: 538281622

commit 414b5b9
Author: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
Date:   Sat May 27 09:14:40 2023 +0200

    Fix typos

commit dcae685
Author: Mathieu Blondel <mblondel@google.com>
Date:   Fri May 26 23:26:05 2023 +0200

    Release v0.7.

commit 83c1370
Author: Chansoo Lee <chansoo@google.com>
Date:   Fri May 26 09:21:31 2023 -0700

    Internal change

    PiperOrigin-RevId: 535636577

commit af48e7b
Author: Zaccharie Ramzi <zaccharie.ramzi@gmail.com>
Date:   Mon May 8 20:19:37 2023 +0200

    fixed imaml tutorial (speed and correctness): phase application, data generation, outer loss computation

commit 5392a81
Author: Fabian Pedregosa <pedregosa@google.com>
Date:   Wed Feb 15 15:33:31 2023 +0100

    Misc improvements in resnet_flax example.

    * Better data augmentation, leading to 88% accuracy (from 70%)
    * Plots showing the data augmentation in action.
    * Options use the same format as distributed training examples.
    * Changed solver from Adam to SGD for better accuracy.

commit a260cfb
Author: Emily Fertig <emilyaf@google.com>
Date:   Tue Apr 25 08:47:41 2023 -0700

    Internal change

    PiperOrigin-RevId: 526978774

commit e593c89
Author: Vincent Roulet <vroulet@google.com>
Date:   Tue Apr 4 17:54:20 2023 -0700

    Fixed prox to handle pytrees

    Fixed prox_lasso and prox_elastic_net to handle pytrees as inputs and floats for hyperparameters

    Added tests

commit 675ae9d
Author: Vincent Roulet <vroulet@google.com>
Date:   Fri Jun 9 13:27:24 2023 -0700

    integrated new zoom linesearch in all solvers, simplifying them

commit 8da43d3
Author: Vincent Roulet <vroulet@google.com>
Date:   Wed Jun 7 14:20:41 2023 -0700

    minor edit

commit 416e687
Author: Vincent Roulet <vroulet@google.com>
Date:   Wed Jun 7 14:10:16 2023 -0700

    fix copyright year

commit 8f2d2c2
Author: Vincent Roulet <vroulet@google.com>
Date:   Wed Jun 7 11:29:03 2023 -0700

    fixed dtypes

commit 8ca6e67
Author: Vincent Roulet <vroulet@google.com>
Date:   Tue Jun 6 22:19:02 2023 -0700

    minor edits

commit 4c847ea
Author: Vincent Roulet <vroulet@google.com>
Date:   Tue Jun 6 21:41:53 2023 -0700

    convert zoom_linesearch into an IterativeLineSearchSolver

commit d40b6d7
Author: Srinivas Vasudevan <srvasude@google.com>
Date:   Tue Jun 6 13:49:34 2023 -0700

    Internal change

    PiperOrigin-RevId: 538281622

commit e87b9b9
Merge: 58ce7cb 3ccb6b9
Author: JAXopt authors <no-reply@google.com>
Date:   Sat May 27 08:28:11 2023 -0700

    Merge pull request google#435 from gdalle:patch-1

    PiperOrigin-RevId: 535860339

commit 3ccb6b9
Author: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
Date:   Sat May 27 09:14:40 2023 +0200

    Fix typos

commit 58ce7cb
Merge: 541bbaa 7cf0567
Author: JAXopt authors <no-reply@google.com>
Date:   Fri May 26 14:54:44 2023 -0700

    Merge pull request google#434 from mblondel:release_0.7

    PiperOrigin-RevId: 535720960

commit 7cf0567
Author: Mathieu Blondel <mblondel@google.com>
Date:   Fri May 26 23:26:05 2023 +0200

    Release v0.7.

commit 541bbaa
Author: Chansoo Lee <chansoo@google.com>
Date:   Fri May 26 09:21:31 2023 -0700

    Internal change

    PiperOrigin-RevId: 535636577

commit b934387
Merge: b3b6a0d f6c0ca0
Author: JAXopt authors <no-reply@google.com>
Date:   Mon May 15 07:04:49 2023 -0700

    Merge pull request google#425 from zaccharieramzi:fix-maml-example

    PiperOrigin-RevId: 532098566

commit f6c0ca0
Author: Zaccharie Ramzi <zaccharie.ramzi@gmail.com>
Date:   Mon May 8 20:19:37 2023 +0200

    fixed imaml tutorial (speed and correctness): phase application, data generation, outer loss computation

commit b3b6a0d
Merge: 4aa9bc9 fff693f
Author: JAXopt authors <no-reply@google.com>
Date:   Thu Apr 27 07:24:22 2023 -0700

    Merge pull request google#401 from fabianp:resnet_flax

    PiperOrigin-RevId: 527570739

commit 4aa9bc9
Author: Emily Fertig <emilyaf@google.com>
Date:   Tue Apr 25 08:47:41 2023 -0700

    Internal change

    PiperOrigin-RevId: 526978774

commit fff693f
Author: Fabian Pedregosa <pedregosa@google.com>
Date:   Wed Feb 15 15:33:31 2023 +0100

    Misc improvements in resnet_flax example.

    * Better data augmentation, leading to 88% accuracy (from 70%)
    * Plots showing the data augmentation in action.
    * Options use the same format as distributed training examples.
    * Changed solver from Adam to SGD for better accuracy.

commit 4edd8ac
Merge: 674a992 7da12ec
Author: JAXopt authors <no-reply@google.com>
Date:   Wed Apr 12 14:00:46 2023 -0700

    Merge pull request google#420 from vroulet:fix_prox_pytree

    PiperOrigin-RevId: 523798658

commit 7da12ec
Author: Vincent Roulet <vroulet@google.com>
Date:   Tue Apr 4 17:54:20 2023 -0700

    Fixed prox to handle pytrees

    Fixed prox_lasso and prox_elastic_net to handle pytrees as inputs and floats for hyperparameters

    Added tests

commit 674a992
Merge: 1019f7b 18c4bd3
Author: JAXopt authors <no-reply@google.com>
Date:   Wed Apr 5 01:39:26 2023 -0700

    Merge pull request google#418 from LawrenceMMStewart:main

    PiperOrigin-RevId: 521986072

commit 18c4bd3
Author: LawrenceMMStewart <lmmstewart@proton.me>
Date:   Fri Mar 31 12:12:57 2023 +0200

    added control variate to make_perturbed_argmax

commit 1019f7b
Merge: 36d7a0d 7f54e31
Author: JAXopt authors <no-reply@google.com>
Date:   Thu Mar 23 18:16:45 2023 -0700

    Merge pull request google#382 from aymgal:pr-hess_inv

    PiperOrigin-RevId: 519014528

commit 36d7a0d
Author: Quentin Berthet <qberthet@google.com>
Date:   Tue Mar 21 06:08:00 2023 -0700

    Internal change

    PiperOrigin-RevId: 518250976

commit 7f54e31
Merge: a4f3956 ea8e0f1
Author: Aymeric Galan <aymeric.galan@gmail.com>
Date:   Thu Mar 16 15:13:49 2023 +0100

    Merge remote-tracking branch 'upstream/main' into pr-hess_inv

commit ea8e0f1
Merge: cb6ed9a e196ece
Author: JAXopt authors <no-reply@google.com>
Date:   Wed Mar 15 13:40:56 2023 -0700

    Merge pull request google#412 from froystig:jit-bisect-test

    PiperOrigin-RevId: 516917371

commit e196ece
Author: Roy Frostig <frostig@google.com>
Date:   Wed Mar 15 17:25:40 2023 +0000

    avoid closing over dynamic jax tracers in the bisection solver

    Internally in jaxopt, we (should) try to maintain that the parameters
    to a solver class are "static" from jax's point of view.

    One reason for this is that class attributes might be read by any of
    the class' methods, including `run`. Meanwhile a bound `run` method
    serves as the solver function, which is passed through jaxopt's core
    `custom_root` mechanism in order to set it up with an
    implicit-diff-based custom VJP. Currently, that `custom_root`
    mechanism assumes that the solver function it receives has, in its
    closure, no arrays that are involved in any of jax's differentiation
    or staging. Re-stated using jax-internal jargon: `custom_root` assumes
    that the solver function it receives does not have tracers in its
    closure. But: a bound Python method (e.g. `o.run`) carries its bound
    instance (e.g. `o`) in its closure.

    The code in `bisection_test.py` did not conform to this requirement
    that all class attributes are static (in the jax transformation
    sense). Specifically, it constructed a `Bisection` instance, within a
    jitted function, given parameters (`lower` and `upper`) that depend on
    inputs to the jitted function. This change fixes that by hoisting the
    construction of this `Bisection` out from the jitted function (and
    marking it a static argument).

    Doing this fixes a jax "tracer leak" error raised in the jaxopt CI
    recently. This was not an issue until jax released version 0.4.4, for
    the rather technical reason that jax changed its `jit` implementation
    such that it eagerly stages out its function argument. This in turn
    led jax to encounter "jit tracers" (corresponding to
    `Bisection.{lower,upper}`) within the closure of a solver function
    (`Bisection.run`) in the course of custom-differentiating the solver
    function.

commit cb6ed9a
Author: Emily Fertig <emilyaf@google.com>
Date:   Wed Mar 15 10:43:20 2023 -0700

    Internal change

    PiperOrigin-RevId: 516868349

commit a4f3956
Author: Aymeric Galan <aymeric.galan@gmail.com>
Date:   Thu Mar 9 16:32:48 2023 +0100

    Attempt to fix failing test on python 3.9 regarding 32 vs 64-bits numbers

commit abe44e4
Author: Aymeric Galan <aymeric.galan@gmail.com>
Date:   Thu Jan 19 13:05:54 2023 +0100

    Add inverse hessian approximation to the returned state

    Add custom pytree registration for LbfgsInvHessProduct result

    Fix issue with undefined class

    Add docstring

    Remove drepecated comments

    Fix scipy.optimize module not found

    LbfgsInvHessProductPyTree constructor now compliant with JAX

commit 040c8fc
Author: Yash Katariya <yashkatariya@google.com>
Date:   Tue Feb 21 15:24:39 2023 -0800

    Internal change

    PiperOrigin-RevId: 511320677

commit a51d5ed
Merge: 52d56ab f65001b
Author: JAXopt authors <no-reply@google.com>
Date:   Fri Feb 17 04:34:13 2023 -0800

    Merge pull request google#398 from mblondel:add_isotonic_module

    PiperOrigin-RevId: 510398823

commit 52d56ab
Merge: 0472831 6e6a0ab
Author: JAXopt authors <no-reply@google.com>
Date:   Fri Feb 17 01:08:50 2023 -0800

    Merge pull request google#397 from mblondel:remove_matplotlib

    PiperOrigin-RevId: 510364159

commit f65001b
Author: Mathieu Blondel <mblondel@google.com>
Date:   Thu Feb 16 19:49:06 2023 +0100

    Add isotonic module.

commit 6e6a0ab
Author: Mathieu Blondel <mblondel@google.com>
Date:   Thu Feb 16 19:34:24 2023 +0100

    Update requirements.

commit 0472831
Author: Peter Hawkins <phawkins@google.com>
Date:   Thu Feb 9 09:03:52 2023 -0800

    Internal change

    PiperOrigin-RevId: 508389531

commit e1d8355
Merge: 0c8b25b 730b5a6
Author: JAXopt authors <no-reply@google.com>
Date:   Thu Feb 9 07:43:10 2023 -0800

    Merge pull request google#394 from mblondel:release_0.6

    PiperOrigin-RevId: 508371158

commit 730b5a6
Author: Mathieu Blondel <mblondel@google.com>
Date:   Thu Feb 9 16:02:32 2023 +0100

    Release v0.6.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants