-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement gsam in jax #8
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the comments. There's a few new ones regarding the config and my understanding of the schedule.
if lr_max == lr_min: | ||
sam_rho = rho_max | ||
else: | ||
sam_rho = rho_min + (rho_max - rho_min) * (lr - lr_min) / (lr_max - lr_min) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From #4:
Lucas:
This makes me wonder (sorry I haven't read the GSAM paper), do you really want to linearly interpolate rho, or would you ideally want to apply the same scheduling function as the learning-rate, e.g. cosine for example?
Juntang:
Sorry for the confusion. I want to apply the same scheduler but with a different scale / upper_lower bound.
In the paper I only used linear lr scheduler for experiments, and in theory (and proofs part of paper) the two schedules are assumed to be both of inverse sqrt.
Ah this is really unfortunate, there should be a much cleaner way to implement this eg using a squashed version of sched_fns
from the trainer!
But if you don't want to change the code to do this, then you should put an assert config.schedule.decay_type == "linear", "GSAM only implemented for linear lr schedule"
into the train.py
and add a little comment here in the code that goes something like
# Ideally, we'd use the same schedule as the lr here, just stretched to a different min/max.
# However, here we hard-code the linear scheduler only for convenience.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, sorry I did not explain this clearly. Suppose learning rate is lr(t)
for step t
, and there's an effective rho(t)
for each step t
. The code restricts rho(t)
to be linear w.r.t lr(t)
, however rho(t)
is not linear w.r.t t
. If we change lr(t)
to be some non-linear schedule such as cosine, the code here will generate a rho(t)
also in the shape of cosine, except lr_max != rho_max
and lr_min != rho_min
.
I tried to use a separate sched_fn
for rho(t)
, but it seems some schedules such as cosine
does not have the option to specify a non-zero min value rho_min
.
I wonder if you have any suggestions for a neater version using sched_fn
with configurable min value, or we keep the schedule code here?
config.wd = 0.3 # default is 0.0001; paper used 0.3, effective wd=0.3*lr | ||
config.schedule = dict( | ||
warmup_steps=10_000, | ||
decay_type='linear', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe append a short inline comment # only linear supported
# config.optax = dict(beta2_cap=0.95) | ||
|
||
config.lr = 0.003 | ||
config.wd = 0.3 # default is 0.0001; paper used 0.3, effective wd=0.3*lr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand you correctly, this is actually not correct anymore. We changed the code to always use "decoupled" values now. So you should specify here the effective wd you want, which is independent of the lr value (eg I think you want 0.001 here? as in 0.3 * 0.003 ≈ 0.001
?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing it out. Since the old version code uses lr * wd
as the effective wd
, and lr
changes with a schedule, the effective wd
also has a schedule. Switching to the new configuration, is effective wd
schedule available? I'm concerned if the effective wd schedule is disabled, using the same hyper-param might not be able to reproduce.
Regarding running experiments, I could give it a try at some point, but definitely impossible to do so this week. I would run exactly the config you provide, and you need to tell me exactly which number in which table of the paper it's supposed to reproduce. |
Thanks a lot! If the effective wd schedule is not figured out, I might need to find some way to either implement the old versioned weight decay schedule, or tune the hyper-param with the new setting. I wonder if you could point Ting to the docs on how to run this repository internally, and I'll submit codes from external, so we could re-run some experiments to reproduce? |
hey, sorry I got distracted by something urgent to finish, will get back to this in one of the next two weeks and am optimistic we can get it to work well :) edit: however, you did not yet tell me which exact number from the paper the config should be reproducing? |
Thanks for the response. Sorry about the missing number, it's supposed to reproduce the 76.8 for ViT-B/32 in Table 1 of https://openreview.net/pdf?id=edONMAnhLu- . I'm not fully sure about the new wdecay and lr scheduler. In the old version, lr scheduler is a single function (here lr scheduler func seems to be chained with a bunch of other schedulers); in the old version, wdecay is multiplied by lr, so wdecay is actually a scheduler rather than constant, is the new wdecay set to a constant? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi again, I am ow giving it a try, and there were a few more issues remaining. I have written them as comments, as well as given instructions on how to fix them. I am now able to actually run the trainer and the config, and will train it over night and see if it already reproduces the result or not.
I'll try a couple weight decay values to see what's the right one, but FYI, the weight decay is still following the schedule of the lr in the new code (linear decay in this case), it's just that the base lr is not multiplied to it.
|
||
def get_config(arg=None): | ||
"""Config for training.""" | ||
arg = bvcc.parse_arg(arg, variant='B/32', runlocal=False, aug='') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aug=''
is not used anymore and should be removed.
This configuration makes use of the "arg" to get_config to select which model | ||
to run, so a few examples are given below: | ||
|
||
Run training of a B/16 model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of these example commands need to be updated to this config file.
rho_min=0.1, | ||
alpha=0.6, | ||
adaptive_perturbation=False, | ||
minimize_fp=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we need to add two more parameters:
lr_max=config.get_ref('lr'),
lr_min=config.schedule.get_ref('linear_end'),
opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu) | ||
sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns] | ||
|
||
@partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to add , static_broadcasted_argnums=(5,))
here or this will not work: step
is a scalar, so we need to tell pmap
that, or it expects it to be replicated. So the final line should look like:
@partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1),
static_broadcasted_argnums=(5,))
def update_fn...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait no that's not what we should do, or it will recompile a new function every step 😅 Instead, we should indeed replicate the step we're passing, for example by passing flax.jax_utils.replicate(step)
at call-site.
However, this is creating a synchronization point, blocks prefetching, and creates a transfer at each step. Instead, we should really use the step number which is already replicated inside the optimizer. I'll find out how exactly tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I lost track of this. What we need to do is to not pass any step at all to the function, but instead get the step like this, around line 208:
step = bv_optax.get_count(opt)
learning_rate = schd_fns[0](step) * config.lr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, it turns out there's a minor issue with get_count
so that it can't be called inside a compiled function. I have a fix for it, but let's not roll too much into this PR, you could leave this as it is currently, and I'll fix it myself after the PR is merged.
Get the GSAM gradient (https://openreview.net/pdf?id=edONMAnhLu-) of the loss function. | ||
Args: | ||
loss_fn: the loss function. | ||
base_opt: the base optimizer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[1/2] This (base_opt.target
used below) does not work anymore with optax. Although it looks like you really use base_opt
only for getting to the params, so you can replace the argument by an actual params
argument, and then use that everywhere where you currently use base_opt.target
in this function.
logits=logits, labels=labels) | ||
|
||
learning_rate = sched_fns[0](step) | ||
l, grads = gsam_gradient(loss_fn=loss_fn, base_opt=opt, inputs=images, targets=labels, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[2/2] and then here you would pass params=params
instead of base_opt=opt
.
oh, and you have a bunch of small issues like wrong indentations, trailing spaces, etc. It would be helpful if you could run pylint with this config over it, then I don't need to fix these later on. |
and another minor nitpick: could you rename the config from |
|
||
# Per-worker perturbation. | ||
if adaptive_perturbation: | ||
param_sam = jax.tree_multimap(lambda a, b: a + jnp.abs(a) * sam_rho * b / (g_clean_length + eps), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.tree_multimap
does not exist anymore. It's now just jax.tree_map
.
Thanks a lot for the experiments, seems the config is not correct. I'll discuss it with Ting and see if we can directly compare the config file with the one we used for experiments. |
So far, no luck with any of (sigmoid->softmax, head-bias init, ) made it any better. Then, I also tried the follwing things:
So, I tried all the ideas I had regarding configuration, and at this point wonder if maybe there's a bug in the implementation. Could you please try on your side? Note that you don't need TPU access to run big_vision, it works great on GPUs too, we did update the README with instructions about that. Let me know when you figure out a setting/code change such that the loss does not explode in the first hundreds of steps anymore, and I can then try longer runs for you again. (I'll also ping Ting my runs internally). |
I forgot to mention, but I also tried a run with adam 1t momentum not in bfloat16, but in regular float32, and it makes no difference. Note this bfloat16 really just affects the 1st momentum buffer, nothing else. |
Thanks a lot for the feedback and experiments, I'll dig it out with Ting, and will post the working version here. Sorry for all the trouble with this PR. |
No worries, I will be happy and thankful to have up-to-date GSAM and SAM in the codebase! |
I also tried to run this with alpha=0, and it looks slightly better at the start, but still explodes after 1-2k step. |
I just noticed in one of your changes a few days ago, you did find a bug:
This looks very promising! So I patched it in and tried another run on top of the last one I mentioned here. It looks a lot better! It doesn't explode, and reaches 75.2/81.8/61.0 validation/real/v2 accuracy after 90 epochs. This not yet the expected 76.8/82.7/63.0 we're trying to reproduce, but it's getting much closer 🥳 However, the missing 1.6% are still significant, so we should find them before merging this. I carefully compared configs (already before, but once again) and didn't find a new discrepancy. |
@lucasb-eyer Thanks so much for running experiments! I'm also running an experiment on ViT-S/32, but takes much longer on my GPU machine, will also post results here after it finishes. The results for SAM are copied from https://arxiv.org/abs/2106.01548 table 2. For the gap of 1.6%, it might come from
In previous updates, I made a few changes that potentially make a difference, including the following:
(I'm not sure if 4 is necessary, just following my old code after meeting with Ting.) For 1, it's my fault that I did not realize For 2 and 3, it's also caused by my mistake with lr schedule. To reproduce the paper results, the absolute learning rate is a linear decay with I have merged the changes above in the latest PR, let me know if you have time to take a look. I'm also reproducing a ViT-S/32 results with my machine, it's a bit slow but will post it here once I get results. Thanks again for your help with this! |
No need to blame yourself alone, I also should have noticed ALL of these during review and testing, but didn't :) Happy you found them now! Let me start some runs right away, for 300ep, and report back later today. I actually ran all experiments on 8x8, but am curious why TPU topology would influence the results? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have good news. Running for 300ep largely closes the remaining gap. Here are my results:
setting | wd | val | real | v2 |
---|---|---|---|---|
your paper | 0.0009 | 76.8 | 82.7 | 63.0 |
gsam | 0.0009 | 77.18 | 82.77 | 63.24 |
gsam | 0.001 | 77.35 | 83.04 | 64.03 |
gsam (a=0) | 0.0009 | 76.02 | 81.56 | 62.31 |
sam (a=0, rho=0.15) | 0.0009 | 75.56 | 81.12 | 60.97 |
sam for vit/mixer paper | 0.0009 | 73.6 | 80.3 | 60.0 |
I am relatively sure wd=0.0009 is what you ran, but back then it was expressed differently in our configs, and the number you used was prettier. So I also ran 0.001 which is very close and a pretty number too =)
I only left a few more small comments about to code to address, and after that we can merge!
Note: we have further refactored the code a little bit since, but it is fine for you to submit the code as-is, and I will cleanup/update and test once more on my side afterwards, you've done more than enough already!
rho_min=0.1, | ||
alpha=0.6, | ||
adaptive_perturbation=False, | ||
minimize_fp=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those two (adaptive_perturbation
and minimize_fp
) are set to their default values. From the doc-comment and paper, it does not seem like something a regular user would tune (contrary to rho and alpha), so let's remove them fromt he config?
perturbation is element-wise multiplied by abs(p). | ||
minimize_fp: if True, min(f_p, h), original GSAM; | ||
if False, min(f, h), where f is the clean loss. | ||
f_p is the perturbed loss, h is the surrogate gap. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The doc comments of both adaptive_perturbation
and minimize_fp
both explain what they do in very technical terms, but it would be good to have a short high-level recommendation at the end as to when or why one would want to change them.
For example (the example is clearly wrong, because I don't understand them, but just to show the spirit of what I'm looking for):
adaptive_perturbation: if False, same perturbation as SAM,
treat all parameters as a single vector,
perturbation norm is calculated as the norm of the whole vector;
if True, for each parameter tensor p,
perturbation is element-wise multiplied by abs(p).
Try setting this to False when you use least-squares loss instead of KL-based ones.
minimize_fp: if True, min(f_p, h), original GSAM;
if False, min(f, h), where f is the clean loss.
f_p is the perturbed loss, h is the surrogate gap.
You probably want to leave this at its default unless you know what you're doing.
opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu) | ||
sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns] | ||
|
||
@partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I lost track of this. What we need to do is to not pass any step at all to the function, but instead get the step like this, around line 208:
step = bv_optax.get_count(opt)
learning_rate = schd_fns[0](step) * config.lr
return getattr(u, config.get("loss", "sigmoid_xent"))( | ||
logits=logits, labels=labels) | ||
|
||
learning_rate = sched_fns[0](step) * config["lr"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a ConfiDict, it can be the slightly nicer config.lr
.
|
||
learning_rate = sched_fns[0](step) * config["lr"] | ||
l, grads = gsam_gradient(loss_fn=loss_fn, params=params, inputs=images, | ||
targets=labels, lr=learning_rate, **config["gsam"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, slightly simpler **config.gsam
opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu) | ||
sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns] | ||
|
||
@partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, it turns out there's a minor issue with get_count
so that it can't be called inside a compiled function. I have a fix for it, but let's not roll too much into this PR, you could leave this as it is currently, and I'll fix it myself after the PR is merged.
Cool, I'm really excited to see the updated results, they outperform numbers in the paper! One minor thing is, GSAM reduces to SAM requires For TPU number, it's because that GSAM / SAM performs per-worker perturbation based on per-worker gradient in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your patience overall!
I'll merge it now, and will update the trainer according to the latest refactors early next week, such that it actually works :)
I also just realized that we should add a pointer to this from the README. I'll do so early next week too. |
Thanks so much for your help with the debug and PR! Regarding the For per-worker perturbation, the model soup paper seems to contradict the original SAM paper https://arxiv.org/pdf/2010.01412.pdf section 4.1. It defines I'm not quite sure about model soup implementations. In my implementation (and SAM), the process is:
I'm not quite sure with model soup, but I suspect if it draws an opposite conclusion from SAM paper, it might come from a different implementation. For example, if it switches the order of 3 and 4, first performs per-worker parameter update with per-worker If want to perform synced perturbation, we can add
param_sam is the same for all workers
|
Hi, @lucasb-eyer thanks for your review and comments. I reformated the files and squashed commits into a new PR (sorry I messed up the old PR and could not squash commits there). This PR includes:
config.gsam
and call gsam withl, grads = gsam_gradient(loss_fn=loss_fn, base_opt=opt, inputs=images, targets=labels, lr=learning_rate, **config["gsam"])
big_vision/configs/proj/gsam/vit_1k_gsam_no_aug.py
, the network used in GSAM paper usedpool_type='gap'
andrep_size=False
, which is different from the default config.Regarding reproducing the experiments, I wonder if it's possible for you to run the script (with 8x8 TPU cores to exactly match the paper)? I'm sorry I don't have access to TPU resources since I'm not affiliated with Google now, so I can't run experiments, though the checkpoints and the old version code that I used were kept in server. Thanks so much for your code review and help!