-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
avoid closing over dynamic jax tracers in the bisection solver #412
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Internally in jaxopt, we (should) try to maintain that the parameters to a solver class are "static" from jax's point of view. One reason for this is that class attributes might be read by any of the class' methods, including `run`. Meanwhile a bound `run` method serves as the solver function, which is passed through jaxopt's core `custom_root` mechanism in order to set it up with an implicit-diff-based custom VJP. Currently, that `custom_root` mechanism assumes that the solver function it receives has, in its closure, no arrays that are involved in any of jax's differentiation or staging. Re-stated using jax-internal jargon: `custom_root` assumes that the solver function it receives does not have tracers in its closure. But: a bound Python method (e.g. `o.run`) carries its bound instance (e.g. `o`) in its closure. The code in `bisection_test.py` did not conform to this requirement that all class attributes are static (in the jax transformation sense). Specifically, it constructed a `Bisection` instance, within a jitted function, given parameters (`lower` and `upper`) that depend on inputs to the jitted function. This change fixes that by hoisting the construction of this `Bisection` out from the jitted function (and marking it a static argument). Doing this fixes a jax "tracer leak" error raised in the jaxopt CI recently. This was not an issue until jax released version 0.4.4, for the rather technical reason that jax changed its `jit` implementation such that it eagerly stages out its function argument. This in turn led jax to encounter "jit tracers" (corresponding to `Bisection.{lower,upper}`) within the closure of a solver function (`Bisection.run`) in the course of custom-differentiating the solver function.
mblondel
approved these changes
Mar 15, 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you Roy!
vroulet
added a commit
to vroulet/jaxopt
that referenced
this pull request
Jun 13, 2023
commit 3614817 Author: Vincent Roulet <vroulet@google.com> Date: Tue Jun 13 15:07:47 2023 -0700 added approx wolfe condition to fix bug at high precision commit d2211bc Author: Vincent Roulet <vroulet@google.com> Date: Mon Jun 12 13:43:05 2023 -0700 fix attempt for failed test correctness lbfgsb commit 024bca8 Merge: c21479e 257b673 Author: Vincent Roulet <vroulet@google.com> Date: Mon Jun 12 09:06:42 2023 -0700 Merge branch 'main' into zoom_linesearch_as_iterative_linesearch_solver commit 257b673 Merge: d40b6d7 0cfa882 Author: JAXopt authors <no-reply@google.com> Date: Mon Jun 12 08:20:07 2023 -0700 Merge pull request google#440 from mblondel:lbfb_failure PiperOrigin-RevId: 539658221 commit 0cfa882 Author: Mathieu Blondel <mblondel@google.com> Date: Mon Jun 12 16:28:14 2023 +0200 Drop Python 3.7 support. commit c21479e Merge: 675ae9d d108ddf Author: Vincent Roulet <vroulet@google.com> Date: Fri Jun 9 15:33:21 2023 -0700 merging with main branch commit d108ddf Author: Srinivas Vasudevan <srvasude@google.com> Date: Tue Jun 6 13:49:34 2023 -0700 Internal change PiperOrigin-RevId: 538281622 commit 414b5b9 Author: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat May 27 09:14:40 2023 +0200 Fix typos commit dcae685 Author: Mathieu Blondel <mblondel@google.com> Date: Fri May 26 23:26:05 2023 +0200 Release v0.7. commit 83c1370 Author: Chansoo Lee <chansoo@google.com> Date: Fri May 26 09:21:31 2023 -0700 Internal change PiperOrigin-RevId: 535636577 commit af48e7b Author: Zaccharie Ramzi <zaccharie.ramzi@gmail.com> Date: Mon May 8 20:19:37 2023 +0200 fixed imaml tutorial (speed and correctness): phase application, data generation, outer loss computation commit 5392a81 Author: Fabian Pedregosa <pedregosa@google.com> Date: Wed Feb 15 15:33:31 2023 +0100 Misc improvements in resnet_flax example. * Better data augmentation, leading to 88% accuracy (from 70%) * Plots showing the data augmentation in action. * Options use the same format as distributed training examples. * Changed solver from Adam to SGD for better accuracy. commit a260cfb Author: Emily Fertig <emilyaf@google.com> Date: Tue Apr 25 08:47:41 2023 -0700 Internal change PiperOrigin-RevId: 526978774 commit e593c89 Author: Vincent Roulet <vroulet@google.com> Date: Tue Apr 4 17:54:20 2023 -0700 Fixed prox to handle pytrees Fixed prox_lasso and prox_elastic_net to handle pytrees as inputs and floats for hyperparameters Added tests commit 675ae9d Author: Vincent Roulet <vroulet@google.com> Date: Fri Jun 9 13:27:24 2023 -0700 integrated new zoom linesearch in all solvers, simplifying them commit 8da43d3 Author: Vincent Roulet <vroulet@google.com> Date: Wed Jun 7 14:20:41 2023 -0700 minor edit commit 416e687 Author: Vincent Roulet <vroulet@google.com> Date: Wed Jun 7 14:10:16 2023 -0700 fix copyright year commit 8f2d2c2 Author: Vincent Roulet <vroulet@google.com> Date: Wed Jun 7 11:29:03 2023 -0700 fixed dtypes commit 8ca6e67 Author: Vincent Roulet <vroulet@google.com> Date: Tue Jun 6 22:19:02 2023 -0700 minor edits commit 4c847ea Author: Vincent Roulet <vroulet@google.com> Date: Tue Jun 6 21:41:53 2023 -0700 convert zoom_linesearch into an IterativeLineSearchSolver commit d40b6d7 Author: Srinivas Vasudevan <srvasude@google.com> Date: Tue Jun 6 13:49:34 2023 -0700 Internal change PiperOrigin-RevId: 538281622 commit e87b9b9 Merge: 58ce7cb 3ccb6b9 Author: JAXopt authors <no-reply@google.com> Date: Sat May 27 08:28:11 2023 -0700 Merge pull request google#435 from gdalle:patch-1 PiperOrigin-RevId: 535860339 commit 3ccb6b9 Author: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat May 27 09:14:40 2023 +0200 Fix typos commit 58ce7cb Merge: 541bbaa 7cf0567 Author: JAXopt authors <no-reply@google.com> Date: Fri May 26 14:54:44 2023 -0700 Merge pull request google#434 from mblondel:release_0.7 PiperOrigin-RevId: 535720960 commit 7cf0567 Author: Mathieu Blondel <mblondel@google.com> Date: Fri May 26 23:26:05 2023 +0200 Release v0.7. commit 541bbaa Author: Chansoo Lee <chansoo@google.com> Date: Fri May 26 09:21:31 2023 -0700 Internal change PiperOrigin-RevId: 535636577 commit b934387 Merge: b3b6a0d f6c0ca0 Author: JAXopt authors <no-reply@google.com> Date: Mon May 15 07:04:49 2023 -0700 Merge pull request google#425 from zaccharieramzi:fix-maml-example PiperOrigin-RevId: 532098566 commit f6c0ca0 Author: Zaccharie Ramzi <zaccharie.ramzi@gmail.com> Date: Mon May 8 20:19:37 2023 +0200 fixed imaml tutorial (speed and correctness): phase application, data generation, outer loss computation commit b3b6a0d Merge: 4aa9bc9 fff693f Author: JAXopt authors <no-reply@google.com> Date: Thu Apr 27 07:24:22 2023 -0700 Merge pull request google#401 from fabianp:resnet_flax PiperOrigin-RevId: 527570739 commit 4aa9bc9 Author: Emily Fertig <emilyaf@google.com> Date: Tue Apr 25 08:47:41 2023 -0700 Internal change PiperOrigin-RevId: 526978774 commit fff693f Author: Fabian Pedregosa <pedregosa@google.com> Date: Wed Feb 15 15:33:31 2023 +0100 Misc improvements in resnet_flax example. * Better data augmentation, leading to 88% accuracy (from 70%) * Plots showing the data augmentation in action. * Options use the same format as distributed training examples. * Changed solver from Adam to SGD for better accuracy. commit 4edd8ac Merge: 674a992 7da12ec Author: JAXopt authors <no-reply@google.com> Date: Wed Apr 12 14:00:46 2023 -0700 Merge pull request google#420 from vroulet:fix_prox_pytree PiperOrigin-RevId: 523798658 commit 7da12ec Author: Vincent Roulet <vroulet@google.com> Date: Tue Apr 4 17:54:20 2023 -0700 Fixed prox to handle pytrees Fixed prox_lasso and prox_elastic_net to handle pytrees as inputs and floats for hyperparameters Added tests commit 674a992 Merge: 1019f7b 18c4bd3 Author: JAXopt authors <no-reply@google.com> Date: Wed Apr 5 01:39:26 2023 -0700 Merge pull request google#418 from LawrenceMMStewart:main PiperOrigin-RevId: 521986072 commit 18c4bd3 Author: LawrenceMMStewart <lmmstewart@proton.me> Date: Fri Mar 31 12:12:57 2023 +0200 added control variate to make_perturbed_argmax commit 1019f7b Merge: 36d7a0d 7f54e31 Author: JAXopt authors <no-reply@google.com> Date: Thu Mar 23 18:16:45 2023 -0700 Merge pull request google#382 from aymgal:pr-hess_inv PiperOrigin-RevId: 519014528 commit 36d7a0d Author: Quentin Berthet <qberthet@google.com> Date: Tue Mar 21 06:08:00 2023 -0700 Internal change PiperOrigin-RevId: 518250976 commit 7f54e31 Merge: a4f3956 ea8e0f1 Author: Aymeric Galan <aymeric.galan@gmail.com> Date: Thu Mar 16 15:13:49 2023 +0100 Merge remote-tracking branch 'upstream/main' into pr-hess_inv commit ea8e0f1 Merge: cb6ed9a e196ece Author: JAXopt authors <no-reply@google.com> Date: Wed Mar 15 13:40:56 2023 -0700 Merge pull request google#412 from froystig:jit-bisect-test PiperOrigin-RevId: 516917371 commit e196ece Author: Roy Frostig <frostig@google.com> Date: Wed Mar 15 17:25:40 2023 +0000 avoid closing over dynamic jax tracers in the bisection solver Internally in jaxopt, we (should) try to maintain that the parameters to a solver class are "static" from jax's point of view. One reason for this is that class attributes might be read by any of the class' methods, including `run`. Meanwhile a bound `run` method serves as the solver function, which is passed through jaxopt's core `custom_root` mechanism in order to set it up with an implicit-diff-based custom VJP. Currently, that `custom_root` mechanism assumes that the solver function it receives has, in its closure, no arrays that are involved in any of jax's differentiation or staging. Re-stated using jax-internal jargon: `custom_root` assumes that the solver function it receives does not have tracers in its closure. But: a bound Python method (e.g. `o.run`) carries its bound instance (e.g. `o`) in its closure. The code in `bisection_test.py` did not conform to this requirement that all class attributes are static (in the jax transformation sense). Specifically, it constructed a `Bisection` instance, within a jitted function, given parameters (`lower` and `upper`) that depend on inputs to the jitted function. This change fixes that by hoisting the construction of this `Bisection` out from the jitted function (and marking it a static argument). Doing this fixes a jax "tracer leak" error raised in the jaxopt CI recently. This was not an issue until jax released version 0.4.4, for the rather technical reason that jax changed its `jit` implementation such that it eagerly stages out its function argument. This in turn led jax to encounter "jit tracers" (corresponding to `Bisection.{lower,upper}`) within the closure of a solver function (`Bisection.run`) in the course of custom-differentiating the solver function. commit cb6ed9a Author: Emily Fertig <emilyaf@google.com> Date: Wed Mar 15 10:43:20 2023 -0700 Internal change PiperOrigin-RevId: 516868349 commit a4f3956 Author: Aymeric Galan <aymeric.galan@gmail.com> Date: Thu Mar 9 16:32:48 2023 +0100 Attempt to fix failing test on python 3.9 regarding 32 vs 64-bits numbers commit abe44e4 Author: Aymeric Galan <aymeric.galan@gmail.com> Date: Thu Jan 19 13:05:54 2023 +0100 Add inverse hessian approximation to the returned state Add custom pytree registration for LbfgsInvHessProduct result Fix issue with undefined class Add docstring Remove drepecated comments Fix scipy.optimize module not found LbfgsInvHessProductPyTree constructor now compliant with JAX commit 040c8fc Author: Yash Katariya <yashkatariya@google.com> Date: Tue Feb 21 15:24:39 2023 -0800 Internal change PiperOrigin-RevId: 511320677 commit a51d5ed Merge: 52d56ab f65001b Author: JAXopt authors <no-reply@google.com> Date: Fri Feb 17 04:34:13 2023 -0800 Merge pull request google#398 from mblondel:add_isotonic_module PiperOrigin-RevId: 510398823 commit 52d56ab Merge: 0472831 6e6a0ab Author: JAXopt authors <no-reply@google.com> Date: Fri Feb 17 01:08:50 2023 -0800 Merge pull request google#397 from mblondel:remove_matplotlib PiperOrigin-RevId: 510364159 commit f65001b Author: Mathieu Blondel <mblondel@google.com> Date: Thu Feb 16 19:49:06 2023 +0100 Add isotonic module. commit 6e6a0ab Author: Mathieu Blondel <mblondel@google.com> Date: Thu Feb 16 19:34:24 2023 +0100 Update requirements. commit 0472831 Author: Peter Hawkins <phawkins@google.com> Date: Thu Feb 9 09:03:52 2023 -0800 Internal change PiperOrigin-RevId: 508389531 commit e1d8355 Merge: 0c8b25b 730b5a6 Author: JAXopt authors <no-reply@google.com> Date: Thu Feb 9 07:43:10 2023 -0800 Merge pull request google#394 from mblondel:release_0.6 PiperOrigin-RevId: 508371158 commit 730b5a6 Author: Mathieu Blondel <mblondel@google.com> Date: Thu Feb 9 16:02:32 2023 +0100 Release v0.6.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Internally in jaxopt, we (should) try to maintain that the parameters to a solver class are "static" from jax's point of view.
One reason for this is that class attributes might be read by any of the class' methods, including
run
. Meanwhile a boundrun
method serves as the solver function, which is passed through jaxopt's corecustom_root
mechanism in order to set it up with an implicit-diff-based custom VJP. Currently, thatcustom_root
mechanism assumes that the solver function it receives has, in its closure, no arrays that are involved in any of jax's differentiation or staging. Re-stated using jax-internal jargon:custom_root
assumes that the solver function it receives does not have tracers in its closure. But: a bound Python method (e.g.o.run
) carries its bound instance (e.g.o
) in its closure.The code in
bisection_test.py
did not conform to this requirement that all class attributes are static (in the jax transformation sense). Specifically, it constructed aBisection
instance, within a jitted function, given parameters (lower
andupper
) that depend on inputs to the jitted function. This change fixes that by hoisting the construction of thisBisection
out from the jitted function (and marking it a static argument).Doing this fixes a jax "tracer leak" error raised in the jaxopt CI recently. This was not an issue until jax released version 0.4.4, for the rather technical reason that jax changed its
jit
implementation such that it eagerly stages out its function argument. This in turn led jax to encounter "jit tracers" (corresponding toBisection.{lower,upper}
) within the closure of a solver function (Bisection.run
) in the course of custom-differentiating the solver function.