Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Wav2Vec2 Adapter Weights to Flax #15521

Closed
wants to merge 34 commits into from
Closed

Add Wav2Vec2 Adapter Weights to Flax #15521

wants to merge 34 commits into from

Conversation

sanchit-gandhi
Copy link
Contributor

What does this PR do?

Fixes #15476

  • Adds an adapter to the Flax Wav2Vec2 model to reduce the time dimension of the extracted feature vectors beyond that of the standard Wav2Vec2 model. The encoder's output hidden states thus have a time context window that is more similar to that of a subword token instead of just a character.
  • Shape and values of Flax output logits match those of the PyTorch model.
  • Flax model uses all PyTorch model weights, including those of the adapter. Running the script in Add Adapter Weighs to Flax #15476 resolved to yield identical results (within 4e-2 threshold).

@HuggingFaceDocBuilder
Copy link

HuggingFaceDocBuilder commented Feb 4, 2022

The documentation is not available anymore as the PR was closed or merged.

@sanchit-gandhi sanchit-gandhi changed the title Flax wav2vec2 Add Wav2Vec2 Adapter Weights to Flax Feb 4, 2022
stas00 and others added 3 commits February 4, 2022 11:15
* Standardize instance segmentation models outputs

* Rename output

* Update src/transformers/modeling_outputs.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Add legacy argument to the config and model forward

* Update src/transformers/models/beit/modeling_beit.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Copy fix in Segformer

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
* [deepspeed docs] DeepSpeed ZeRO Inference

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* tweak

* deal with black

* extra cleanup, better comments

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool PR! Looks more or less good to me. Left some final comments and then I think we can merge :-)

patrickvonplaten and others added 9 commits February 7, 2022 15:35
* [torch_int_div] Correct true division in generation

* up

* up
* First draft

* Add conversion script

* Improve conversion script

* Improve docs and implement tests

* Define model output class

* Fix tests

* Fix more tests

* Add model to README

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply more suggestions from code review

* Apply suggestions from code review

* Rename dims to hidden_sizes

* Fix equivalence test

* Rename gamma to gamma_parameter

* Clean up conversion script

* Add ConvNextFeatureExtractor

* Add corresponding tests

* Implement feature extractor correctly

* Make implementation cleaner

* Add ConvNextStem class

* Improve design

* Update design to also include encoder

* Fix gamma parameter

* Use sample docstrings

* Finish conversion, add center cropping

* Replace nielsr by facebook, make feature extractor tests smaller

* Fix integration test

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* Unused import

* Make `has_length()` torch-independent to use in callbacks

* Update src/transformers/trainer_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* Single-epoch run

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Infinite dataset

* Trainer fix + distributed benchmark

* Benchmark fix

* unused import

* interleaved splits

* interleaved splits

* has_length util

* Move to research projects

* Leftover Sized checks

* Bump min version

* Unused import

* Revert trainer changes

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* Wav2Vec2 models must either throw or deal with add_apater

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add pre-add_adapter backwards compatibility

* Add pre-add_adapter backwards compatibility

* Fix issue in tests/test_modeling_wav2vec2.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* add cross attn to outputs

* add cross attn to outputs for TFLED

* add undo padding

* remove unused import

* fix style

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
@sanchit-gandhi sanchit-gandhi deleted the flax-wav2vec2 branch February 7, 2022 16:49
@sanchit-gandhi sanchit-gandhi restored the flax-wav2vec2 branch February 7, 2022 16:50
@sanchit-gandhi sanchit-gandhi reopened this Feb 7, 2022
sanchit-gandhi and others added 17 commits February 7, 2022 18:07
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* fix outputs

* fix for CTC

* fix doc

* make style

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* 📝 add config section

* 📝 finish first draft

* 📝 add feature extractor and processor

* 🖍 apply feedback from review

* 📝 minor edits

* last review
* Change the way tracing happens, enabling dynamic axes out of the box

* Update the tests and modeling xlnet

* Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors).

* Comments and making tracing work for gpt-j and xlnet

* Refactore things related to num_choices (and batch_size, sequence_length)

* Update fx to work on PyTorch 1.10

* Postpone autowrap_function feature usage for later

* Add copyrights

* Remove unnecessary file

* Fix issue with add_new_model_like

* Apply suggestions
* electra is added to onnx supported model

* add google/electra-base-generator for test onnx module

Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
* use_cache = False for PT models if labels is passed

* Fix for BigBirdPegasusForConditionalGeneration

* add warning if users specify use_cache=True

* Use logger.warning instead of warnings.warn

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
@patrickvonplaten
Copy link
Contributor

I think the commit history is messed up here - the best is usually to just reopen a new PR and to just extract your changes from this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Adapter Weighs to Flax