-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformer building blocks tutorial #3075
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/3075
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit f666842 with merge base 24c42d2 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(discussed offline) more thorough review coming when the rendered docs are available. Sadly, we require nightly for the NJT stuff to work.
# is actually the same as an ``nn.TransformerEncoderLayer`` with ``is_causal=True``. | ||
|
||
# We demonstrate examples of implementing the rest of the nn layers | ||
# `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume the name / hosting for this repo will change before publishing the tutorial?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name yes, hosting, perhaps we can make that a followup -- I don't think we want to host this at pytorch/ yet as that lends a lot of "official"-ness to this repo which I don't want it to have just yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking really good! Love how it reads, and I think you did a good job introducing the primitives and covering the background.
I left a bunch of silly editorial nits but nothing too major.
I guess we still need some flex + NJT demonstration once #136792 lands?
# of the attention layer would be NaN. See `issue <https://github.com/pytorch/pytorch/issues/41508>`_. | ||
# This is because the softmax operation would divide by zero. | ||
# | ||
# Thanks to `this PR <https://github.com/pytorch/pytorch/pull/133882>`_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'm not sure it's within scope of a tutorial to mention specific PRs, but I think it is valuable to say that rolling a custom MHA doesn't run into the same NaN issues as the old nn.MHA because we're not employing a fused kernel with this problem (i.e. the fastpath case, which I think still exhibits the NaN behavior even after @drisspg's fix)
If you wanted, you could mention that NJT's ability to model raggedness appropriately makes it possible to distinguish when there is an empty sequence
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm I think Driss' PR description is pretty great so kind of wanted to link to it. If you feel strongly that tutorials should not link to PRs I'm happy to remove this though
also lmk whether the rewording in this section sounds reasonable
09642b4
to
8c7ec76
Compare
f64196c
to
d83f14b
Compare
a48e2a5
to
aaed759
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking really nice, thanks!
value = ( | ||
value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() | ||
) | ||
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice example here! I think it's also worth showing a block_mask example since it's a little different (there's a new create_nested_block_mask()
helper).
Here's one from testing: https://github.com/pytorch/pytorch/blob/b09eb6ed6a22476746d8b7d5f6e464e34f89747a/test/test_nestedtensor.py#L7043-L7051
lmk if you need help with this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the hard work on this; I think it looks great!
Just got one more suggestion on the Flex + NJT component, but otherwise LGTM :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, @mikaylagawarecki! Just a few editorial suggestions - let me know if you have any questions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tutorial looks good to me - should we wait for 2.6RC and test against it?
Could we merge first and I'll update as necessary to make it runnable when 2.6 RC is available (I have ran locally and verified that it runs) |
Merging withe plan to remove the top note from the tutorial and adding back to the build |
Description
This adds the tutorial for transformer building blocks following the outline discussed in nn/optim triage on Friday (9/27/24) here https://docs.google.com/document/d/1TMrd0bDiM9-lcFHi079edkMRP1Ux5MTxt4lI1diiAKI/edit
This tutorial also links to a repo https://github.com/mikaylagawarecki/temp which
nn.Transformer
-related layers in pytorch in a NJT friendly manner (basically no more*_padding_mask
)To run this tutorial with correctness, we likely need torch 2.6
In the future we can add the following
index_put_
+ support in torch.compile for mutation of non-contiguous subclass instances (KV caching section) Add support for index_put_ in NT pytorch#135722Checklist