Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer building blocks tutorial #3075

Merged
merged 25 commits into from
Nov 14, 2024
Merged

Conversation

mikaylagawarecki
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki commented Oct 4, 2024

Description

This adds the tutorial for transformer building blocks following the outline discussed in nn/optim triage on Friday (9/27/24) here https://docs.google.com/document/d/1TMrd0bDiM9-lcFHi079edkMRP1Ux5MTxt4lI1diiAKI/edit

This tutorial also links to a repo https://github.com/mikaylagawarecki/temp which

  • has examples of implementing the rest of the nn.Transformer-related layers in pytorch in a NJT friendly manner (basically no more *_padding_mask)
  • Notes some cases that we don't intend to demonstrate (e.g. see here)
  • removes fast path logic from MHA/TEL/TE
  • sanity checks that for MHA/TEL/TDL over kwargs: new_layer + NJT + compile we have correctness + perf gains over nn.layer + dense + mask + compile (as we expect :)). (TE, TD and T are just higher level wrappers so we didn't test those)

To run this tutorial with correctness, we likely need torch 2.6

In the future we can add the following

  • KV caching: NJT index_put_ + support in torch.compile for mutation of non-contiguous subclass instances (KV caching section) Add support for index_put_ in NT pytorch#135722
  • Grouped Query Attention + NJT (not sure if there is a plan for this yet)

Checklist

  • The issue that is being fixed is referred in the description (see above "Fixes #ISSUE_NUMBER")
  • Only one issue is addressed in this pull request
  • Labels from the issue that this PR is fixing are added to this pull request
  • No unnecessary issues are included into this pull request.

Copy link

pytorch-bot bot commented Oct 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/3075

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit f666842 with merge base 24c42d2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(discussed offline) more thorough review coming when the rendered docs are available. Sadly, we require nightly for the NJT stuff to work.

# is actually the same as an ``nn.TransformerEncoderLayer`` with ``is_causal=True``.

# We demonstrate examples of implementing the rest of the nn layers
# `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume the name / hosting for this repo will change before publishing the tutorial?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name yes, hosting, perhaps we can make that a followup -- I don't think we want to host this at pytorch/ yet as that lends a lot of "official"-ness to this repo which I don't want it to have just yet

Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really good! Love how it reads, and I think you did a good job introducing the primitives and covering the background.

I left a bunch of silly editorial nits but nothing too major.

I guess we still need some flex + NJT demonstration once #136792 lands?

index.rst Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
# of the attention layer would be NaN. See `issue <https://github.com/pytorch/pytorch/issues/41508>`_.
# This is because the softmax operation would divide by zero.
#
# Thanks to `this PR <https://github.com/pytorch/pytorch/pull/133882>`_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'm not sure it's within scope of a tutorial to mention specific PRs, but I think it is valuable to say that rolling a custom MHA doesn't run into the same NaN issues as the old nn.MHA because we're not employing a fused kernel with this problem (i.e. the fastpath case, which I think still exhibits the NaN behavior even after @drisspg's fix)

If you wanted, you could mention that NJT's ability to model raggedness appropriately makes it possible to distinguish when there is an empty sequence

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm I think Driss' PR description is pretty great so kind of wanted to link to it. If you feel strongly that tutorials should not link to PRs I'm happy to remove this though

also lmk whether the rewording in this section sounds reasonable

@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review October 31, 2024 22:24
Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really nice, thanks!

intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
value = (
value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
)
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice example here! I think it's also worth showing a block_mask example since it's a little different (there's a new create_nested_block_mask() helper).

Here's one from testing: https://github.com/pytorch/pytorch/blob/b09eb6ed6a22476746d8b7d5f6e464e34f89747a/test/test_nestedtensor.py#L7043-L7051

lmk if you need help with this

Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hard work on this; I think it looks great!

Just got one more suggestion on the Flex + NJT component, but otherwise LGTM :)

Copy link
Contributor

@svekars svekars left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, @mikaylagawarecki! Just a few editorial suggestions - let me know if you have any questions.

intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
intermediate_source/transformer_building_blocks.py Outdated Show resolved Hide resolved
Copy link
Contributor

@svekars svekars left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tutorial looks good to me - should we wait for 2.6RC and test against it?

@mikaylagawarecki
Copy link
Contributor Author

Could we merge first and I'll update as necessary to make it runnable when 2.6 RC is available (I have ran locally and verified that it runs)

@svekars svekars merged commit 1fcb66e into pytorch:main Nov 14, 2024
20 checks passed
@svekars
Copy link
Contributor

svekars commented Nov 14, 2024

Merging withe plan to remove the top note from the tutorial and adding back to the build

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants