-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement_gsam_jax #4
Conversation
Hi, Thank you for contribution. As stated in the readme, we normally do not accept external contributions, but we are happy to make an exception for open-source implementations of published projects developed in However, according to the codebase principles, project-specific code should not add complexity to the core library parts, such as the main train loop. Thus, standalone projects are expected to fork the main train loop into |
Thanks a lot for the clarification! I will re-format and re-submit later according to the examples. |
hey, we now have an example of a project-specific trainer here: https://github.com/google-research/big_vision/tree/main/big_vision/trainers/proj/distill If you are still interested in submitting gsam (we would like it!), could you sync to head and instead of modifying the core Sorry for the delay on our side! |
Thanks a lot for the example! I have moved all changes to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your patience!
Would it be possible to add an example config? Ideally, one which produces some reference run from the paper. It would live in configs/proj/gsam/whatever.py? You would probably fork it off https://github.com/google-research/big_vision/blob/main/big_vision/configs/vit_i1k.py.
Also, in an ideal world, you would actually run this config, and show that it matches a number in the paper, and link the result here or at the top of the config, is that still possible, or you can't do that anymore?
# limitations under the License. | ||
|
||
"""Training loop example. | ||
This is a basic variant of a training loop, good starting point for fancy ones. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should probably update this to something like "Trainer that implements SAM/GSAM optimizers"?
|
||
if config.get("GSAM", False): | ||
# Get the current learning rate. | ||
learning_rate = sched_fns_cpu[0](step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I highly doubt this is what you want. Note that you're calling a function that's been jit'ed onto the CPU from within a function that's pmap'ed onto GPU/TPU, so we have transfer at every single step happening here.
Why not call sched_fn[0](step)
instead?
return getattr(u, config.get("loss", "sigmoid_xent"))( | ||
logits=logits, labels=labels) | ||
|
||
if config.get("GSAM", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since here we're specifically in the gsam/train.py
, we can remove this config variable and if
statement, and always execute the GSAM branch.
|
||
ALPHA = config.get("alpha", 0.05) | ||
ADAPTIVE_PERTURBATION = config.get("adaptive_perturbation", False) | ||
MINIMIZE_FP = config.get("minimize_fp", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each of these is actually only used exactly once below, so in our code style, we would not assign them to any variable, but just inline them where they are used, see comment below.
a - g_clean_projection_norm * b, g_clean, g_robust_normalized) | ||
|
||
# Get GSAM gradient. | ||
g_gsam = jax.tree_multimap( lambda a, b: a - b * alpha, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an awkward space that slipped in front of lambda
, please remove.
a - g_robust_projection_norm * b, g_robust, g_clean_normalized) | ||
|
||
# Get GSAM gradient. | ||
g_gsam = jax.tree_multimap( lambda a, b: a + b * alpha, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same awkward space here.
# Per-worker perturbation. | ||
if adaptive_perturbation: | ||
param_sam = jax.tree_multimap(lambda a, b: a + jnp.abs(a) * sam_rho * b / (g_clean_length + eps), | ||
base_opt.target, g_clean) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Misaligned line continuation
base_opt.target, g_clean) | ||
else: | ||
param_sam = jax.tree_multimap(lambda a, b: a + sam_rho * b / (g_clean_length + eps), | ||
base_opt.target, g_clean) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
if lr_max == lr_min: | ||
sam_rho = rho_max | ||
else: | ||
sam_rho = rho_min + (rho_max - rho_min) * (lr - lr_min) / (lr_max - lr_min) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes me wonder (sorry I haven't read the GSAM paper), do you really want to linearly interpolate rho, or would you ideally want to apply the same scheduling function as the learning-rate, e.g. cosine for example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusion. I want to apply the same scheduler but with a different scale / upper_lower bound.
In the paper I only used linear lr scheduler for experiments, and in theory (and proofs part of paper) the two schedules are assumed to be both of inverse sqrt.
Also, once you're done, could you squash all the commits into just a single one? |
Implement GSAM algorithm proposed in Surrogate gap minimization improves sharpness-aware training, ICLR 2022, which is an improvement over SAM (Sharpness-Aware Minimization)
When
config.rho_max == config.rho_min
andconfig.alpha=0.0
, the GSAM algorithm reduces to SAM.