Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] improve large model init and loading #16148

Merged
merged 35 commits into from
Apr 19, 2022

Conversation

patil-suraj
Copy link
Contributor

@patil-suraj patil-suraj commented Mar 14, 2022

What does this PR do?

As discussed in #15766 this PR adds the _do_init argument to handle large model loading in flax. By default _do_init=True and the API stays same.

  1. __init__:
  • When _do_init=False is passed to __init__, the params are not initialised and the user should manually call the model.init_weights method to do the initialisation.
  • When _do_init is False accessing model.params is not allowed and params must be always kept outside of the model
  • This PR also adds the params_shape_tree shape property to FlaxPreTrainedModel, which is a PyTree with shape and dtype information for each param.

This is how the API looks like:

config = BertConfig()
model = FlaxBertModel(config, _do_init=False)

# accessing model.params will raise an ValueError
model.params 

# to init the params
params = model.init_weights(model.key, model.input_shape)

# setting the model.params will also raise an ValueError
model.params = params

# model.init_weights can be used with `jit` and `pjit` to init the params in CPU or init in sharded way
params = jax.jit(model.init_weights, static_argnums=1, backend="cpu")(model.key, model.input_shape)
  1. from_pretrained:
  • When _do_init=False is passed to from_pretrained, the model won't be randomly initialised, only the weights will be loaded and returned along with the model instance.
  • when _do_init=False params will always be loaded on CPU
  • If _do_init=False and some keys are missing, the user should call init_weights and pass it the loaded params. init_weights will then take care of adding the missing keys.
  • as described above, getting and setting model.params will raise an error.
model, params = FlaxBertModel.from_pretrained("...", _do_init=False)
# if keys are missing
params = model.initt_weights(model.key, model.input_shape, params)

cc @borisdayma @sanchit-gandhi for info.

Fixes #15766

@patil-suraj patil-suraj changed the title Flax-do-init [WiP] [Flax] improve large model init and loading Mar 14, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 14, 2022

The documentation is not available anymore as the PR was closed or merged.

@patil-suraj patil-suraj changed the title [WiP] [Flax] improve large model init and loading [Flax] improve large model init and loading Mar 16, 2022
Copy link
Contributor

@borisdayma borisdayma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

src/transformers/modeling_flax_utils.py Outdated Show resolved Hide resolved
"`params` cannot be set from model when the model is created with `_do_init=False`. "
"You store the params outside of the model."
)

if isinstance(params, FrozenDict):
params = unfreeze(params)
Copy link
Contributor

@borisdayma borisdayma Mar 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering why we want to unfreeze the params here?
I personally always do model._params = freeze(model.params)

@agemagician
Copy link
Contributor

Hi @patil-suraj,

Thanks a lot for fixing this issue.
Could you please give us some estimation of when this pull request will be ready?

@patil-suraj
Copy link
Contributor Author

Hey @agemagician !
Thanks. I just need to run and fix a few tests and then it should be good to merge by tomorrow.

@agemagician
Copy link
Contributor

agemagician commented Mar 23, 2022

Hey @agemagician ! Thanks. I just need to run and fix a few tests and then it should be good to merge by tomorrow.

Perfect, thanks a lot @patil-suraj for your effort.
It will be awesome if you could update one of the flax language model examples to understand the changes that is needed after this merge:
https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Just two small nits to give the test a better name

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@agemagician
Copy link
Contributor

Hi @patil-suraj ,

Any updates regarding this merge ?

@patrickvonplaten
Copy link
Contributor

@patil-suraj - think we can merge this no?

@patil-suraj
Copy link
Contributor Author

Just need to fix the template tests and then we can merge it.

@patil-suraj patil-suraj merged commit d3bd9ac into huggingface:main Apr 19, 2022
@patil-suraj patil-suraj deleted the flax-do-init branch April 19, 2022 12:20
stancld added a commit to stancld/transformers that referenced this pull request Apr 19, 2022
patil-suraj pushed a commit to stancld/transformers that referenced this pull request May 26, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* begin do_init

* add params_shape_tree

* raise error if params are accessed when do_init is False

* don't allow do_init=False when keys are missing

* make shape tree a property

* assign self._params at the end

* add test for do_init

* add do_init arg to all flax models

* fix param setting

* disbale do_init for composite models

* update test

* add do_init in FlaxBigBirdForMultipleChoice

* better names and errors

* improve test

* style

* add a warning when do_init=False

* remove extra if

* set params after _required_params

* add test for from_pretrained

* do_init => _do_init

* chage warning to info

* fix typo

* add params in init_weights

* add params to gpt neo init

* add params to init_weights

* update do_init test

* Trigger CI

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* update template

* trigger CI

* style

* style

* fix template

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
patrickvonplaten added a commit that referenced this pull request Jun 13, 2022
* Initial commit

* Make some fixes

* Make PT model full forward pass

* Drop TF & Flax implementation, fix copies etc

* Add Flax model and update some corresponding stuff

* Drop some TF things

* Update config and flax local attn

* Add encoder_attention_type to config

* .

* Update docs

* Do some cleansing

* Fix some issues -> make style; add some docs

* Fix position_bias + mask addition + Update tests

* Fix repo consistency

* Fix model consistency by removing flax operation over attn_mask

