diff --git a/.jenkins/validate_tutorials_built.py b/.jenkins/validate_tutorials_built.py index c3bf4c5534..1c582ab4fc 100644 --- a/.jenkins/validate_tutorials_built.py +++ b/.jenkins/validate_tutorials_built.py @@ -10,7 +10,7 @@ NOT_RUN = [ "beginner_source/basics/intro", # no code - "beginner_source/introyt/introyt_index", # no code + "beginner_source/introyt/introyt_index", # no code "beginner_source/onnx/intro_onnx", "beginner_source/profiler", "beginner_source/saving_loading_models", @@ -25,9 +25,8 @@ "intermediate_source/mnist_train_nas", # used by ax_multiobjective_nas_tutorial.py "intermediate_source/fx_conv_bn_fuser", "intermediate_source/_torch_export_nightly_tutorial", # does not work on release - "intermediate_source/transformer_building_blocks", # does not work on release "advanced_source/super_resolution_with_onnxruntime", - "advanced_source/usb_semisup_learn", # fails with CUDA OOM error, should try on a different worker + "advanced_source/usb_semisup_learn", # fails with CUDA OOM error, should try on a different worker "prototype_source/fx_graph_mode_ptq_dynamic", "prototype_source/vmap_recipe", "prototype_source/torchscript_freezing", @@ -50,10 +49,11 @@ "recipes_source/recipes/Captum_Recipe", "intermediate_source/flask_rest_api_tutorial", "intermediate_source/text_to_speech_with_torchaudio", - "intermediate_source/tensorboard_profiler_tutorial", # reenable after 2.0 release. - "intermediate_source/torch_export_tutorial" # reenable after 2940 is fixed. + "intermediate_source/tensorboard_profiler_tutorial", # reenable after 2.0 release. + "intermediate_source/torch_export_tutorial", # reenable after 2940 is fixed. ] + def tutorial_source_dirs() -> List[Path]: return [ p.relative_to(REPO_ROOT).with_name(p.stem[:-7]) diff --git a/intermediate_source/transformer_building_blocks.py b/intermediate_source/transformer_building_blocks.py index 932be472e8..7d2c67356e 100644 --- a/intermediate_source/transformer_building_blocks.py +++ b/intermediate_source/transformer_building_blocks.py @@ -142,7 +142,7 @@ # that arise due to different sequence lengths within a batch. Since there is # no ``query_padding_mask`` in ``nn.MHA``, users have to take care to mask/slice # the outputs appropriately to account for query sequence lengths. ``NestedTensor`` -# cleanly removes the need for this sort of error-prone padding masks. +# cleanly removes the need for this sort of error-prone padding masks. # # * **Memory** # Instead of materializing a dense ``[B, S, D]`` tensor with a ``[B, S]`` @@ -163,6 +163,7 @@ import torch.nn as nn import torch.nn.functional as F + class MultiHeadAttention(nn.Module): """ Computes multi-head attention. Supports nested or padded tensors. @@ -177,6 +178,7 @@ class MultiHeadAttention(nn.Module): dropout (float, optional): Dropout probability. Default: 0.0 bias (bool, optional): Whether to add bias to input projection. Default: True """ + def __init__( self, E_q: int, @@ -195,23 +197,25 @@ def __init__( self.dropout = dropout self._qkv_same_embed_dim = E_q == E_k and E_q == E_v if self._qkv_same_embed_dim: - self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs) + self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs) else: - self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs) - self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs) - self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs) + self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs) + self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs) + self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs) E_out = E_q self.out_proj = nn.Linear(E_total, E_out, bias=bias, **factory_kwargs) assert E_total % nheads == 0, "Embedding dim is not divisible by nheads" self.E_head = E_total // nheads self.bias = bias - def forward(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask=None, - is_causal=False) -> torch.Tensor: + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask=None, + is_causal=False, + ) -> torch.Tensor: """ Forward pass; runs the following process: 1. Apply input projection @@ -235,12 +239,20 @@ def forward(self, result = self.packed_proj(query) query, key, value = torch.chunk(result, 3, dim=-1) else: - q_weight, k_weight, v_weight = torch.chunk(self.packed_proj.weight, 3, dim=0) + q_weight, k_weight, v_weight = torch.chunk( + self.packed_proj.weight, 3, dim=0 + ) if self.bias: - q_bias, k_bias, v_bias = torch.chunk(self.packed_proj.bias, 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk( + self.packed_proj.bias, 3, dim=0 + ) else: q_bias, k_bias, v_bias = None, None, None - query, key, value = F.linear(query, q_weight, q_bias), F.linear(key, k_weight, k_bias), F.linear(value, v_weight, v_bias) + query, key, value = ( + F.linear(query, q_weight, q_bias), + F.linear(key, k_weight, k_bias), + F.linear(value, v_weight, v_bias), + ) else: query = self.q_proj(query) @@ -259,7 +271,8 @@ def forward(self, # Step 3. Run SDPA # (N, nheads, L_t, E_head) attn_output = F.scaled_dot_product_attention( - query, key, value, dropout_p=self.dropout, is_causal=is_causal) + query, key, value, dropout_p=self.dropout, is_causal=is_causal + ) # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total) attn_output = attn_output.transpose(1, 2).flatten(-2) @@ -280,6 +293,7 @@ def forward(self, import numpy as np + def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor: # generate fake corpus by unigram Zipf distribution # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858 @@ -292,6 +306,7 @@ def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor: word = np.random.zipf(alpha) return torch.tensor(sentence_lengths) + # Generate a batch of semi-realistic data using Zipf distribution for sentence lengths # in the form of nested tensors with the jagged layout. def gen_batch(N, E_q, E_k, E_v, device, dtype=torch.float32, query_seq_len_1=False): @@ -302,30 +317,41 @@ def gen_batch(N, E_q, E_k, E_v, device, dtype=torch.float32, query_seq_len_1=Fal # dimension and works with torch.compile. The batch items each have shape (B, S*, D) # where B = batch size, S* = ragged sequence length, and D = embedding dimension. if query_seq_len_1: - query = torch.nested.nested_tensor([ - torch.randn(1, E_q, dtype=dtype, device=device) - for l in sentence_lengths - ], layout=torch.jagged) + query = torch.nested.nested_tensor( + [torch.randn(1, E_q, dtype=dtype, device=device) for l in sentence_lengths], + layout=torch.jagged, + ) else: - query = torch.nested.nested_tensor([ - torch.randn(l.item(), E_q, dtype=dtype, device=device) - for l in sentence_lengths - ], layout=torch.jagged) - - key = torch.nested.nested_tensor([ - torch.randn(s.item(), E_k, dtype=dtype, device=device) - for s in sentence_lengths - ], layout=torch.jagged) - - value = torch.nested.nested_tensor([ - torch.randn(s.item(), E_v, dtype=dtype, device=device) - for s in sentence_lengths - ], layout=torch.jagged) + query = torch.nested.nested_tensor( + [ + torch.randn(l.item(), E_q, dtype=dtype, device=device) + for l in sentence_lengths + ], + layout=torch.jagged, + ) + + key = torch.nested.nested_tensor( + [ + torch.randn(s.item(), E_k, dtype=dtype, device=device) + for s in sentence_lengths + ], + layout=torch.jagged, + ) + + value = torch.nested.nested_tensor( + [ + torch.randn(s.item(), E_v, dtype=dtype, device=device) + for s in sentence_lengths + ], + layout=torch.jagged, + ) return query, key, value, sentence_lengths -import timeit + import math +import timeit + def benchmark(func, *args, **kwargs): torch.cuda.synchronize() @@ -336,6 +362,7 @@ def benchmark(func, *args, **kwargs): end = timeit.default_timer() return output, (end - begin), torch.cuda.max_memory_allocated() + ############################################################################## # We will now demonstrate the performance improvements of using nested tensors # in the ``MultiheadAttention`` layer + compile for self attention. We compare this against @@ -347,71 +374,94 @@ def benchmark(func, *args, **kwargs): nheads = 8 dropout = 0.0 bias = True -device='cuda' +device = "cuda" torch.manual_seed(6) query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device) S = sentence_lengths.max().item() -print(f"Total sequence length in nested query {sentence_lengths.sum().item()}, max sequence length {S}") +print( + f"Total sequence length in nested query {sentence_lengths.sum().item()}, max sequence length {S}" +) padded_query, padded_key, padded_value = ( t.to_padded_tensor(0.0) for t in (query, key, value) ) torch.manual_seed(6) -mha_layer = MultiHeadAttention(E_q, E_k, E_v, E_total, nheads, dropout=dropout, bias=bias, device='cuda') +mha_layer = MultiHeadAttention( + E_q, E_k, E_v, E_total, nheads, dropout=dropout, bias=bias, device="cuda" +) torch.manual_seed(6) -vanilla_mha_layer = nn.MultiheadAttention(E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device='cuda') +vanilla_mha_layer = nn.MultiheadAttention( + E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device="cuda" +) # ``nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :( -mha_layer.out_proj.weight = nn.Parameter(vanilla_mha_layer.out_proj.weight.clone().detach()) -mha_layer.packed_proj.weight = nn.Parameter(vanilla_mha_layer.in_proj_weight.clone().detach()) +mha_layer.out_proj.weight = nn.Parameter( + vanilla_mha_layer.out_proj.weight.clone().detach() +) +mha_layer.packed_proj.weight = nn.Parameter( + vanilla_mha_layer.in_proj_weight.clone().detach() +) mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach()) -mha_layer.packed_proj.bias = nn.Parameter(vanilla_mha_layer.in_proj_bias.clone().detach()) +mha_layer.packed_proj.bias = nn.Parameter( + vanilla_mha_layer.in_proj_bias.clone().detach() +) new_mha_layer = torch.compile(mha_layer) # warmup compile nested_result_warmup = new_mha_layer(query, query, query, is_causal=True) # benchmark -nested_result, nested_time, nested_peak_memory = benchmark(new_mha_layer, query, query, query, is_causal=True) +nested_result, nested_time, nested_peak_memory = benchmark( + new_mha_layer, query, query, query, is_causal=True +) padded_nested_result = nested_result.to_padded_tensor(0.0) # For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask`` # Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal`` src_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0] -attn_mask = torch.empty((N, S, S), device=device).fill_(float('-inf')) +attn_mask = torch.empty((N, S, S), device=device).fill_(float("-inf")) for i, s in enumerate(sentence_lengths): attn_mask[i, :s, :s] = nn.Transformer.generate_square_subsequent_mask(s) -attn_mask = attn_mask.unsqueeze(1).expand(N, nheads, S, S).reshape(N*nheads, S, S) +attn_mask = attn_mask.unsqueeze(1).expand(N, nheads, S, S).reshape(N * nheads, S, S) vanilla_mha_layer = torch.compile(vanilla_mha_layer) # warmup compile -warmup_vanilla_result = vanilla_mha_layer(padded_query, - padded_query, - padded_query, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - need_weights=False, - is_causal=True) +warmup_vanilla_result = vanilla_mha_layer( + padded_query, + padded_query, + padded_query, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + need_weights=False, + is_causal=True, +) # benchmark -(padded_result, _), padded_time, padded_peak_memory = benchmark(vanilla_mha_layer, - padded_query, - padded_query, - padded_query, - key_padding_mask=src_key_padding_mask, - need_weights=False, - attn_mask=attn_mask, - is_causal=True) +(padded_result, _), padded_time, padded_peak_memory = benchmark( + vanilla_mha_layer, + padded_query, + padded_query, + padded_query, + key_padding_mask=src_key_padding_mask, + need_weights=False, + attn_mask=attn_mask, + is_causal=True, +) print(f"{padded_time=:.5f}, padded_peak_memory={padded_peak_memory/1e9:.2f} GB") print(f"{nested_time=:.5f}, nested_peak_memory={nested_peak_memory/1e9:.2f} GB") -print("Max difference between vanilla and nested result", (padded_result - padded_nested_result).abs().max().item()) +print( + "Max difference between vanilla and nested result", + (padded_result - padded_nested_result).abs().max().item(), +) print(f"Nested speedup: {(padded_time/nested_time):.2f}") -print(f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB") +print( + f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB" +) ###################################################################################### # For reference, here are some sample outputs on A100: -# +# # .. code:: # # padded_time=0.03454, padded_peak_memory=4.14 GB @@ -426,24 +476,54 @@ def benchmark(func, *args, **kwargs): # padding-specific step: remove output projection bias from padded entries for fair comparison padded_result[i, entry_length:, :] = 0.0 -_, padded_bw_time, padded_bw_peak_mem = benchmark(lambda : padded_result.sum().backward()) -_, nested_bw_time, nested_bw_peak_mem = benchmark(lambda : padded_nested_result.sum().backward()) +_, padded_bw_time, padded_bw_peak_mem = benchmark( + lambda: padded_result.sum().backward() +) +_, nested_bw_time, nested_bw_peak_mem = benchmark( + lambda: padded_nested_result.sum().backward() +) print(f"{padded_bw_time=:.5f}, padded_bw_peak_mem={padded_bw_peak_mem/1e9:.2f} GB") print(f"{nested_bw_time=:.5f}, nested_bw_peak_mem={nested_bw_peak_mem/1e9:.2f} GB") print(f"Nested backward speedup: {(padded_bw_time/nested_bw_time):.2f}") -print(f"Nested backward peak memory reduction {((padded_bw_peak_mem - nested_bw_peak_mem)/1e9):.2f} GB") +print( + f"Nested backward peak memory reduction {((padded_bw_peak_mem - nested_bw_peak_mem)/1e9):.2f} GB" +) -print("Difference in out_proj.weight.grad", (mha_layer.out_proj.weight.grad - vanilla_mha_layer.out_proj.weight.grad).abs().max().item()) -print("Difference in packed_proj.weight.grad", (mha_layer.packed_proj.weight.grad - vanilla_mha_layer.in_proj_weight.grad).abs().max().item()) -print("Difference in out_proj.bias.grad", (mha_layer.out_proj.bias.grad - vanilla_mha_layer.out_proj.bias.grad).abs().max().item()) -print("Difference in packed_proj.bias.grad", (mha_layer.packed_proj.bias.grad - vanilla_mha_layer.in_proj_bias.grad).abs().max().item()) +print( + "Difference in out_proj.weight.grad", + (mha_layer.out_proj.weight.grad - vanilla_mha_layer.out_proj.weight.grad) + .abs() + .max() + .item(), +) +print( + "Difference in packed_proj.weight.grad", + (mha_layer.packed_proj.weight.grad - vanilla_mha_layer.in_proj_weight.grad) + .abs() + .max() + .item(), +) +print( + "Difference in out_proj.bias.grad", + (mha_layer.out_proj.bias.grad - vanilla_mha_layer.out_proj.bias.grad) + .abs() + .max() + .item(), +) +print( + "Difference in packed_proj.bias.grad", + (mha_layer.packed_proj.bias.grad - vanilla_mha_layer.in_proj_bias.grad) + .abs() + .max() + .item(), +) ################################################################################## # Sample outputs on A100: # # .. code:: -# +# # padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB # nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB # Nested backward speedup: 144.13 @@ -477,10 +557,10 @@ def benchmark(func, *args, **kwargs): # classified the modifications into layer type, layer ordering, and modifications # to the attention score. We trust that changing layer type and layer ordering # (such as swapping ``LayerNorm`` for ``RMSNorm``) is fairly straightforward. -# +# # In this section, we will discuss various functionalities using the # aforementioned building blocks, including the following: -# +# # * Cross Attention # * Fully masked rows no longer cause NaNs # * Modifying attention score: ALiBi with FlexAttention and NJT @@ -501,8 +581,12 @@ def benchmark(func, *args, **kwargs): query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device) _, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device) -print(f"Total sequence length in nested query {q_len.sum().item()}, max sequence length {q_len.max().item()}") -print(f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}") +print( + f"Total sequence length in nested query {q_len.sum().item()}, max sequence length {q_len.max().item()}" +) +print( + f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}" +) out = new_mha_layer(query, key, value, is_causal=False) ######################################################################################## @@ -519,29 +603,40 @@ def benchmark(func, *args, **kwargs): # warmup compile warmup_nested_result = new_mha_layer(query, key, value, is_causal=False) -warmup_vanilla_result = vanilla_mha_layer(padded_query, - padded_key, - padded_value, - key_padding_mask=key_padding_mask, - need_weights=False, - is_causal=False) - -nested_result, nested_time, nested_peak_memory = benchmark(new_mha_layer, query, key, value, is_causal=False) -(padded_result, _), padded_time, padded_peak_memory = benchmark(vanilla_mha_layer, - padded_query, - padded_key, - padded_value, - key_padding_mask=key_padding_mask, - need_weights=False, - is_causal=False) +warmup_vanilla_result = vanilla_mha_layer( + padded_query, + padded_key, + padded_value, + key_padding_mask=key_padding_mask, + need_weights=False, + is_causal=False, +) + +nested_result, nested_time, nested_peak_memory = benchmark( + new_mha_layer, query, key, value, is_causal=False +) +(padded_result, _), padded_time, padded_peak_memory = benchmark( + vanilla_mha_layer, + padded_query, + padded_key, + padded_value, + key_padding_mask=key_padding_mask, + need_weights=False, + is_causal=False, +) padded_nested_result = nested_result.to_padded_tensor(0.0) for i, entry_length in enumerate(q_len): # padding-specific step: remove output projection bias from padded entries for fair comparison padded_result[i, entry_length:, :] = 0.0 -print("Max difference between vanilla and nested result", (padded_result - padded_nested_result).abs().max().item()) +print( + "Max difference between vanilla and nested result", + (padded_result - padded_nested_result).abs().max().item(), +) print(f"Nested speedup: {(padded_time/nested_time):.2f}") -print(f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB") +print( + f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB" +) ################################################################################## # Sample outputs on A100: @@ -556,15 +651,16 @@ def benchmark(func, *args, **kwargs): ################################################################################ # Fully masked rows no longer cause NaNs # -------------------------------------- -# +# # There has been a long standing issue with ``nn.MultiheadAttention`` and # ``scaled_dot_product_attention`` where if a row was fully masked out, the output # of the attention layer would be NaN. See `issue `_. # This is because the softmax over an empty set is undefined. -# +# # Thanks to `this PR `_ -# this is no longer the case. Instead, fully masked rows in ``scaled_dot_product_attention``. -# For cases where ``nn.MHA`` does not employ the "fast-path", this will also apply. +# this is no longer the case. Instead, the output corresponding to fully masked rows +# in ``scaled_dot_product_attention`` will be 0. For cases where ``nn.MHA`` does +# not employ the "fast-path", this will also apply. # # Using a custom MHA layer with NJTs is strongly recommended over the # existing "fast-path" in ``nn.MultiheadAttention`` as NJT's ability to model raggedness @@ -583,6 +679,7 @@ def benchmark(func, *args, **kwargs): from torch.nn.attention.flex_attention import flex_attention + def generate_alibi_bias(H: int): """Returns an alibi bias score_mod given the number of heads H Args: @@ -590,22 +687,21 @@ def generate_alibi_bias(H: int): Returns: alibi_bias: alibi bias score_mod """ + def alibi_mod(score, b, h, q_idx, kv_idx): scale = torch.exp2(-((h + 1) * 8.0 / H)) bias = (q_idx - kv_idx) * scale return score + bias + return alibi_mod + query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device) n_heads, D = 8, E_q // 8 alibi_score_mod = generate_alibi_bias(n_heads) -query = ( - query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() -) +query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() -value = ( - value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() -) +value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod) ############################################################################### @@ -619,69 +715,72 @@ def alibi_mod(score, b, h, q_idx, kv_idx): from torch.nn.attention.flex_attention import create_nested_block_mask + def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx + query, key, value, _ = gen_batch(N, E_q, E_k, E_v, device) block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True) -query = ( - query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() -) +query = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() -value = ( - value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() -) +value = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_() out_flex = flex_attention(query, key, value, block_mask=block_mask) ############################################################################### # Packed Projection # ----------------- -# +# # Packed projection is a technique that makes use of the fact that when the input # for projection (matrix multiplications) are the same (self-attention), we can pack the projection # weights and biases into single tensors. It is especially useful when the individual # projections are memory bound rather than compute bound. There are # two examples that we will demonstrate here: -# +# # * Input projection for MultiheadAttention # * SwiGLU activation in feed-forward network of Transformer Layer -# +# # Input projection for MultiheadAttention # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # When doing self-attention, the ``query``, ``key``, and ``value`` -# are the same tensor. Each of these tensors is projected with a +# are the same tensor. Each of these tensors is projected with a # ``Linear(E_q, E_total)`` layer. Instead, we can pack this into one layer, # which is what we do in the MultiheadAttention layer above. -# +# # Let us compare the performance of the packed projection against the usual method: + class InputProjection(nn.Module): def __init__(self, E_q, E_total, bias=False, device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs) self.k_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs) self.v_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs) - def forward(self, x): + def forward(self, x): return self.q_proj(x), self.k_proj(x), self.v_proj(x) + class PackedInputProjection(nn.Module): def __init__(self, E_q, E_total, bias=False, device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs) - def forward(self, query): + def forward(self, query): return torch.chunk(self.packed_proj(query), 3, dim=-1) + B, D, dtype = 256, 8192, torch.bfloat16 -torch.set_float32_matmul_precision('high') -in_proj = torch.compile(InputProjection(D, D, device='cuda', dtype=torch.bfloat16)) -packed_in_proj = torch.compile(PackedInputProjection(D, D, device='cuda', dtype=torch.bfloat16)) +torch.set_float32_matmul_precision("high") +in_proj = torch.compile(InputProjection(D, D, device="cuda", dtype=torch.bfloat16)) +packed_in_proj = torch.compile( + PackedInputProjection(D, D, device="cuda", dtype=torch.bfloat16) +) -q, _, _, sequence_lengths = gen_batch(B, D, D, D, device='cuda', dtype=torch.bfloat16) +q, _, _, sequence_lengths = gen_batch(B, D, D, D, device="cuda", dtype=torch.bfloat16) # warmup in_proj(q) @@ -691,7 +790,9 @@ def forward(self, query): (q_out, k_out, v_out), time, _ = benchmark(in_proj, q) (q_out, k_out, v_out), time_packed, _ = benchmark(packed_in_proj, q) # On my A100 prints 1.05x speedup -print(f"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x") +print( + f"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x" +) ################################################## # SwiGLU feed forward network of Transformer Layer @@ -699,9 +800,18 @@ def forward(self, query): # Swish-Gated Linear Unit (SwiGLU) is a non-linear activation function that is increasingly popular in the feed-forward # network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as: + class SwiGLUFFN(nn.Module): - def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + dim, + hidden_dim, + multiple_of, + ffn_dim_multiplier=None, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier @@ -712,16 +822,26 @@ def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device self.w1 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs) self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs) self.w3 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs) - + def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) + ######################################################################## -# An alternative way of implementing this that uses packed projection is +# An alternative way of implementing this that uses packed projection is + class PackedSwiGLUFFN(nn.Module): - def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + dim, + hidden_dim, + multiple_of, + ffn_dim_multiplier=None, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier @@ -731,19 +851,22 @@ def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False, **factory_kwargs) self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs) - + def forward(self, x): x1, x3 = torch.chunk(self.w13(x), 2, dim=-1) return self.w2(F.silu(x1) * x3) + ################################################################################ # We can compare the performance of the two implementations as follows # Depending on your hardware, you might see different results. On an A100 I see # 1.12x speedup for D=128. D = 128 -swigluffn = torch.compile(SwiGLUFFN(D, D * 4, 256, device='cuda', dtype=torch.bfloat16)) -packed_swigluffn = torch.compile(PackedSwiGLUFFN(D, D * 4, 256, device='cuda', dtype=torch.bfloat16)) +swigluffn = torch.compile(SwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16)) +packed_swigluffn = torch.compile( + PackedSwiGLUFFN(D, D * 4, 256, device="cuda", dtype=torch.bfloat16) +) q, _, _, sentence_lengths = gen_batch(D, D, D, D, device="cuda", dtype=torch.bfloat16) @@ -755,12 +878,14 @@ def forward(self, x): _, time, _ = benchmark(swigluffn, q) _, time_packed, _ = benchmark(packed_swigluffn, q) # On my A100 prints 1.08x speedup -print(f"SwiGLUFFN: {time} s, PackedSwiGLUFFN: {time_packed} s, speedup: {time/time_packed:.2f}x") +print( + f"SwiGLUFFN: {time} s, PackedSwiGLUFFN: {time_packed} s, speedup: {time/time_packed:.2f}x" +) ################################################################################ # Extended examples # ----------------- -# +# # We intend to update this tutorial to demonstrate more examples of how to use # the various performant building blocks such as KV-Caching, Grouped Query Attention # etc. Further, there are several good examples of using various performant building blocks to @@ -774,7 +899,7 @@ def forward(self, x): ################################################################################ # Conclusion # ---------- -# +# # In this tutorial, we have introduced the low level building blocks PyTorch # provides for writing transformer layers and demonstrated examples how to compose # them. It is our hope that this tutorial has educated the reader on the ease with