Skip to content

Commit

Permalink
Address most comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mikaylagawarecki committed Oct 25, 2024
1 parent 8c7ec76 commit 09642b4
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 32 deletions.
5 changes: 4 additions & 1 deletion .jenkins/metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
},
"recipes_source/torch_export_aoti_python.py": {
"needs": "linux.g5.4xlarge.nvidia.gpu"
},
},
"advanced_source/pendulum.py": {
"needs": "linux.g5.4xlarge.nvidia.gpu",
"_comment": "need to be here for the compiling_optimizer_lr_scheduler.py to run."
Expand All @@ -58,6 +58,9 @@
"intermediate_source/scaled_dot_product_attention_tutorial.py": {
"needs": "linux.g5.4xlarge.nvidia.gpu"
},
"intermediate_source/transformer_building_blocks.py": {
"needs": "linux.g5.4xlarge.nvidia.gpu"
},
"recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py": {
"needs": "linux.g5.4xlarge.nvidia.gpu"
},
Expand Down
2 changes: 1 addition & 1 deletion index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ Welcome to PyTorch Tutorials

.. customcarditem::
:header: [Title TBD] Unbundling nn.Transformer modules for gains and profits
:card_description: This tutorial goes over recommended best practices for implementing Transformers.
:card_description: This tutorial goes over recommended best practices for implementing Transformers with native PyTorch.
:image: _static/img/thumbnails/cropped/pytorch-logo.png
:link: intermediate/transformer_building_blocks.html
:tags: Transformer
Expand Down
74 changes: 44 additions & 30 deletions intermediate_source/transformer_building_blocks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
[Title TBD] Unbundling nn.Transformer modules for gains and profits
===================================================================
Dismantling down the ``nn.Transformer`` modules for gains and profits
======================================================================
**Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
The ``torch.nn`` module currently provides various ``Transformer``-related layers.
Expand All @@ -11,7 +11,7 @@
were made to try to make these layers more flexible.
While historically these layers intended to provide out-of-the-box, performant
solutions. We make the observations that
solutions, we make the observations that
1. People want to add slight customizations to their transformer layers
2. Writing these layers and customizations is not hard
Expand All @@ -21,7 +21,7 @@
own performant transformer layers following our recommended best practices.
The technologies used will be the following
1. Nested Tensors with the ``torch.jagged`` layout
1. Nested Tensors with the ``torch.jagged`` layout (AKA NJTs)
2. ``scaled_dot_product_attention``
3. ``torch.compile()``
4. ``FlexAttention``
Expand All @@ -31,8 +31,13 @@
If you are looking for an out-of-the-box implementation of a popular transformer
architecture, note that there are many open-source libraries that provide them,
with some examples being HuggingFace transformers and torchtune. Please head
there instead!
with some examples being:
* `HuggingFace transformers <https://github.com/huggingface/transformers>`_
* `xformers <https://github.com/facebookresearch/xformers>`_
* `torchtune <https://github.com/pytorch/torchtune>`_
Please head there instead!
If you are only interested in performant attention score modifications, please
head to the `FlexAttention blog <https://flexattention.com/blog/>`_ that
Expand All @@ -50,52 +55,56 @@
* `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
Nested tensors generalize the shape of regular dense tensors, allowing for
representation of ragged-sized data. In the context of transformers,
we can think of nested tensors as a tool for representing variable sequence
lengths. They eliminate the need for the bug-prone practices of explicit
representation of ragged-sized data with the same tensor UX. In the context of
transformers, we can think of nested tensors as a tool for representing variable
sequence lengths. They eliminate the need for the bug-prone practices of explicit
padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
*`scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
``scaled_dot_product_attention`` is a primitive for
$\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V$ that dispatches into either fused
implementations of the operator or a fallback implementation. It works out of
the box in eager mode and also integrates seamlessly with compile.
As of 2.6, it will also offer grouped query attention natively.
the box in eager mode (i.e. the default mode of using PyTorch where operations
are executed on the fly as they are encountered) and also integrates seamlessly
with ``torch.compile()``. As of 2.6, it will also offer grouped query attention
natively.
* `torch.compile() <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
``torch.compile()`` is a compiler introduced in version 2.0 that is able to
fuse together sequences of ops. Nested tensors with the ``torch.jagged`` layout
capture a graph of PyTorch code and perform various optimizations on it, such as
fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout
and ``scaled_dot_product_attention`` work seamlessly with compile. In the
context of transformers, the value add of using compile with nested tensor
and sdpa is that compile can remove framework overhead ones sees in eager mode
and SDPA is that compile can remove framework overhead ones sees in eager mode
and fuse sequences of ops in transformers together (e.g. projection and
activation).
* `FlexAttention <https://pytorch.org/blog/flexattention/>`_
``FlexAttention`` is a primitive that allows users to modify attention scores
prior to the softmax operation. It generalizes the additive ``B`` term above
for `scaled_dot_product_attention` into allowing you to do any op. It requires
compile to achieve good performance.
for `scaled_dot_product_attention`, allowing for arbitrary calculation. It
requires compile to achieve good performance.
The above building blocks are "All You Need" (as of October 2024)
==================================================================
The main premise in this section is that most transformers these days are
The main premise in this section is that most transformer variations are
GPT-style, consisting of layers like Embedding, Positional Encoding, Attention
Blocks and Feed Forward networks. If we were to try to classify the differences
in this space we might land on something like
in this space, we might land on something like:
1. Layer type (activation functions e.g. SwiGLU, normalization functions
e.g. RMSNorm etc., positional encodings e.g. Sinusoidal, Rotary etc.)
1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions
e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.)
2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
3. Modifications to attention score (ALiBi, Relative Positional Bias etc.)
3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
In a pre-compiler world, one might write their custom transformer and observe
that it works but is slow. Then, one might write a custom fused kernel for
series of ops. In a compiler world, one can do the former, compile and profit.
the specific series of ops. In a compiler world, one can do the former, compile
and profit.
"""

Expand Down Expand Up @@ -125,7 +134,8 @@
# intermediate activations will use less memory.
#
# * Performance
# Since unnecessary computation on padding is skipped, performance improves.
# Since padding is not materialized and unnecessary computation on padding is
# skipped, performance and memory usage improve.
#
# We'll demonstrate the above by building off the ``MultiheadAttention`` layer in the
# `Nested Tensor tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
Expand Down Expand Up @@ -403,10 +413,11 @@ def benchmark(func, *args, **kwargs):
##################################################################################
# GPT-style layer
# ---------------
# A basic GPT-style transformer layer consistst of a causal self-attention layer
# A basic GPT-style transformer layer consists of a causal self-attention layer
# followed by a feed-forward network (FFN) with skip connections. Implementing
# this is fairly straightforward using the ``MultiheadAttention`` layer above and
# is actually the same as an ``nn.TransformerEncoderLayer`` with ``is_causal=True``.
# gives equivalent results to an ``nn.TransformerEncoderLayer`` with
# ``is_causal=True``.

# We demonstrate examples of implementing the rest of the nn layers
# `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
Expand All @@ -418,7 +429,7 @@ def benchmark(func, *args, **kwargs):
# So far, we have demonstrated how to implement a performant ``MultiheadAttention``
# layer that follows the traditional ``nn.MultiheadAttention``. Going back to our
# classification of modifications to the transformer architecture, recall that we
# classified the modifications into layer type, layer ordering and modifications
# classified the modifications into layer type, layer ordering, and modifications
# to the attention score. We trust that changing layer type and layer ordering
# (e.g. swapping LayerNorm for RMSNorm) is fairly straightforward.
#
Expand Down Expand Up @@ -570,20 +581,23 @@ def forward(self, x):
print(f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}")
out = new_mha_layer(query, key, value, is_causal=False)

# TODO: anything else I can add here?

################################################################################
# Fully masked rows no longer cause NaNs
# --------------------------------------
#
# There has been a long standing issue with ``nn.MultiheadAttention`` and
# ``scaled_dot_product_attention`` where if a row was fully masked, the output
# ``scaled_dot_product_attention`` where if a row was fully masked out, the output
# of the attention layer would be NaN. See `issue <https://github.com/pytorch/pytorch/issues/41508>`_.
# This is because the softmax operation would divide by zero.
# This is because the softmax over an empty set is undefined.
#
# Thanks to `this PR <https://github.com/pytorch/pytorch/pull/133882>`_
# this is no longer the case. Instead, fully masked rows will be set to zero. More
# motivation can be found in the PR description.
# this is no longer the case. Instead, fully masked rows in ``scaled_dot_product_attention``.
# For cases where ``nn.MHA`` does not employ the "fast-path", this will also apply.
#
# Using a custom MHA layer with NJTs is strongly recommended over the
# existing "fast-path" in ``nn.MultiheadAttention`` as NJT's ability to model raggedness
# appropriately makes it possible to distinguish when there is an empty sequence.


################################################################################
Expand Down

0 comments on commit 09642b4

Please sign in to comment.