Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Scan & Gradient checkpointing in Flax #17399

Open
patrickvonplaten opened this issue May 24, 2022 · 5 comments · May be fixed by #18341
Open

[RFC] Scan & Gradient checkpointing in Flax #17399

patrickvonplaten opened this issue May 24, 2022 · 5 comments · May be fixed by #18341
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@patrickvonplaten
Copy link
Contributor

Feature request

We should add scan and remat (gradient checkpointing) to the most important Flax/JAX models (BERT, GPT2, OPT, T5, BART, Wav2Vec2).

Motivation

Scan allows for much faster compilation and memory savings and remat is the equivalent of gradient_checkpointing in PyTorch.

@sanchit-gandhi already uses both features in the Flax Seq2Seq Speech project - see: https://github.com/sanchit-gandhi/seq2seq-speech so it'd be quite trivial to get them working.

Implementation details:

Given that both scan and remat are not related to the model architecture, they should IMO not be in the model's config (We've done this mistake in PyTorch and don't want to repeat it here).

I would advocate for the following API:

model = FlaxBertForMaskedLM.from_pretrained("bert-base-cased")
model.scan()  # or model.scan_enable()
model.unscan()  # or model.scan_disable()

and

model = FlaxBertForMaskedLM.from_pretrained("bert-base-cased")
model.gradient_checkpoint_enable()
model.gradient_checkpoint_disable()

As can be seen here: https://github.com/sanchit-gandhi/seq2seq-speech/blob/b28d0c25c8fad0f9ffa6707f91f7aba320d44a4b/models/modeling_flax_wav2vec2.py#L504

We'll need to re-initialize the flax.linen.module inside the model. However this should be fine since it just means that we do

self.module = self.module_class(config=config, dtype=dtype, use_scan=True, **kwargs)
 self. _is_scan_enabled = True

similar to this line:

module = self.module_class(config=config, dtype=dtype, **kwargs)

We can see along the PR how much logic can reside in modeling_flax_utils.py and how much would go into the specific models, e.g. modeling_flax_wav2vec2.py.

The same API / logic could be used for the gradient_checkpointing.

Your contribution

Happy to give this implementation a shot with @sanchit-gandhi and @patil-suraj .

Also would love to hear feedback from @borisdayma @marcvanzee about the API

@patrickvonplaten patrickvonplaten changed the title RFC [RFC] Scan & Gradient checkpointing in Flax May 24, 2022
@borisdayma
Copy link
Contributor

I'm not sure you would need both versions within a same script (scan and unscanned, or with and without checkpointing which affects only training anyway).

Then maybe you could just add it directly as an arg to model.from_pretrained(..., scan=False, gradient_checkpointing=False)

You would just have to use some naming conventions on your params to see if you need to scan/unscan when loading a checkpoint.

@sanchit-gandhi
Copy link
Contributor

Suppose you have a training script, it would be useful to be able to use scan and remat during training for faster compile times and larger batch sizes, and then switch to unscan and no remat during eval for faster inference?

@borisdayma
Copy link
Contributor

I'm not sure it would be worth it:

  • Most of the time evaluation is relatively fast
  • You would have to reformat your parameters each time between eval and train, potentially leading to memory fragmentation

@huggingface huggingface deleted a comment from github-actions bot Jun 27, 2022
@patrickvonplaten patrickvonplaten added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Jun 27, 2022
@KMFODA
Copy link
Contributor

KMFODA commented Jun 28, 2022

Hey @patrickvonplaten, I'm keen to get gradient checkpointing working in JAX for long-t5. If this is not on the cards to be added soon happy to work on a PR for it if that works with you all?

@sanchit-gandhi
Copy link
Contributor

Hey @KMFODA! There's a PR that is close to being merged: #17843 I'll let you know once it's complete, and you can copy the logic across to Flax T5 in a new PR if that sounds good to you!

@KMFODA KMFODA mentioned this issue Jul 2, 2022
5 tasks
@sanchit-gandhi sanchit-gandhi linked a pull request Jul 28, 2022 that will close this issue
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants