Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test newly uploaded Flan-T5 weights #2074

Merged
merged 7 commits into from
Feb 27, 2023

Conversation

joecummings
Copy link
Contributor

This PR adds tests for the Flan-T5 weights and confirms that build_hf_checkpoint_from_path works w/ Flan.

@joecummings
Copy link
Contributor Author

This should be cherry-picked into the release as it's just more test coverage.

@joecummings joecummings requested review from Nayef211 and rshraga and removed request for Nayef211 February 24, 2023 22:32
@@ -122,7 +122,7 @@ The library currently consist of following pre-trained models:
* `DistilRoBERTa <https://github.com/huggingface/transformers/blob/main/examples/research_projects/distillation/README.md>`_
* XLM-RoBERTa: `Base and Large Architure <https://github.com/pytorch/fairseq/tree/main/examples/xlmr#pre-trained-models>`_
* T5: `Small, Base, Large, 3B, and 11B Architecture <https://github.com/google-research/text-to-text-transfer-transformer>`_
* Flan-T5: `Small, Base, Large, XL, and XXL Architecture <https://github.com/google-research/t5x>`_
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't actually support Flan-T5 small due to non divisible # of attention heads.

@@ -55,7 +55,6 @@ jobs:
python3 -m pip --quiet install sentencepiece
python3 -m pip --quiet install tqdm
python3 -m pip --quiet install expecttest
python3 -m pip --quiet install transformers
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer need transformers install for integration tests.

decoder_attention_mask=self.decoder_padding_mask,
output_hidden_states=True,
output_attentions=True,
model = T5Bundle.build_model_from_huggingface_ckpt(model_path, encoder_only=is_encoder_only)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above already tests the outputs of HF models, so we just need to confirm that we can load the weights from a HF file and that the model can run.

@@ -238,12 +239,12 @@ def build_model_from_huggingface_ckpt(

for i in range(config.num_decoder_layers):
if config.is_gated_act:
t5_model_state_dict[f"encoder.layers.{i}.linear1_0.weight"] = hf_weights[
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to confirm that the weights I generated before are correct.

@Nayef211
Copy link
Contributor

This should be cherry-picked into the release as it's just more test coverage.

Could you create a release tracker issue similar to #1766?

Copy link
Contributor

@Nayef211 Nayef211 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM

our_output["decoder_hidden_states"][i], hf_output.decoder_hidden_states[i]
), f"Mismatched hidden states for decoder layer {i}"

def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we get rid of all of these tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do the actual testing of the weights in the above tests. Here we just want to make sure the code can be loaded and ran.

@joecummings joecummings mentioned this pull request Feb 27, 2023
8 tasks
@joecummings joecummings merged commit a1dc61b into pytorch:main Feb 27, 2023
@joecummings joecummings deleted the test-flan-weights branch February 27, 2023 19:18
joecummings added a commit that referenced this pull request Feb 27, 2023
* Add tests for loading Flan-T5 weights from HF checkpoints

* Add expected outputs and update tests for Flan

* Add newline at end of file

* pin transformers version for testing

* Simplify test for HF loading

* Fix linting

* Fix integration tests w/ proper download path
joecummings added a commit that referenced this pull request Feb 28, 2023
* Add tests for loading Flan-T5 weights from HF checkpoints

* Add expected outputs and update tests for Flan

* Add newline at end of file

* pin transformers version for testing

* Simplify test for HF loading

* Fix linting

* Fix integration tests w/ proper download path
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants