diff --git a/.jenkins/metadata.json b/.jenkins/metadata.json index aa479828d02..6e82d054b4e 100644 --- a/.jenkins/metadata.json +++ b/.jenkins/metadata.json @@ -33,7 +33,7 @@ }, "recipes_source/torch_export_aoti_python.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" - }, + }, "advanced_source/pendulum.py": { "needs": "linux.g5.4xlarge.nvidia.gpu", "_comment": "need to be here for the compiling_optimizer_lr_scheduler.py to run." @@ -58,6 +58,9 @@ "intermediate_source/scaled_dot_product_attention_tutorial.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" }, + "intermediate_source/transformer_building_blocks.py": { + "needs": "linux.g5.4xlarge.nvidia.gpu" + }, "recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" }, diff --git a/index.rst b/index.rst index 668e9cc5924..cfe9c2ead8b 100644 --- a/index.rst +++ b/index.rst @@ -667,7 +667,7 @@ Welcome to PyTorch Tutorials .. customcarditem:: :header: [Title TBD] Unbundling nn.Transformer modules for gains and profits - :card_description: This tutorial goes over recommended best practices for implementing Transformers. + :card_description: This tutorial goes over recommended best practices for implementing Transformers with native PyTorch. :image: _static/img/thumbnails/cropped/pytorch-logo.png :link: intermediate/transformer_building_blocks.html :tags: Transformer diff --git a/intermediate_source/transformer_building_blocks.py b/intermediate_source/transformer_building_blocks.py index 586423b14e6..e736b8cef30 100644 --- a/intermediate_source/transformer_building_blocks.py +++ b/intermediate_source/transformer_building_blocks.py @@ -1,6 +1,6 @@ """ -[Title TBD] Unbundling nn.Transformer modules for gains and profits -=================================================================== +Dismantling down the ``nn.Transformer`` modules for gains and profits +====================================================================== **Author:** `Mikayla Gawarecki `_ The ``torch.nn`` module currently provides various ``Transformer``-related layers. @@ -11,7 +11,7 @@ were made to try to make these layers more flexible. While historically these layers intended to provide out-of-the-box, performant -solutions. We make the observations that +solutions, we make the observations that 1. People want to add slight customizations to their transformer layers 2. Writing these layers and customizations is not hard @@ -21,7 +21,7 @@ own performant transformer layers following our recommended best practices. The technologies used will be the following -1. Nested Tensors with the ``torch.jagged`` layout +1. Nested Tensors with the ``torch.jagged`` layout (AKA NJTs) 2. ``scaled_dot_product_attention`` 3. ``torch.compile()`` 4. ``FlexAttention`` @@ -31,8 +31,13 @@ If you are looking for an out-of-the-box implementation of a popular transformer architecture, note that there are many open-source libraries that provide them, -with some examples being HuggingFace transformers and torchtune. Please head -there instead! +with some examples being: + +* `HuggingFace transformers `_ +* `xformers `_ +* `torchtune `_ + +Please head there instead! If you are only interested in performant attention score modifications, please head to the `FlexAttention blog `_ that @@ -50,9 +55,9 @@ * `torch.nested `_ Nested tensors generalize the shape of regular dense tensors, allowing for -representation of ragged-sized data. In the context of transformers, -we can think of nested tensors as a tool for representing variable sequence -lengths. They eliminate the need for the bug-prone practices of explicit +representation of ragged-sized data with the same tensor UX. In the context of +transformers, we can think of nested tensors as a tool for representing variable +sequence lengths. They eliminate the need for the bug-prone practices of explicit padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``). *`scaled_dot_product_attention `_ @@ -60,16 +65,19 @@ ``scaled_dot_product_attention`` is a primitive for $\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V$ that dispatches into either fused implementations of the operator or a fallback implementation. It works out of -the box in eager mode and also integrates seamlessly with compile. -As of 2.6, it will also offer grouped query attention natively. +the box in eager mode (i.e. the default mode of using PyTorch where operations +are executed on the fly as they are encountered) and also integrates seamlessly +with ``torch.compile()``. As of 2.6, it will also offer grouped query attention +natively. * `torch.compile() `_ ``torch.compile()`` is a compiler introduced in version 2.0 that is able to -fuse together sequences of ops. Nested tensors with the ``torch.jagged`` layout +capture a graph of PyTorch code and perform various optimizations on it, such as +fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout and ``scaled_dot_product_attention`` work seamlessly with compile. In the context of transformers, the value add of using compile with nested tensor -and sdpa is that compile can remove framework overhead ones sees in eager mode +and SDPA is that compile can remove framework overhead ones sees in eager mode and fuse sequences of ops in transformers together (e.g. projection and activation). @@ -77,25 +85,26 @@ ``FlexAttention`` is a primitive that allows users to modify attention scores prior to the softmax operation. It generalizes the additive ``B`` term above -for `scaled_dot_product_attention` into allowing you to do any op. It requires -compile to achieve good performance. +for `scaled_dot_product_attention`, allowing for arbitrary calculation. It +requires compile to achieve good performance. The above building blocks are "All You Need" (as of October 2024) ================================================================== -The main premise in this section is that most transformers these days are +The main premise in this section is that most transformer variations are GPT-style, consisting of layers like Embedding, Positional Encoding, Attention Blocks and Feed Forward networks. If we were to try to classify the differences -in this space we might land on something like +in this space, we might land on something like: -1. Layer type (activation functions e.g. SwiGLU, normalization functions - e.g. RMSNorm etc., positional encodings e.g. Sinusoidal, Rotary etc.) +1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions + e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.) 2. Layer ordering (where to apply norms, where to apply positional encoding etc.) -3. Modifications to attention score (ALiBi, Relative Positional Bias etc.) +3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.) In a pre-compiler world, one might write their custom transformer and observe that it works but is slow. Then, one might write a custom fused kernel for -series of ops. In a compiler world, one can do the former, compile and profit. +the specific series of ops. In a compiler world, one can do the former, compile +and profit. """ @@ -125,7 +134,8 @@ # intermediate activations will use less memory. # # * Performance -# Since unnecessary computation on padding is skipped, performance improves. +# Since padding is not materialized and unnecessary computation on padding is +# skipped, performance and memory usage improve. # # We'll demonstrate the above by building off the ``MultiheadAttention`` layer in the # `Nested Tensor tutorial `_ @@ -403,10 +413,11 @@ def benchmark(func, *args, **kwargs): ################################################################################## # GPT-style layer # --------------- -# A basic GPT-style transformer layer consistst of a causal self-attention layer +# A basic GPT-style transformer layer consists of a causal self-attention layer # followed by a feed-forward network (FFN) with skip connections. Implementing # this is fairly straightforward using the ``MultiheadAttention`` layer above and -# is actually the same as an ``nn.TransformerEncoderLayer`` with ``is_causal=True``. +# gives equivalent results to an ``nn.TransformerEncoderLayer`` with +# ``is_causal=True``. # We demonstrate examples of implementing the rest of the nn layers # `here `_ but omit that from this @@ -418,7 +429,7 @@ def benchmark(func, *args, **kwargs): # So far, we have demonstrated how to implement a performant ``MultiheadAttention`` # layer that follows the traditional ``nn.MultiheadAttention``. Going back to our # classification of modifications to the transformer architecture, recall that we -# classified the modifications into layer type, layer ordering and modifications +# classified the modifications into layer type, layer ordering, and modifications # to the attention score. We trust that changing layer type and layer ordering # (e.g. swapping LayerNorm for RMSNorm) is fairly straightforward. # @@ -570,20 +581,23 @@ def forward(self, x): print(f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}") out = new_mha_layer(query, key, value, is_causal=False) -# TODO: anything else I can add here? ################################################################################ # Fully masked rows no longer cause NaNs # -------------------------------------- # # There has been a long standing issue with ``nn.MultiheadAttention`` and -# ``scaled_dot_product_attention`` where if a row was fully masked, the output +# ``scaled_dot_product_attention`` where if a row was fully masked out, the output # of the attention layer would be NaN. See `issue `_. -# This is because the softmax operation would divide by zero. +# This is because the softmax over an empty set is undefined. # # Thanks to `this PR `_ -# this is no longer the case. Instead, fully masked rows will be set to zero. More -# motivation can be found in the PR description. +# this is no longer the case. Instead, fully masked rows in ``scaled_dot_product_attention``. +# For cases where ``nn.MHA`` does not employ the "fast-path", this will also apply. +# +# Using a custom MHA layer with NJTs is strongly recommended over the +# existing "fast-path" in ``nn.MultiheadAttention`` as NJT's ability to model raggedness +# appropriately makes it possible to distinguish when there is an empty sequence. ################################################################################