1
1
"""
2
- (Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)
3
- ==========================================================================================
2
+ (Beta) Scaled Dot Product Attention (SDPA)λ‘ κ³ μ±λ₯ νΈλμ€ν¬λ¨Έ(Transformers) ꡬννκΈ°
3
+ =================================================================================
4
4
5
5
6
- **Author:** `Driss Guessous <https://github.com/drisspg>`_
6
+ **μ μ:** `Driss Guessous <https://github.com/drisspg>`_
7
+ **λ²μ:** `μ΄κ°ν¬ <https://github.com/khleexv>`_
7
8
"""
8
9
9
10
######################################################################
10
- # Summary
11
- # ~~~~~~~~
11
+ # μμ½
12
+ # ~~~~
12
13
#
13
- # In this tutorial, we want to highlight a new ``torch.nn.functional`` function
14
- # that can be helpful for implementing transformer architectures. The
15
- # function is named `` torch.nn.functional.scaled_dot_product_attention``.
16
- # For detailed description of the function, see the `PyTorch documentation <https://pytorch.org/docs/master/generated/ torch.nn.functional.scaled_dot_product_attention.html# torch.nn.functional.scaled_dot_product_attention>`__.
17
- # This function has already been incorporated into ``torch.nn.MultiheadAttention`` and ``torch.nn.TransformerEncoderLayer`` .
14
+ # μ΄ νν 리μΌμμ, νΈλμ€ν¬λ¨Έ(Transformer) μν€ν
μ² κ΅¬νμ λμμ΄ λλ μλ‘μ΄
15
+ # ``torch.nn.functional`` λͺ¨λμ ν¨μλ₯Ό μκ°ν©λλ€. μ΄ ν¨μμ μ΄λ¦μ ``torch.nn.functional.scaled_dot_product_attention``
16
+ # μ
λλ€. ν¨μμ λν μμΈν μ€λͺ
μ `PyTorch λ¬Έμ <https://pytorch.org/docs/master/generated/ torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention>`__
17
+ # λ₯Ό μ°Έκ³ νμΈμ. μ΄ ν¨μλ μ΄λ―Έ `` torch.nn.MultiheadAttention`` κ³Ό `` torch.nn.TransformerEncoderLayer``
18
+ # μμ μ¬μ©λκ³ μμ΅λλ€ .
18
19
#
19
- # Overview
20
- # ~~~~~~~~~
21
- # At a high level, this PyTorch function calculates the
22
- # scaled dot product attention (SDPA) between query, key, and value according to
23
- # the definition found in the paper `Attention is all you
24
- # need <https://arxiv.org/abs/1706.03762>`__. While this function can
25
- # be written in PyTorch using existing functions, a fused implementation can provide
26
- # large performance benefits over a naive implementation.
20
+ # κ°μ
21
+ # ~~~~
22
+ # κ³ μμ€μμ, μ΄ PyTorch ν¨μλ 쿼리(query), ν€(key), κ°(value) μ¬μ΄μ
23
+ # scaled dot product attention (SDPA)μ κ³μ°ν©λλ€.
24
+ # μ΄ ν¨μμ μ μλ `Attention is all you need <https://arxiv.org/abs/1706.03762>`__
25
+ # λ
Όλ¬Έμμ μ°Ύμ μ μμ΅λλ€. μ΄ ν¨μλ κΈ°μ‘΄ ν¨μλ₯Ό μ¬μ©νμ¬ PyTorchλ‘ μμ±ν μ μμ§λ§,
26
+ # ν¨μ¦λ(fused) ꡬνμ λ¨μν ꡬνλ³΄λ€ ν° μ±λ₯ μ΄μ μ μ 곡ν μ μμ΅λλ€.
27
27
#
28
- # Fused implementations
28
+ # ν¨μ¦λ ꡬν
29
29
# ~~~~~~~~~~~~~~~~~~~~~~
30
30
#
31
- # For CUDA tensor inputs, the function will dispatch into one of the following
32
- # implementations:
31
+ # μ΄ ν¨μλ CUDA tensor μ
λ ₯μ λ€μ μ€ νλμ ꡬνμ μ¬μ©ν©λλ€.
32
+ #
33
+ # ꡬν:
33
34
#
34
35
# * `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness <https://arxiv.org/abs/2205.14135>`__
35
36
# * `Memory-Efficient Attention <https://github.com/facebookresearch/xformers>`__
36
37
# * A PyTorch implementation defined in C++
37
38
#
38
39
# .. note::
39
40
#
40
- # This tutorial requires PyTorch 2.0.0 or later .
41
+ # μ΄ νν 리μΌμ PyTorch λ²μ 2.0.0 μ΄μμ΄ νμν©λλ€ .
41
42
#
42
43
43
44
import torch
44
45
import torch .nn as nn
45
46
import torch .nn .functional as F
46
47
device = "cuda" if torch .cuda .is_available () else "cpu"
47
48
48
- # Example Usage :
49
+ # μ¬μ© μμ :
49
50
query , key , value = torch .randn (2 , 3 , 8 , device = device ), torch .randn (2 , 3 , 8 , device = device ), torch .randn (2 , 3 , 8 , device = device )
50
51
F .scaled_dot_product_attention (query , key , value )
51
52
52
53
53
54
######################################################################
54
- # Explicit Dispatcher Control
55
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~
56
- #
57
- # While the function will implicitly dispatch to one of the three
58
- # implementations, the user can also explicitly control the dispatch via
59
- # the use of a context manager. This context manager allows users to
60
- # explicitly disable certain implementations. If a user wants to ensure
61
- # the function is indeed using the fastest implementation for their
62
- # specific inputs, the context manager can be used to sweep through
63
- # measuring performance.
55
+ # λͺ
μμ Dispatcher μ μ΄
56
+ # ~~~~~~~~~~~~~~~~~~~~
64
57
#
58
+ # μ΄ ν¨μλ μμμ μΌλ‘ μΈ κ°μ§ ꡬν μ€ νλλ₯Ό μ¬μ©ν©λλ€. νμ§λ§ 컨ν
μ€νΈ λ§€λμ λ₯Ό
59
+ # μ¬μ©νλ©΄ λͺ
μμ μΌλ‘ μ΄λ€ ꡬνμ μ¬μ©ν μ§ μ μ΄ν μ μμ΅λλ€. 컨ν
μ€νΈ λ§€λμ λ₯Ό ν΅ν΄
60
+ # νΉμ ꡬνμ λͺ
μμ μΌλ‘ λΉνμ±ν ν μ μμ΅λλ€. νΉμ μ
λ ₯μ λν κ°μ₯ λΉ λ₯Έ ꡬνμ μ°Ύκ³ μ
61
+ # νλ€λ©΄, 컨ν
μ€νΈ λ§€λμ λ‘ λͺ¨λ ꡬνμ μ±λ₯μ μΈ‘μ ν΄λ³Ό μ μμ΅λλ€.
65
62
66
- # Lets define a helpful benchmarking function:
63
+ # λ²€μΉλ§ν¬ ν¨μλ₯Ό μ μν©λλ€
67
64
import torch .utils .benchmark as benchmark
68
65
def benchmark_torch_function_in_microseconds (f , * args , ** kwargs ):
69
66
t0 = benchmark .Timer (
70
67
stmt = "f(*args, **kwargs)" , globals = {"args" : args , "kwargs" : kwargs , "f" : f }
71
68
)
72
69
return t0 .blocked_autorange ().mean * 1e6
73
70
74
- # Lets define the hyper-parameters of our input
71
+ # μ
λ ₯μ νμ΄νΌνλΌλ―Έν°λ₯Ό μ μν©λλ€
75
72
batch_size = 32
76
73
max_sequence_len = 1024
77
74
num_heads = 32
@@ -85,7 +82,7 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
85
82
86
83
print (f"The default implementation runs in { benchmark_torch_function_in_microseconds (F .scaled_dot_product_attention , query , key , value ):.3f} microseconds" )
87
84
88
- # Lets explore the speed of each of the 3 implementations
85
+ # μΈ κ°μ§ ꡬνμ μλλ₯Ό μΈ‘μ ν©λλ€
89
86
from torch .backends .cuda import sdp_kernel , SDPBackend
90
87
91
88
# Helpful arguments mapper
@@ -114,24 +111,22 @@ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
114
111
115
112
116
113
######################################################################
117
- # Hardware dependence
118
- # ~~~~~~~~~~~~~~~~~~~
114
+ # νλμ¨μ΄ μμ‘΄μ±
115
+ # ~~~~~~~~~~~~~
119
116
#
120
- # Depending on what machine you ran the above cell on and what hardware is
121
- # available, your results might be different.
122
- # - If you donβt have a GPU and are running on CPU then the context manager
123
- # will have no effect and all three runs should return similar timings.
124
- # - Depending on what compute capability your graphics card supports
125
- # flash attention or memory efficient might have failed.
117
+ # μ μ
μ μ΄λ€ λ¨Έμ μμ μ€ννλμ§μ μ¬μ© κ°λ₯ν νλμ¨μ΄μ λ°λΌ κ²°κ³Όκ° λ€λ₯Ό μ μμ΅λλ€.
118
+ # - GPUκ° μκ³ CPUμμ μ€ν μ€μ΄λΌλ©΄ 컨ν
μ€νΈ λ§€λμ λ ν¨κ³Όκ° μκ³ μΈ κ°μ§ μ€ν λͺ¨λ
119
+ # μ μ¬ν μκ°μ λ°νν κ²μ
λλ€.
120
+ # - κ·Έλν½ μΉ΄λκ° μ§μνλ μ»΄ν¨ν
λ₯λ ₯μ λ°λΌ flash attention λλ
121
+ # memory efficient ꡬνμ΄ λμνμ§ μμ μ μμ΅λλ€.
126
122
127
123
128
124
######################################################################
129
125
# Causal Self Attention
130
126
# ~~~~~~~~~~~~~~~~~~~~~
131
127
#
132
- # Below is an example implementation of a multi-headed causal self
133
- # attention block inspired by
134
- # `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository.
128
+ # μλλ multi-head causal self attention λΈλ‘μ ꡬν μμμ
λλ€.
129
+ # `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ μ μ₯μλ₯Ό μ°Έκ³ νμ΅λλ€.
135
130
#
136
131
137
132
class CausalSelfAttention (nn .Module ):
@@ -187,12 +182,13 @@ def forward(self, x):
187
182
188
183
189
184
#####################################################################
190
- # ``NestedTensor`` and Dense tensor support
191
- # -----------------------------------------
185
+ # ``NestedTensor`` λ° Dense tensor μ§μ
186
+ # ------------------------------------
192
187
#
193
- # SDPA supports both ``NestedTensor`` and Dense tensor inputs. ``NestedTensors`` handle the case where the input is a batch of variable length sequences
194
- # without needing to pad each sequence to the maximum length in the batch. For more information about ``NestedTensors`` see
195
- # `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ and `NestedTensors Tutorial <https://tutorials.pytorch.kr/prototype/nestedtensor.html>`__.
188
+ # SDPAλ ``NestedTensor`` μ Dense tensor μ
λ ₯μ λͺ¨λ μ§μν©λλ€.
189
+ # ``NestedTensors`` λ μ
λ ₯μ΄ κ°λ³ κΈΈμ΄ μνμ€λ‘ ꡬμ±λ λ°°μΉμΈ κ²½μ°μ
190
+ # λ°°μΉ λ΄ μνμ€μ μ΅λ κΈΈμ΄μ λ§μΆ° κ° μνμ€λ₯Ό ν¨λ©ν νμκ° μμ΅λλ€. ``NestedTensors`` μ λν μμΈν λ΄μ©μ
191
+ # `torch.nested <https://pytorch.org/docs/stable/nested.html>`__ μ `NestedTensors νν λ¦¬μΌ <https://tutorials.pytorch.kr/prototype/nestedtensor.html>`__ μ μ°Έκ³ νμΈμ.
196
192
#
197
193
198
194
import random
@@ -236,7 +232,7 @@ def generate_rand_batch(
236
232
random_nt , _ = generate_rand_batch (32 , 512 , embed_dimension , pad_percentage = 0.5 , dtype = dtype , device = device )
237
233
random_dense , _ = generate_rand_batch (32 , 512 , embed_dimension , pad_percentage = None , dtype = dtype , device = device )
238
234
239
- # Currently the fused implementations don't support ``NestedTensor`` for training
235
+ # νμ¬ ν¨μ¦λ ꡬνμ ``NestedTensor`` λ‘ νμ΅νλ κ²μ μ§μνμ§ μμ΅λλ€.
240
236
model .eval ()
241
237
242
238
with sdp_kernel (** backend_map [SDPBackend .FLASH_ATTENTION ]):
@@ -248,15 +244,14 @@ def generate_rand_batch(
248
244
249
245
250
246
######################################################################
251
- # Using SDPA with ``torch.compile``
252
- # =================================
247
+ # ``torch.compile`` κ³Ό ν¨κ» SDPA μ¬μ©νκΈ°
248
+ # =====================================
253
249
#
254
- # With the release of PyTorch 2.0, a new feature called
255
- # ``torch.compile()`` has been introduced, which can provide
256
- # significant performance improvements over eager mode.
257
- # Scaled dot product attention is fully composable with ``torch.compile()``.
258
- # To demonstrate this, let's compile the ``CausalSelfAttention`` module using
259
- # ``torch.compile()`` and observe the resulting performance improvements.
250
+ # PyTorch 2.0 릴리μ¦μ ν¨κ» ``torch.compile()`` λΌλ μλ‘μ΄ κΈ°λ₯μ΄ μΆκ°λμλλ°,
251
+ # μ΄λ eager modeλ³΄λ€ μλΉν μ±λ₯ ν₯μμ μ 곡ν μ μμ΅λλ€.
252
+ # Scaled dot product attentionμ ``torch.compile()`` λ‘ μμ ν ꡬμ±ν μ μμ΅λλ€.
253
+ # μ΄λ₯Ό νμΈνκΈ° μν΄ ``torch.compile()`` μ ν΅ν΄ ``CausalSelfAttention`` λͺ¨λμ μ»΄νμΌνκ³
254
+ # κ²°κ³Όμ μΌλ‘ μ»μ΄μ§λ μ±λ₯ ν₯μμ μμλ΄
μλ€.
260
255
#
261
256
262
257
batch_size = 32
@@ -276,12 +271,11 @@ def generate_rand_batch(
276
271
277
272
######################################################################
278
273
#
279
- # The exact execution time is dependent on machine, however the results for mine:
280
- # The non compiled module runs in 166.616 microseconds
281
- # The compiled module runs in 166.726 microseconds
282
- # That is not what we were expecting. Let's dig a little deeper.
283
- # PyTorch comes with an amazing built-in profiler that you can use to
284
- # inspect the performance characteristics of your code.
274
+ # μ νν μ€ν μκ°μ νκ²½μ λ°λΌ λ€λ₯΄μ§λ§, λ€μμ μ μμ κ²°κ³Όμ
λλ€.
275
+ # μ»΄νμΌ λμ§ μμ λͺ¨λμ μ€νμ 166.616ms κ° μμλμμ΅λλ€.
276
+ # μ»΄νμΌ λ λͺ¨λμ μ€νμ 166.726ms κ° μμλμμ΅λλ€.
277
+ # μ΄λ μ°λ¦¬μ μμκ³Όλ λ€λ¦
λλ€. μ’ λ μμΈν μμλ΄
μλ€.
278
+ # PyTorchλ μ½λμ μ±λ₯ νΉμ±μ μ κ²ν μ μλ λλΌμ΄ λ΄μ₯(built-in) νλ‘νμΌλ¬λ₯Ό μ 곡ν©λλ€.
285
279
#
286
280
287
281
from torch .profiler import profile , record_function , ProfilerActivity
@@ -302,7 +296,7 @@ def generate_rand_batch(
302
296
compiled_model (x )
303
297
print (prof .key_averages ().table (sort_by = "cuda_time_total" , row_limit = 10 ))
304
298
305
- # For even more insights, you can export the trace and use ``chrome://tracing`` to view the results
299
+ # λ λ§μ μ 보λ₯Ό μ»κΈ° μν΄ μΆμ ( trace)λ₯Ό λ΄λ³΄λ΄κ³ ``chrome://tracing``μ μ¬μ©νμ¬ κ²°κ³Όλ₯Ό νμΈν΄λ³΄μΈμ.
306
300
# ::
307
301
#
308
302
# prof.export_chrome_trace("compiled_causal_attention_trace.json").
@@ -311,33 +305,30 @@ def generate_rand_batch(
311
305
312
306
313
307
######################################################################
314
- # The previous code snippet generates a report of the top 10 PyTorch functions
315
- # that consumed the most GPU execution time, for both the compiled and non-compiled module.
316
- # The analysis reveals that the majority of time spent on the GPU is concentrated
317
- # on the same set of functions for both modules.
318
- # The reason for this here is that ``torch.compile`` is very good at removing the
319
- # framework overhead associated with PyTorch. If your model is launching
320
- # large, efficient CUDA kernels, which in this case ``CausaulSelfAttention``
321
- # is, then the overhead of PyTorch can be hidden.
308
+ # μ΄μ μ½λ μ‘°κ°(snippet)μ μ»΄νμΌ λ λͺ¨λκ³Ό μ»΄νμΌλμ§ μμ λͺ¨λ λͺ¨λμ λν΄
309
+ # κ°μ₯ λ§μ GPU μ€ν μκ°μ μ°¨μ§ν μμ 10κ°μ PyTorch ν¨μμ λν λ³΄κ³ μλ₯Ό μμ±ν©λλ€.
310
+ # λΆμ κ²°κ³Ό, λ λͺ¨λ λͺ¨λ GPUμμ μμλ μκ°μ λλΆλΆμ΄
311
+ # λμΌν ν¨μλ€μ μ§μ€λμ΄ μμμ 보μ¬μ€λλ€.
312
+ # PyTorchκ° νλ μμν¬ μ€λ²ν€λλ₯Ό μ κ±°νλ λ° λ§€μ° νμν ``torch.compile`` λ₯Ό
313
+ # μ 곡νκΈ° λλ¬Έμ
λλ€. ``CausalSelfAttention`` κ°μ κ²½μ°μ²λΌ ν¬κ³ , ν¨μ¨μ μΈ CUDA 컀λμ
314
+ # μ¬μ©νλ λͺ¨λΈμμ PyTorch μ€λ²ν€λλ μμμ§ κ²μ
λλ€.
322
315
#
323
- # In reality, your module does not normally consist of a singular
324
- # ``CausalSelfAttention`` block. When experimenting with ` Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ repository, compiling
325
- # the module took the time per train step from: ``6090.49ms`` to
326
- # ``3273.17ms``! This was done on commit: ``ae3a8d5`` of NanoGPT training on
327
- # the Shakespeare dataset .
316
+ # μ¬μ€, λͺ¨λμ λ³΄ν΅ ``CausalSelfAttention`` λΈλ νλλ§μΌλ‘ ꡬμ±λμ§ μμ΅λλ€.
317
+ # `Andrej Karpathy NanoGPT <https://github.com/karpathy/nanoGPT>`__ μ μ₯μμμ μ€νν κ²½μ°,
318
+ # λͺ¨λμ μ»΄νμΌ νλ κ²μ νμ΅μ κ° λ¨κ³λ³ μμ μκ°μ ``6090.49ms`` μμ ``3273.17ms`` λ‘
319
+ # μ€μΌ μ μμμ΅λλ€. μ΄ μ€νμ NanoGPT μ μ₯μμ ``ae3a8d5`` 컀λ°μμ Shakespeare
320
+ # λ°μ΄ν°μ
μ μ¬μ©νμ¬ μ§νλμμ΅λλ€ .
328
321
#
329
322
330
323
331
324
######################################################################
332
- # Conclusion
333
- # ==========
325
+ # κ²°λ‘
326
+ # ====
334
327
#
335
- # In this tutorial, we have demonstrated the basic usage of
336
- # ``torch.nn.functional.scaled_dot_product_attention``. We have shown how
337
- # the ``sdp_kernel`` context manager can be used to assert a certain
338
- # implementation is used on GPU. As well, we built a simple
339
- # ``CausalSelfAttention`` module that works with ``NestedTensor`` and is torch
340
- # compilable. In the process we have shown how to the profiling tools can
341
- # be used to explore the performance characteristics of a user defined
342
- # module.
328
+ # μ΄ νν 리μΌμμ, ``torch.nn.functional.scaled_dot_product_attention`` μ κΈ°λ³Έμ μΈ
329
+ # μ¬μ©λ²μ μ΄ν΄λ΄€μ΅λλ€. ``sdp_kernel`` 컨ν
μ€νΈ λ§€λμ λ‘ GPUκ° νΉμ ꡬνμ
330
+ # μ¬μ©νλλ‘ ν μ μλ€λ κ²μ 보μμ΅λλ€. λν, κ°λ¨ν ``NestedTensor`` μμ μλνκ³
331
+ # μ»΄νμΌ κ°λ₯ν ``CausalSelfAttention`` λͺ¨λμ λ§λ€μμ΅λλ€.
332
+ # μ΄ κ³Όμ μμ νλ‘νμΌλ§ λꡬλ₯Ό μ¬μ©νμ¬ μ μ κ° μ μν λͺ¨λμ μ±λ₯ νΉμ±μ μ΄λ»κ²
333
+ # νμΈν μ μλμ§λ μ΄ν΄λ΄€μ΅λλ€.
343
334
#
0 commit comments