* [WIP] Add PT TGlobal LongT5

* .

* [WIP] Add flax tglobal model

* [WIP] Update flax model to use the right attention type in the encoder

* Fix flax tglobal model forward pass

* Make the use of global_relative_attention_bias

* Add test suites for TGlobal model

* Fix minor bugs, clean code

* Fix pt-flax equivalence though not convinced with correctness

* Fix LocalAttn implementation to match the original impl. + update READMEs

* Few updates

* Update: [Flax] improve large model init and loading #16148

* Add ckpt conversion script accoring to #16853 + handle torch device placement

* Minor updates to conversion script.

* Typo: AutoModelForSeq2SeqLM -> FlaxAutoModelForSeq2SeqLM

* gpu support + dtype fix

* Apply some suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* * Remove (de)parallelize stuff
* Edit shape comments
* Update README.md
* make fix-copies

* Remove caching logic for local & tglobal attention

* Apply another batch of suggestions from code review

* Add missing checkpoints
* Format converting scripts
* Drop (de)parallelize links from longT5 mdx

* Fix converting script + revert config file change

* Revert "Remove caching logic for local & tglobal attention"

This reverts commit 2a61982.

* Stash caching logic in Flax model

* Make side relative bias used always

* Drop caching logic in PT model

* Return side bias as it was

* Drop all remaining model parallel logic

* Remove clamp statements

* Move test files to the proper place

* Update docs with new version of hf-doc-builder

* Fix test imports

* Make some minor improvements

* Add missing checkpoints to docs
* Make TGlobal model compatible with torch.onnx.export
* Replace some np.ndarray with jnp.ndarray

* Fix TGlobal for ONNX conversion + update docs

* fix _make_global_fixed_block_ids and masked neg  value

* update flax model

* style and quality

* fix imports

* remove load_tf_weights_in_longt5 from init and fix copies

* add slow test for TGlobal model

* typo fix

* Drop obsolete is_parallelizable and one warning

* Update __init__ files to fix repo-consistency

* fix pipeline test

* Fix some device placements

* [wip]: Update tests -- need to generate summaries to update expected_summary

* Fix quality

* Update LongT5 model card

* Update (slow) summarization tests

* make style

* rename checkpoitns

* finish

* fix flax tests

Co-authored-by: phungvanduy <pvduy23@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: patil-suraj <surajp815@gmail.com>
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jun 16, 2022
* Initial commit

* Make some fixes

* Make PT model full forward pass

* Drop TF & Flax implementation, fix copies etc

* Add Flax model and update some corresponding stuff

* Drop some TF things

* Update config and flax local attn

* Add encoder_attention_type to config

* .

* Update docs

* Do some cleansing

* Fix some issues -> make style; add some docs

* Fix position_bias + mask addition + Update tests

* Fix repo consistency

* Fix model consistency by removing flax operation over attn_mask

* [WIP] Add PT TGlobal LongT5

* .

* [WIP] Add flax tglobal model

* [WIP] Update flax model to use the right attention type in the encoder

* Fix flax tglobal model forward pass

* Make the use of global_relative_attention_bias

* Add test suites for TGlobal model

* Fix minor bugs, clean code

* Fix pt-flax equivalence though not convinced with correctness

* Fix LocalAttn implementation to match the original impl. + update READMEs

* Few updates

* Update: [Flax] improve large model init and loading huggingface#16148

* Add ckpt conversion script accoring to huggingface#16853 + handle torch device placement

* Minor updates to conversion script.

* Typo: AutoModelForSeq2SeqLM -> FlaxAutoModelForSeq2SeqLM

* gpu support + dtype fix

* Apply some suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* * Remove (de)parallelize stuff
* Edit shape comments
* Update README.md
* make fix-copies

* Remove caching logic for local & tglobal attention

* Apply another batch of suggestions from code review

* Add missing checkpoints
* Format converting scripts
* Drop (de)parallelize links from longT5 mdx

* Fix converting script + revert config file change

* Revert "Remove caching logic for local & tglobal attention"

This reverts commit 2a61982.

* Stash caching logic in Flax model

* Make side relative bias used always

* Drop caching logic in PT model

* Return side bias as it was

* Drop all remaining model parallel logic

* Remove clamp statements

* Move test files to the proper place

* Update docs with new version of hf-doc-builder

* Fix test imports

* Make some minor improvements

* Add missing checkpoints to docs
* Make TGlobal model compatible with torch.onnx.export
* Replace some np.ndarray with jnp.ndarray

* Fix TGlobal for ONNX conversion + update docs

* fix _make_global_fixed_block_ids and masked neg  value

* update flax model

* style and quality

* fix imports

* remove load_tf_weights_in_longt5 from init and fix copies

* add slow test for TGlobal model

* typo fix

* Drop obsolete is_parallelizable and one warning

* Update __init__ files to fix repo-consistency

* fix pipeline test

* Fix some device placements

* [wip]: Update tests -- need to generate summaries to update expected_summary

* Fix quality

* Update LongT5 model card

* Update (slow) summarization tests

* make style

* rename checkpoitns

* finish

* fix flax tests

Co-authored-by: phungvanduy <pvduy23@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: patil-suraj <surajp815@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Discussion] Loading and initialising large models with Flax
5 participants