Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 82 additions & 91 deletions intermediate_source/scaled_dot_product_attention_tutorial.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,74 @@
"""
(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)
==========================================================================================
(Beta) Scaled Dot Product Attention (SDPA)로 고성능 트랜스포머(Transformers) 구현하기
=================================================================================


**Author:** `Driss Guessous <https://github.com/drisspg>`_
**저자:** `Driss Guessous <https://github.com/drisspg>`_
**번역:** `이강희 <https://github.com/khleexv>`_
"""

######################################################################
# Summary
# ~~~~~~~~
# 요약
# ~~~~
#
# In this tutorial, we want to highlight a new ``torch.nn.functional`` function
# that can be helpful for implementing transformer architectures. The
# function is named ``torch.nn.functional.scaled_dot_product_attention``.
# For detailed description of the function, see the `PyTorch documentation <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__.
# This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer``.
# 이 튜토리얼에서, 트랜스포머(Transformer) 아키텍처 구현에 도움이 되는 새로운
# ``torch.nn.functional`` 모듈의 함수를 소개합니다. 이 함수의 이름은 ``torch.nn.functional.scaled_dot_product_attention``
# 입니다. 함수에 대한 자세한 설명은 `PyTorch 문서 <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__
# 를 참고하세요. 이 함수는 이미 ``torch.nn.MultiheadAttention`` 과 ``torch.nn.TransformerEncoderLayer``
# 에서 사용되고 있습니다.
#
# Overview
# ~~~~~~~~~
# At a high level, this PyTorch function calculates the
# scaled dot product attention (SDPA) between query, key, and value according to
# the definition found in the paper `Attention is all you
# need <https://arxiv.org/abs/1706.03762>`__. While this function can
# be written in PyTorch using existing functions, a fused implementation can provide
# large performance benefits over a naive implementation.
# 개요
# ~~~~
# 고수준에서, 이 PyTorch 함수는 쿼리(query), 키(key), 값(value) 사이의
# scaled dot product attention (SDPA)을 계산합니다.
# 이 함수의 정의는 `Attention is all you need <https://arxiv.org/abs/1706.03762>`__
# 논문에서 찾을 수 있습니다. 이 함수는 기존 함수를 사용하여 PyTorch로 작성할 수 있지만,
# 퓨즈드(fused) 구현은 단순한 구현보다 큰 성능 이점을 제공할 수 있습니다.
#
# Fused implementations
# 퓨즈드 구현
# ~~~~~~~~~~~~~~~~~~~~~~
#
# For CUDA tensor inputs, the function will dispatch into one of the following
# implementations:
# 이 함수는 CUDA tensor 입력을 다음 중 하나의 구현을 사용합니다.
#
# 구현:
#
# * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
# * `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__
# * A PyTorch implementation defined in C++
#
# .. note::
#
# This tutorial requires PyTorch 2.0.0 or later.
# 이 튜토리얼은 PyTorch 버전 2.0.0 이상이 필요합니다.
#

import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

# Example Usage:
# 사용 예시:
query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)
F.scaled_dot_product_attention(query, key, value)


######################################################################
# Explicit Dispatcher Control
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# While the function will implicitly dispatch to one of the three
# implementations, the user can also explicitly control the dispatch via
# the use of a context manager. This context manager allows users to
# explicitly disable certain implementations. If a user wants to ensure
# the function is indeed using the fastest implementation for their
# specific inputs, the context manager can be used to sweep through
# measuring performance.
# 명시적 Dispatcher 제어
# ~~~~~~~~~~~~~~~~~~~~
#
# 이 함수는 암시적으로 세 가지 구현 중 하나를 사용합니다. 하지만 컨텍스트 매니저를
# 사용하면 명시적으로 어떤 구현을 사용할 지 제어할 수 있습니다. 컨텍스트 매니저를 통해
# 특정 구현을 명시적으로 비활성화 할 수 있습니다. 특정 입력에 대한 가장 빠른 구현을 찾고자
# 한다면, 컨텍스트 매니저로 모든 구현의 성능을 측정해볼 수 있습니다.

# Lets define a helpful benchmarking function:
# 벤치마크 함수를 정의합니다
import torch.utils.benchmark as benchmark
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
# 입력의 하이퍼파라미터를 정의합니다
batch_size = 32
max_sequence_len = 1024
num_heads = 32
Expand All @@ -85,7 +82,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations
# 세 가지 구현의 속도를 측정합니다
from torch.backends.cuda import sdp_kernel, SDPBackend

# Helpful arguments mapper
Expand Down Expand Up @@ -114,24 +111,22 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):


######################################################################
# Hardware dependence
# ~~~~~~~~~~~~~~~~~~~
# 하드웨어 의존성
# ~~~~~~~~~~~~~
#
# Depending on what machine you ran the above cell on and what hardware is
# available, your results might be different.
# - If you don’t have a GPU and are running on CPU then the context manager
# will have no effect and all three runs should return similar timings.
# - Depending on what compute capability your graphics card supports
# flash attention or memory efficient might have failed.
# 위 셀을 어떤 머신에서 실행했는지와 사용 가능한 하드웨어에 따라 결과가 다를 수 있습니다.
# - GPU가 없고 CPU에서 실행 중이라면 컨텍스트 매니저는 효과가 없고 세 가지 실행 모두
# 유사한 시간을 반환할 것입니다.
# - 그래픽 카드가 지원하는 컴퓨팅 능력에 따라 flash attention 또는
# memory efficient 구현이 동작하지 않을 수 있습니다.


######################################################################
# Causal Self Attention
# ~~~~~~~~~~~~~~~~~~~~~
#
# Below is an example implementation of a multi-headed causal self
# attention block inspired by
# `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository.
# 아래는 multi-head causal self attention 블록의 구현 예시입니다.
# `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ 저장소를 참고했습니다.
#

class CausalSelfAttention(nn.Module):
Expand Down Expand Up @@ -187,12 +182,13 @@ def forward(self, x):


#####################################################################
# ``NestedTensor`` and Dense tensor support
# -----------------------------------------
# ``NestedTensor`` Dense tensor 지원
# ------------------------------------
#
# SDPA supports both ``NestedTensor`` and Dense tensor inputs. ``NestedTensors`` handle the case where the input is a batch of variable length sequences
# without needing to pad each sequence to the maximum length in the batch. For more information about ``NestedTensors`` see
# `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://tutorials.pytorch.kr/prototype/nestedtensor.html>`__.
# SDPA는 ``NestedTensor`` 와 Dense tensor 입력을 모두 지원합니다.
# ``NestedTensors`` 는 입력이 가변 길이 시퀀스로 구성된 배치인 경우에
# 배치 내 시퀀스의 최대 길이에 맞춰 각 시퀀스를 패딩할 필요가 없습니다. ``NestedTensors`` 에 대한 자세한 내용은
# `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ 와 `NestedTensors 튜토리얼 <https://tutorials.pytorch.kr/prototype/nestedtensor.html>`__ 을 참고하세요.
#

import random
Expand Down Expand Up @@ -236,7 +232,7 @@ def generate_rand_batch(
random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)
random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)

# Currently the fused implementations don't support ``NestedTensor`` for training
# 현재 퓨즈드 구현은 ``NestedTensor`` 로 학습하는 것을 지원하지 않습니다.
model.eval()

with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
Expand All @@ -248,15 +244,14 @@ def generate_rand_batch(


######################################################################
# Using SDPA with ``torch.compile``
# =================================
# ``torch.compile`` 과 함께 SDPA 사용하기
# =====================================
#
# With the release of PyTorch 2.0, a new feature called
# ``torch.compile()`` has been introduced, which can provide
# significant performance improvements over eager mode.
# Scaled dot product attention is fully composable with ``torch.compile()``.
# To demonstrate this, let's compile the ``CausalSelfAttention`` module using
# ``torch.compile()`` and observe the resulting performance improvements.
# PyTorch 2.0 릴리즈와 함께 ``torch.compile()`` 라는 새로운 기능이 추가되었는데,
# 이는 eager mode보다 상당한 성능 향상을 제공할 수 있습니다.
# Scaled dot product attention은 ``torch.compile()`` 로 완전히 구성할 수 있습니다.
# 이를 확인하기 위해 ``torch.compile()`` 을 통해 ``CausalSelfAttention`` 모듈을 컴파일하고
# 결과적으로 얻어지는 성능 향상을 알아봅시다.
#

batch_size = 32
Expand All @@ -276,12 +271,11 @@ def generate_rand_batch(

######################################################################
#
# The exact execution time is dependent on machine, however the results for mine:
# The non compiled module runs in 166.616 microseconds
# The compiled module runs in 166.726 microseconds
# That is not what we were expecting. Let's dig a little deeper.
# PyTorch comes with an amazing built-in profiler that you can use to
# inspect the performance characteristics of your code.
# 정확한 실행 시간은 환경에 따라 다르지만, 다음은 저자의 결과입니다.
# 컴파일 되지 않은 모듈은 실행에 166.616ms 가 소요되었습니다.
# 컴파일 된 모듈은 실행에 166.726ms 가 소요되었습니다.
# 이는 우리의 예상과는 다릅니다. 좀 더 자세히 알아봅시다.
# PyTorch는 코드의 성능 특성을 점검할 수 있는 놀라운 내장(built-in) 프로파일러를 제공합니다.
#

from torch.profiler import profile, record_function, ProfilerActivity
Expand All @@ -302,7 +296,7 @@ def generate_rand_batch(
compiled_model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
# 더 많은 정보를 얻기 위해 추적(trace)를 내보내고 ``chrome://tracing``을 사용하여 결과를 확인해보세요.
# ::
#
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
Expand All @@ -311,33 +305,30 @@ def generate_rand_batch(


######################################################################
# The previous code snippet generates a report of the top 10 PyTorch functions
# that consumed the most GPU execution time, for both the compiled and non-compiled module.
# The analysis reveals that the majority of time spent on the GPU is concentrated
# on the same set of functions for both modules.
# The reason for this here is that ``torch.compile`` is very good at removing the
# framework overhead associated with PyTorch. If your model is launching
# large, efficient CUDA kernels, which in this case ``CausaulSelfAttention``
# is, then the overhead of PyTorch can be hidden.
# 이전 코드 조각(snippet)은 컴파일 된 모듈과 컴파일되지 않은 모듈 모두에 대해
# 가장 많은 GPU 실행 시간을 차지한 상위 10개의 PyTorch 함수에 대한 보고서를 생성합니다.
# 분석 결과, 두 모듈 모두 GPU에서 소요된 시간의 대부분이
# 동일한 함수들에 집중되어 있음을 보여줍니다.
# PyTorch가 프레임워크 오버헤드를 제거하는 데 매우 탁월한 ``torch.compile`` 를
# 제공하기 때문입니다. ``CausalSelfAttention`` 같은 경우처럼 크고, 효율적인 CUDA 커널을
# 사용하는 모델에서 PyTorch 오버헤드는 작아질 것입니다.
#
# In reality, your module does not normally consist of a singular
# ``CausalSelfAttention`` block. When experimenting with `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository, compiling
# the module took the time per train step from: ``6090.49ms`` to
# ``3273.17ms``! This was done on commit: ``ae3a8d5`` of NanoGPT training on
# the Shakespeare dataset.
# 사실, 모듈은 보통 ``CausalSelfAttention`` 블럭 하나만으로 구성되지 않습니다.
# `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ 저장소에서 실험한 경우,
# 모듈을 컴파일 하는 것은 학습의 각 단계별 소요 시간을 ``6090.49ms`` 에서 ``3273.17ms`` 로
# 줄일 수 있었습니다. 이 실험은 NanoGPT 저장소의 ``ae3a8d5`` 커밋에서 Shakespeare
# 데이터셋을 사용하여 진행되었습니다.
#


######################################################################
# Conclusion
# ==========
# 결론
# ====
#
# In this tutorial, we have demonstrated the basic usage of
# ``torch.nn.functional.scaled_dot_product_attention``. We have shown how
# the ``sdp_kernel`` context manager can be used to assert a certain
# implementation is used on GPU. As well, we built a simple
# ``CausalSelfAttention`` module that works with ``NestedTensor`` and is torch
# compilable. In the process we have shown how to the profiling tools can
# be used to explore the performance characteristics of a user defined
# module.
# 이 튜토리얼에서, ``torch.nn.functional.scaled_dot_product_attention`` 의 기본적인
# 사용법을 살펴봤습니다. ``sdp_kernel`` 컨텍스트 매니저로 GPU가 특정 구현을
# 사용하도록 할 수 있다는 것을 보았습니다. 또한, 간단한 ``NestedTensor`` 에서 작동하고
# 컴파일 가능한 ``CausalSelfAttention`` 모듈을 만들었습니다.
# 이 과정에서 프로파일링 도구를 사용하여 유저가 정의한 모듈의 성능 특성을 어떻게
# 확인할 수 있는지도 살펴봤습니다.
#