Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support BatchNorm in Hubert pos_conv_emb as in fairseq #34389

Merged

Conversation

gallilmaimon
Copy link
Contributor

@gallilmaimon gallilmaimon commented Oct 24, 2024

What does this PR do?

This issue adds support for BatchNorm instead of weight norm in the HubertModel as in facebookresearch/fairseq@4db2649

The conversion file was also adapted to support the conversion from fairseq to HF, and was used to convert the widely used Hubert-base-25hz introduced in https://arxiv.org/abs/2305.13009

Fixes #34229

We already uploaded the converted weights to the hub at - https://huggingface.co/slprl/mhubert-base-25hz , which allows to assert the conversion worked correctly (compared to the original publication in textlesslib as follows):

# Asserting that results are identical to textless original
from transformers import HubertModel
from textless.data.speech_encoder import SpeechEncoder
import torchaudio

model = SpeechEncoder.by_name(dense_model_name='mhubert-base-25hz', quantizer_model_name='kmeans', vocab_size=500, deduplicate=False, need_f0=False)
hf_model = HubertModel.from_pretrained('slprl/mhubert-base-25hz')

wav = torchaudio.load(<WAV_PATH>)[0]

torch.allclose(model(wav)['dense'], hf_model(wav, output_hidden_states=True).hidden_states[11])

@ylacombe - would love your review and specifically there were several open questions I was wondering about:

  1. This means that HubertPositionalConvEmbedding is no longer a copy of transformers.models.wav2vec2.modeling_wav2vec2 - I addressed this but removing the comment, would you prefer I also change wav2vec?
  2. The conversion script convert_hubert_original_pytorch_checkpoint_to_pytorch.py didn't work (before the change) for the regular hubert-base-ls960h model because of layernorm naming changes, as discussed in (False?) warning about weight_g/weight_v missing on WeightNorm on PyTorch #26796. I didn't fix this because this felt out of scope.
  3. I would love some guidance or help with deepspeed because I wasn't sure if any changes were needed to support this.
  4. I also got some error when running make fixup which has to do with a file I haven't changed - src/transformers/models/glm/modeling_glm.py and I didn't manage to understand why. This also happened when running make fixup on a clean branch with no changes at all so would appreciate any help.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ylacombe
@eustlb

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gallilmaimon, thanks for quickly opening this PR!

The integration test you did looks good. Let's make sure to add this in the integration tests of test_modeling_hubert.py !

Once it's done, you can also push an empty commit to launch the slow tests CI run: git commit --allow-empty -m "[run-slow] hubert" !

To address your questions:

  1. I think it's okay to remove the Copied from statement here
  2. This is a correct observation. Sorry that I've missed your comment about it! Would you like to open another quick PR to correct this ?
  3. I think you did the deepspeed integration correctly : it's only applied when using weight_norm
  4. you might want to rebase your branch on the main transformers. If it doesn't work, you can share some logs here, so that I can help you!

Let me know if you have further questions !

Comment on lines 274 to 296
if config.conv_pos_batch_norm:
batch_norm = nn.BatchNorm1d(config.hidden_size)
self.conv = nn.Sequential(batch_norm, self.conv)
else:
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

if is_deepspeed_zero3_enabled():
import deepspeed
if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
weight_g = self.conv.weight_g
weight_v = self.conv.weight_v
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
self.conv = weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd rather add a self.batch_norm = None if not config.conv_pos_batch_norm else nn.BatchNorm1d(config.hidden_size) that we'd use in the forward pass, rather than using nn.Sequential here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt that the current method was more similar to the weight norm approach (and also similar to fairseq), but can change to your suggestion and update the conversion script as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In transformers, we rather make everything explicit!

@@ -94,6 +94,8 @@ class HubertConfig(PretrainedConfig):
embeddings layer.
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
Number of groups of 1D convolutional positional embeddings layer.
conv_pos_batch_norm (`bool`, *optional*, defaults to `False`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we precise (for bf16 models) out of curiosity ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest I just copied this from the fairseq definition https://github.com/facebookresearch/fairseq/blob/ecbf110e1eb43861214b05fa001eff584954f65a/fairseq/models/hubert/hubert.py#L197

I can remove this if you prefer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove it then

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gallilmaimon
Copy link
Contributor Author

The integration test you did looks good. Let's make sure to add this in the integration tests of test_modeling_hubert.py !

Okay, do you prefer that I do it with the textlesslib dependency in the test itself? or save the output and then just compare the output of HubertModel to the saved results from textlesslib?

  1. This is a correct observation. Sorry that I've missed your comment about it! Would you like to open another quick PR to correct this ?

Okay, I will open a new one about this separately.

  1. you might want to rebase your branch on the main transformers. If it doesn't work, you can share some logs here, so that I can help you!

I did do the rebase I think, but will try again and let you know

@ylacombe
Copy link
Contributor

Okay, do you prefer that I do it with the textlesslib dependency in the test itself? or save the output and then just compare the output of HubertModel to the saved results from textlesslib?

You can take a look at the test modeling file to get inspiration, we usually compute a few stats from the expected outputs, as well as a small sequence extracted from the expected outputs

Cyrilvallez and others added 23 commits October 26, 2024 11:47
* Correct the new defaults

* CIs

* add check

* Update utils.py

* Update utils.py

* Add the max_length in generate test checking shape without passing length

* style

* CIs

* fix fx CI issue
…gface#34383)

* update

* update

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
* Fix duplicated

* fix import
* add support for non nested images and add tests

* add tests error scenario

* fix style

* added single and no image to error tests
* fix onnx non-expotable inplace op

* mistral, qwen2, qwen2_vl, starcoder2

* fixup copies
* fix right pad llavas

* device mismatch
* no filter

* no filter

* no filter

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
* better example

* Update src/transformers/generation/configuration_utils.py

* Update src/transformers/generation/logits_process.py

* nits
* Fix bnb training test: compatibility with OPTSdpaAttention
* update

* update

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Add conversion integration test, and make batchnorm explicit variable
Add conversion integration test, and make batchnorm explicit variable
@gallilmaimon
Copy link
Contributor Author

@ylacombe Hey, I think I addressed all of your comments. Let me know if anything else is needed :)

