You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I went and authored this thing so that I can try to train some transformers on smaller datasets, hopefully without running into too many issues with generalization. I did not like how SAM polluted the top levels of my notebooks when I wrote Jupyter cells that fit models in Haiku and it's very useful functionality that I fully intend to use many many times, so instead I put it in this repository. I think it could be a part of optax.
Designing a way for it to be consumed was more than a little bit hairy, and so I would like some feedback about it before opening a PR that attempts to contribute these changes:
Is this functionality beyond the scope of optax?
If this functionality is a good candidate to build into optax, should I change the interface?
how do I test it?? My current method of checking if it diverges on a toy problem seems inadequate.
It is ambiguous whether it is better described as a gradient transform unto itself or a thing that "wraps" a gradient transform, so I just made it a transform with the intent that it be placed in a sequence with chain, prior to something like SGD or any associated postprocessing steps, e.g. clipping the gradient (see the source, the demo).
That may not be the best choice, particularly since I'm not certain that this is consistent with the original, which may feed the output of a first-order optimizer like Adam into the ascent step, not the first-order gradients of model parameters against the objective as computed by auto-differentiation.
The distinction is very important because adaptive methods like Adam are not simply going to apply first-order gradients during descent, so performing ascent using gradients derived from their output is a meaningful change. Third party implementations like this PyTorch implementation that I followed do not seem to have interpreted the pseudocode in Algorithm 1 of the paper as closing over first-order optimizers during ascent, but I couldn't tell whether my own repo is consistent with the authors just by reading their pseudocode, and... I'm kind of bad at reading their flax. Aside from my toy training script there is no code which tests it.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I went and authored this thing so that I can try to train some transformers on smaller datasets, hopefully without running into too many issues with generalization. I did not like how SAM polluted the top levels of my notebooks when I wrote Jupyter cells that fit models in Haiku and it's very useful functionality that I fully intend to use many many times, so instead I put it in this repository. I think it could be a part of optax.
Designing a way for it to be consumed was more than a little bit hairy, and so I would like some feedback about it before opening a PR that attempts to contribute these changes:
It is ambiguous whether it is better described as a gradient transform unto itself or a thing that "wraps" a gradient transform, so I just made it a transform with the intent that it be placed in a sequence with
chain
, prior to something like SGD or any associated postprocessing steps, e.g. clipping the gradient (see the source, the demo).That may not be the best choice, particularly since I'm not certain that this is consistent with the original, which may feed the output of a first-order optimizer like Adam into the ascent step, not the first-order gradients of model parameters against the objective as computed by auto-differentiation.
The distinction is very important because adaptive methods like Adam are not simply going to apply first-order gradients during descent, so performing ascent using gradients derived from their output is a meaningful change. Third party implementations like this PyTorch implementation that I followed do not seem to have interpreted the pseudocode in Algorithm 1 of the paper as closing over first-order optimizers during ascent, but I couldn't tell whether my own repo is consistent with the authors just by reading their pseudocode, and... I'm kind of bad at reading their flax. Aside from my toy training script there is no code which tests it.
thoughts on how to proceed?
Beta Was this translation helpful? Give feedback.
All reactions