@gallilmaimon
Copy link
Contributor Author

@ylacombe Hey again, just a gentle reminder about this as I would be happy to integrate as soon as possible. Thanks again for your time and feedback!

@gallilmaimon gallilmaimon requested a review from ylacombe November 3, 2024 13:50
@avishaiElmakies
Copy link
Contributor

@ylacombe
sorry to bother about this. But I would really love for this change to be added.

@gallilmaimon
Copy link
Contributor Author

@ylacombe - Hey again, just wondering if you had a chance to go over this so we can integrate this addition. There are several projects I know of which would build on this fix. Thanks again!

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @gallilmaimon , really sorry for the late review! Thanks for integrating my comments, it looks good to me now!

Also, thanks for adding the integration tests!

@ylacombe
Copy link
Contributor

Let's push again the empty commit: git commit --allow-empty -m "[run-slow] hubert" ! I don't think it has run yet

@ylacombe
Copy link
Contributor

cc @ArthurZucker , could you review when you have time?

@gallilmaimon, thanks again for the work on this PR! Excited to try the new checkpoint on downstream tasks. Have you been able to run some benchmarks against other Hubert checkpoints?

@gallilmaimon
Copy link
Contributor Author

@gallilmaimon, thanks again for the work on this PR! Excited to try the new checkpoint on downstream tasks. Have you been able to run some benchmarks against other Hubert checkpoints?

@ylacombe - My main usage is for discretising the representations and using them to train SpeechLMs and the results there seem good as expected (similar to the TWIST paper and notably better than the 50 hz), I will open-source this all hopefully very soon once it is ready!

But might also be interesting to try it for other downstream usages:)

@gallilmaimon
Copy link
Contributor Author

Hey @ArthurZucker any chance you had an opportunity to review this PR? Would really love to integrate this :) Thanks!

@ylacombe
Copy link
Contributor

ylacombe commented Dec 5, 2024

Requesting @Rocketknight1's because @ArthurZucker has limited responsibility for a few days

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It goes a little bit against our philosophy, as usually this would need a new model (because we introduce a new code path)!
We can stray a little bit here, or we can go about this using modular but it might be an overkill!

@@ -943,3 +943,40 @@ def test_inference_distilhubert(self):
self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1)

def test_inference_hubert_25hz(self):
model = HubertModel.from_pretrained("slprl/mhubert-base-25hz").to(torch_device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to open a PR to the original repo and use pr branch revision in the mean time!

@ArthurZucker
Copy link
Collaborator

My only request is to use official checkpoint path for the test, otherwise good job and sorry for being late on this review!

@gallilmaimon
Copy link
Contributor Author

My only request is to use official checkpoint path for the test, otherwise good job and sorry for being late on this review!

Hey @ArthurZucker, thanks for the review! I am not sure I understand what you mean by "official checkpoint" as this was only released as part of Fairseq and not HF, thus we preformed the conversion (as part of our academic lab - SLPRL) for community use. We validated that the results are identical as shown by the test. We of course give full reference and credit in the model card. I am happy to put the weights anywhere needed and merge the PR!

@ylacombe
Copy link
Contributor

ylacombe commented Dec 10, 2024

Hey @ArthurZucker, the model hasn't been officially released and can only be found by digging deep into the fairseq repositories. Since the model card is quite clean, gives full reference, and is hosted in the organisation of an academic research lab, I believe it should be OK to keep it like this, WDYT?

@gallilmaimon, could you add the license (MIT I think?) to the model card metadata in the meantime? I've opened a PR to do this, if it's indeed MIT-licensed

@gallilmaimon
Copy link
Contributor Author

Hey @ylacombe, I approved your PR as this is in fact MIT licensed and I also have a link to the original license in fairseq GitHub!

@ylacombe
Copy link
Contributor

I've asked some of the people on the research paper author lists if it would be possible to transfer the checkpoint to the Meta organization. In the meantime, let's merge, thanks for your excellent work!

@ylacombe ylacombe merged commit 6acb4e4 into huggingface:main Dec 10, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for HuBERT batch norm instead of weight norm in pos_conv_emb