|
1 | 1 | {
|
2 |
| - "cells": [ |
3 |
| - { |
4 |
| - "cell_type": "code", |
5 |
| - "execution_count": null, |
6 |
| - "id": "29482448-d69e-4412-82b8-e6b8243699fe", |
7 |
| - "metadata": {}, |
8 |
| - "outputs": [], |
9 |
| - "source": [ |
10 |
| - "!python triton-gpu-check.py" |
11 |
| - ] |
12 |
| - }, |
13 |
| - { |
14 |
| - "cell_type": "code", |
15 |
| - "execution_count": null, |
16 |
| - "id": "591e3d51-8a52-45b5-b6a8-99acd0b1180c", |
17 |
| - "metadata": {}, |
18 |
| - "outputs": [], |
19 |
| - "source": [ |
20 |
| - "!cd triton && pip install ./python && cd -" |
21 |
| - ] |
22 |
| - }, |
23 |
| - { |
24 |
| - "cell_type": "code", |
25 |
| - "execution_count": null, |
26 |
| - "id": "2278ff2f", |
27 |
| - "metadata": {}, |
28 |
| - "outputs": [], |
29 |
| - "source": [ |
30 |
| - "import torch\n", |
31 |
| - "import triton\n", |
32 |
| - "import triton.language as tl\n", |
33 |
| - "import matplotlib.pyplot as plt\n", |
34 |
| - "import time\n", |
35 |
| - "\n", |
36 |
| - "print(\"Torch version:\", torch.__version__)\n", |
37 |
| - "print(\"Triton version:\", triton.__version__)" |
38 |
| - ] |
39 |
| - }, |
40 |
| - { |
41 |
| - "cell_type": "markdown", |
42 |
| - "id": "75bc884d", |
43 |
| - "metadata": {}, |
44 |
| - "source": [ |
45 |
| - "## Flash Attention Benchmark (PyTorch SDPA vs vLLM Kernel)\n", |
46 |
| - "This notebook benchmarks the PyTorch `scaled_dot_product_attention` against the vLLM Triton-based flash attention kernel." |
47 |
| - ] |
48 |
| - }, |
49 |
| - { |
50 |
| - "cell_type": "code", |
51 |
| - "execution_count": null, |
52 |
| - "id": "vllm-import", |
53 |
| - "metadata": {}, |
54 |
| - "outputs": [], |
55 |
| - "source": [ |
56 |
| - "# Assuming vllm_flash_attention.py is present in the same directory or accessible path\n", |
57 |
| - "from flash_attention import triton_attention as vllm_flash_attention\n", |
58 |
| - "from flash_attention import benchmark_flash_attention as vllm_benchmark\n" |
59 |
| - ] |
60 |
| - }, |
61 |
| - { |
62 |
| - "cell_type": "code", |
63 |
| - "execution_count": null, |
64 |
| - "id": "e4958854-eb3a-44d8-a52e-c4f77cefce94", |
65 |
| - "metadata": {}, |
66 |
| - "outputs": [], |
67 |
| - "source": [ |
68 |
| - "!ls /workspace/.triton/cache" |
69 |
| - ] |
70 |
| - }, |
71 |
| - { |
72 |
| - "cell_type": "code", |
73 |
| - "execution_count": null, |
74 |
| - "id": "bc0bbe33", |
75 |
| - "metadata": {}, |
76 |
| - "outputs": [], |
77 |
| - "source": [ |
78 |
| - "def run_pytorch_sdpa(q, k, v):\n", |
79 |
| - " return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0)" |
80 |
| - ] |
81 |
| - }, |
82 |
| - { |
83 |
| - "cell_type": "code", |
84 |
| - "execution_count": null, |
85 |
| - "id": "8ac54ef2-2877-49bb-aca6-c7ca2e6d55d2", |
86 |
| - "metadata": {}, |
87 |
| - "outputs": [], |
88 |
| - "source": [ |
89 |
| - "!ls /workspace/.triton/cache" |
90 |
| - ] |
91 |
| - }, |
92 |
| - { |
93 |
| - "cell_type": "code", |
94 |
| - "execution_count": null, |
95 |
| - "id": "vllm-kernel-wrapper", |
96 |
| - "metadata": {}, |
97 |
| - "outputs": [], |
98 |
| - "source": [ |
99 |
| - "def run_vllm_flash_attention(q, k, v, seqlen):\n", |
100 |
| - " q_flat = q.permute(0, 2, 1, 3).reshape(-1, q.shape[1], q.shape[3])\n", |
101 |
| - " k_flat = k.permute(0, 2, 1, 3).reshape(-1, k.shape[1], k.shape[3])\n", |
102 |
| - " v_flat = v.permute(0, 2, 1, 3).reshape(-1, v.shape[1], v.shape[3])\n", |
103 |
| - " cu_seqlens_q = torch.arange(0, q.shape[0] + 1, dtype=torch.int32, device=q.device) * seqlen\n", |
104 |
| - " cu_seqlens_k = torch.arange(0, q.shape[0] + 1, dtype=torch.int32, device=q.device) * seqlen\n", |
105 |
| - " o, _ = vllm_flash_attention(q_flat, k_flat, v_flat, None, cu_seqlens_q, cu_seqlens_k, seqlen, seqlen, False, 1.0, None)\n", |
106 |
| - " return o.view(q.shape[0], seqlen, q.shape[1], q.shape[3]).permute(0, 2, 1, 3)" |
107 |
| - ] |
108 |
| - }, |
109 |
| - { |
110 |
| - "cell_type": "code", |
111 |
| - "execution_count": null, |
112 |
| - "id": "b911bed7", |
113 |
| - "metadata": {}, |
114 |
| - "outputs": [], |
115 |
| - "source": [ |
116 |
| - "def benchmark_flash_attention(batch, nheads, head_dim, seqlen):\n", |
117 |
| - " q = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n", |
118 |
| - " k = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n", |
119 |
| - " v = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n", |
120 |
| - "\n", |
121 |
| - " torch.cuda.synchronize()\n", |
122 |
| - " start = time.time()\n", |
123 |
| - " out_torch = run_pytorch_sdpa(q, k, v)\n", |
124 |
| - " torch.cuda.synchronize()\n", |
125 |
| - " pytorch_time = time.time() - start\n", |
126 |
| - "\n", |
127 |
| - " torch.cuda.synchronize()\n", |
128 |
| - " start = time.time()\n", |
129 |
| - " out_vllm = run_vllm_flash_attention(q, k, v, seqlen)\n", |
130 |
| - " torch.cuda.synchronize()\n", |
131 |
| - " vllm_time = time.time() - start\n", |
132 |
| - "\n", |
133 |
| - " diff_vllm = torch.max(torch.abs(out_torch - out_vllm)).item()\n", |
134 |
| - " return pytorch_time, vllm_time, diff_vllm" |
135 |
| - ] |
136 |
| - }, |
137 |
| - { |
138 |
| - "cell_type": "code", |
139 |
| - "execution_count": null, |
140 |
| - "id": "2d1ce123", |
141 |
| - "metadata": {}, |
142 |
| - "outputs": [], |
143 |
| - "source": [ |
144 |
| - "seqlens = [128, 256, 512, 1024]\n", |
145 |
| - "batch, nheads, head_dim = 32, 8, 64\n", |
146 |
| - "pytorch_times, vllm_times, vllm_diffs = [], [], []\n", |
147 |
| - "\n", |
148 |
| - "for seqlen in seqlens:\n", |
149 |
| - " t_pt, t_vllm, d_vllm = benchmark_flash_attention(batch, nheads, head_dim, seqlen)\n", |
150 |
| - " pytorch_times.append(t_pt)\n", |
151 |
| - " vllm_times.append(t_vllm)\n", |
152 |
| - " vllm_diffs.append(d_vllm)\n", |
153 |
| - " print(f\"Seqlen={seqlen}: PyTorch CUDA={t_pt:.4f}s, vLLM CUDA={t_vllm:.4f}s, Diff(vLLM)={d_vllm:.2e}\")" |
154 |
| - ] |
155 |
| - }, |
156 |
| - { |
157 |
| - "cell_type": "code", |
158 |
| - "execution_count": null, |
159 |
| - "id": "b99804e7-4693-445e-a473-e2c243f77f70", |
160 |
| - "metadata": {}, |
161 |
| - "outputs": [], |
162 |
| - "source": [ |
163 |
| - "!ls /workspace/.triton/cache" |
164 |
| - ] |
165 |
| - }, |
166 |
| - { |
167 |
| - "cell_type": "code", |
168 |
| - "execution_count": null, |
169 |
| - "id": "6b8fe26d", |
170 |
| - "metadata": {}, |
171 |
| - "outputs": [], |
172 |
| - "source": [ |
173 |
| - "plt.figure()\n", |
174 |
| - "plt.plot(seqlens, pytorch_times, label=\"PyTorch SDPA (CUDA)\")\n", |
175 |
| - "plt.plot(seqlens, vllm_times, label=\"vLLM Flash Attention (CUDA)\")\n", |
176 |
| - "plt.xlabel(\"Sequence Length\")\n", |
177 |
| - "plt.ylabel(\"Time (s)\")\n", |
178 |
| - "plt.title(\"Flash Attention Performance: PyTorch vs vLLM on CUDA\")\n", |
179 |
| - "plt.legend()\n", |
180 |
| - "plt.grid()\n", |
181 |
| - "plt.show()" |
182 |
| - ] |
183 |
| - }, |
184 |
| - { |
185 |
| - "cell_type": "markdown", |
186 |
| - "id": "212e60a2", |
187 |
| - "metadata": {}, |
188 |
| - "source": [ |
189 |
| - "## What is Triton Autotuning?\n", |
190 |
| - "Triton allows kernels to be **autotuned**, meaning it will try multiple kernel configurations (block sizes, warp counts, pipeline stages) to find the optimal setup for your specific GPU hardware and workload shape.\n", |
191 |
| - "\n", |
192 |
| - "This autotuning process significantly improves performance and ensures the kernel is utilizing the GPU most efficiently.\n", |
193 |
| - "\n", |
194 |
| - "**How does it work?** \n", |
195 |
| - "- Triton runs benchmarks internally with different configurations. \n", |
196 |
| - "- It measures which configurations are fastest. \n", |
197 |
| - "- The result is cached, so future runs use the best-found setup.\n", |
198 |
| - "\n", |
199 |
| - "**Why do we re-run tuning?** \n", |
200 |
| - "- Hardware setups or driver versions may change. \n", |
201 |
| - "- Workload shapes (sequence lengths, batch sizes) might differ from defaults. \n", |
202 |
| - "- We want to confirm we’re using the best configuration for *this exact benchmark*.\n", |
203 |
| - "\n", |
204 |
| - "In the next cell, we trigger this autotuning pass.\n" |
205 |
| - ] |
206 |
| - }, |
207 |
| - { |
208 |
| - "cell_type": "code", |
209 |
| - "execution_count": null, |
210 |
| - "id": "f972971b-fef9-4814-926a-da87432b47bb", |
211 |
| - "metadata": {}, |
212 |
| - "outputs": [], |
213 |
| - "source": [ |
214 |
| - "# Trigger re-tuning (will reuse cached or search if needed)\n", |
215 |
| - "vllm_benchmark.run(show_plots=False, print_data=True)" |
216 |
| - ] |
217 |
| - }, |
218 |
| - { |
219 |
| - "cell_type": "code", |
220 |
| - "execution_count": null, |
221 |
| - "id": "1ec96d79-50fe-4814-a668-27c95ceaa04f", |
222 |
| - "metadata": {}, |
223 |
| - "outputs": [], |
224 |
| - "source": [ |
225 |
| - "vllm_tuned_times = []\n", |
226 |
| - "\n", |
227 |
| - "for seqlen in seqlens:\n", |
228 |
| - " q = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n", |
229 |
| - " k = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n", |
230 |
| - " v = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n", |
231 |
| - "\n", |
232 |
| - " torch.cuda.synchronize()\n", |
233 |
| - " start = time.time()\n", |
234 |
| - " out_vllm_tuned = run_vllm_flash_attention(q, k, v, seqlen)\n", |
235 |
| - " torch.cuda.synchronize()\n", |
236 |
| - " tuned_time = time.time() - start\n", |
237 |
| - " vllm_tuned_times.append(tuned_time)\n", |
238 |
| - " print(f\"Seqlen={seqlen}: Tuned vLLM CUDA={tuned_time:.4f}s\")" |
239 |
| - ] |
240 |
| - }, |
241 |
| - { |
242 |
| - "cell_type": "code", |
243 |
| - "execution_count": null, |
244 |
| - "id": "0a7e9593-48a3-42f0-a94d-6a45a4388354", |
245 |
| - "metadata": {}, |
246 |
| - "outputs": [], |
247 |
| - "source": [ |
248 |
| - "print(f\"{'SeqLen':>8} | {'PyTorch Time (s)':>18} | {'vLLM Tuned Time (s)':>20} | {'Speedup (PyTorch/vLLM)':>24}\")\n", |
249 |
| - "print(\"-\" * 75)\n", |
250 |
| - "for seqlen, pt_time, tuned_time in zip(seqlens, pytorch_times, vllm_tuned_times):\n", |
251 |
| - " speedup = pt_time / tuned_time\n", |
252 |
| - " print(f\"{seqlen:8} | {pt_time:18.6f} | {tuned_time:20.6f} | {speedup:24.2f}x\")\n" |
253 |
| - ] |
254 |
| - }, |
255 |
| - { |
256 |
| - "cell_type": "code", |
257 |
| - "execution_count": null, |
258 |
| - "id": "627a2f59-96ae-4994-854e-b1c9259529bd", |
259 |
| - "metadata": {}, |
260 |
| - "outputs": [], |
261 |
| - "source": [ |
262 |
| - "plt.figure()\n", |
263 |
| - "plt.plot(seqlens, pytorch_times, label=\"PyTorch SDPA (CUDA)\")\n", |
264 |
| - "plt.plot(seqlens, vllm_times, label=\"vLLM (Original)\")\n", |
265 |
| - "plt.plot(seqlens, vllm_tuned_times, label=\"vLLM (Autotuned)\")\n", |
266 |
| - "plt.xlabel(\"Sequence Length\")\n", |
267 |
| - "plt.ylabel(\"Time (s)\")\n", |
268 |
| - "plt.title(\"Flash Attention Benchmark: PyTorch vs vLLM (Before & After Autotune)\")\n", |
269 |
| - "plt.legend()\n", |
270 |
| - "plt.grid()\n", |
271 |
| - "plt.show()" |
272 |
| - ] |
273 |
| - }, |
274 |
| - { |
275 |
| - "cell_type": "code", |
276 |
| - "execution_count": null, |
277 |
| - "id": "fd225668-944f-4058-a20e-973188c7442e", |
278 |
| - "metadata": {}, |
279 |
| - "outputs": [], |
280 |
| - "source": [ |
281 |
| - "plt.figure()\n", |
282 |
| - "plt.plot(seqlens, pytorch_times, label=\"PyTorch SDPA (CUDA)\")\n", |
283 |
| - "plt.plot(seqlens, vllm_tuned_times, label=\"vLLM (Autotuned)\")\n", |
284 |
| - "plt.xlabel(\"Sequence Length\")\n", |
285 |
| - "plt.ylabel(\"Time (s)\")\n", |
286 |
| - "plt.title(\"Flash Attention Benchmark: PyTorch vs vLLM (After Autotune)\")\n", |
287 |
| - "plt.legend()\n", |
288 |
| - "plt.grid()\n", |
289 |
| - "plt.show()" |
290 |
| - ] |
291 |
| - }, |
292 |
| - { |
293 |
| - "cell_type": "code", |
294 |
| - "execution_count": null, |
295 |
| - "id": "487d9793-b35d-4fa6-ab05-7843d5fe96b5", |
296 |
| - "metadata": {}, |
297 |
| - "outputs": [], |
298 |
| - "source": [] |
299 |
| - } |
300 |
| - ], |
301 |
| - "metadata": { |
302 |
| - "kernelspec": { |
303 |
| - "display_name": "Python 3 (ipykernel)", |
304 |
| - "language": "python", |
305 |
| - "name": "python3" |
306 |
| - }, |
307 |
| - "language_info": { |
308 |
| - "codemirror_mode": { |
309 |
| - "name": "ipython", |
310 |
| - "version": 3 |
311 |
| - }, |
312 |
| - "file_extension": ".py", |
313 |
| - "mimetype": "text/x-python", |
314 |
| - "name": "python", |
315 |
| - "nbconvert_exporter": "python", |
316 |
| - "pygments_lexer": "ipython3", |
317 |
| - "version": "3.12.5" |
318 |
| - } |
319 |
| - }, |
320 |
| - "nbformat": 4, |
321 |
| - "nbformat_minor": 5 |
322 |
| - } |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "9012182b", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "## Summary\n", |
| 9 | + "\n", |
| 10 | + "This notebook benchmarks PyTorch's scaled_dot_product_attention (SDPA) against a vLLMs Triton-based flash attention kernel.\n", |
| 11 | + "\n", |
| 12 | + "Key highlights:\n", |
| 13 | + "- Environment Setup: GPU checks and Triton installation.\n", |
| 14 | + "- Baseline Performance: Measure PyTorch SDPA runtimes for various sequence lengths.\n", |
| 15 | + "- vLLM Triton Kernel Benchmark: Compare initial vLLM kernel performance vs. PyTorch.\n", |
| 16 | + " - Triton Autotuning & Caching:\n", |
| 17 | + " - The first run triggers autotuning (testing multiple configurations), making it slower.\n", |
| 18 | + " - The best configuration is cached for future runs.\n", |
| 19 | + " - Subsequent runs reuse the cached kernel and run significantly faster without re-tuning.\n", |
| 20 | + "- Visualization: Clear plots show performance improvements before and after autotuning.\n", |
| 21 | + "- Speedup Summary: A table and plots demonstrate consistent 2-4x speedups compared to PyTorch after caching.\n" |
| 22 | + ] |
| 23 | + }, |
| 24 | + { |
| 25 | + "cell_type": "code", |
| 26 | + "execution_count": null, |
| 27 | + "id": "29482448-d69e-4412-82b8-e6b8243699fe", |
| 28 | + "metadata": {}, |
| 29 | + "outputs": [], |
| 30 | + "source": [ |
| 31 | + "!python triton-gpu-check.py" |
| 32 | + ] |
| 33 | + }, |
| 34 | + { |
| 35 | + "cell_type": "code", |
| 36 | + "execution_count": null, |
| 37 | + "id": "591e3d51-8a52-45b5-b6a8-99acd0b1180c", |
| 38 | + "metadata": {}, |
| 39 | + "outputs": [], |
| 40 | + "source": [ |
| 41 | + "!cd triton && pip install ./python && cd -" |
| 42 | + ] |
| 43 | + }, |
| 44 | + { |
| 45 | + "cell_type": "code", |
| 46 | + "execution_count": null, |
| 47 | + "id": "2278ff2f", |
| 48 | + "metadata": {}, |
| 49 | + "outputs": [], |
| 50 | + "source": [ |
| 51 | + "import torch\n", |
| 52 | + "import triton\n", |
| 53 | + "import triton.language as tl\n", |
| 54 | + "import matplotlib.pyplot as plt\n", |
| 55 | + "import time\n", |
| 56 | + "\n", |
| 57 | + "print(\"Torch version:\", torch.__version__)\n", |
| 58 | + "print(\"Triton version:\", triton.__version__)" |
| 59 | + ] |
| 60 | + }, |
| 61 | + { |
| 62 | + "cell_type": "markdown", |
| 63 | + "id": "75bc884d", |
| 64 | + "metadata": {}, |
| 65 | + "source": [ |
| 66 | + "## Flash Attention Benchmark (PyTorch SDPA vs vLLM Kernel)\n", |
| 67 | + "This notebook benchmarks the PyTorch `scaled_dot_product_attention` against the vLLM Triton-based flash attention kernel." |
| 68 | + ] |
| 69 | + }, |
| 70 | + { |
| 71 | + "cell_type": "code", |
| 72 | + "execution_count": null, |
| 73 | + "id": "vllm-import", |
| 74 | + "metadata": {}, |
| 75 | + "outputs": [], |
| 76 | + "source": [ |
| 77 | + "# Assuming vllm_flash_attention.py is present in the same directory or accessible path\n", |
| 78 | + "from flash_attention import triton_attention as vllm_flash_attention\n", |
| 79 | + "from flash_attention import benchmark_flash_attention as vllm_benchmark\n" |
| 80 | + ] |
| 81 | + }, |
| 82 | + { |
| 83 | + "cell_type": "code", |
| 84 | + "execution_count": null, |
| 85 | + "id": "e4958854-eb3a-44d8-a52e-c4f77cefce94", |
| 86 | + "metadata": {}, |
| 87 | + "outputs": [], |
| 88 | + "source": [ |
| 89 | + "!ls /workspace/.triton/cache" |
| 90 | + ] |
| 91 | + }, |
| 92 | + { |
| 93 | + "cell_type": "code", |
| 94 | + "execution_count": null, |
| 95 | + "id": "bc0bbe33", |
| 96 | + "metadata": {}, |
| 97 | + "outputs": [], |
| 98 | + "source": [ |
| 99 | + "def run_pytorch_sdpa(q, k, v):\n", |
| 100 | + " return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0)" |
| 101 | + ] |
| 102 | + }, |
| 103 | + { |
| 104 | + "cell_type": "code", |
| 105 | + "execution_count": null, |
| 106 | + "id": "8ac54ef2-2877-49bb-aca6-c7ca2e6d55d2", |
| 107 | + "metadata": {}, |
| 108 | + "outputs": [], |
| 109 | + "source": [ |
| 110 | + "!ls /workspace/.triton/cache" |
| 111 | + ] |
| 112 | + }, |
| 113 | + { |
| 114 | + "cell_type": "code", |
| 115 | + "execution_count": null, |
| 116 | + "id": "vllm-kernel-wrapper", |
| 117 | + "metadata": {}, |
| 118 | + "outputs": [], |
| 119 | + "source": [ |
| 120 | + "def run_vllm_flash_attention(q, k, v, seqlen):\n", |
| 121 | + " q_flat = q.permute(0, 2, 1, 3).reshape(-1, q.shape[1], q.shape[3])\n", |
| 122 | + " k_flat = k.permute(0, 2, 1, 3).reshape(-1, k.shape[1], k.shape[3])\n", |
| 123 | + " v_flat = v.permute(0, 2, 1, 3).reshape(-1, v.shape[1], v.shape[3])\n", |
| 124 | + " cu_seqlens_q = torch.arange(0, q.shape[0] + 1, dtype=torch.int32, device=q.device) * seqlen\n", |
| 125 | + " cu_seqlens_k = torch.arange(0, q.shape[0] + 1, dtype=torch.int32, device=q.device) * seqlen\n", |
| 126 | + " o, _ = vllm_flash_attention(q_flat, k_flat, v_flat, None, cu_seqlens_q, cu_seqlens_k, seqlen, seqlen, False, 1.0, None)\n", |
| 127 | + " return o.view(q.shape[0], seqlen, q.shape[1], q.shape[3]).permute(0, 2, 1, 3)" |
| 128 | + ] |
| 129 | + }, |
| 130 | + { |
| 131 | + "cell_type": "code", |
| 132 | + "execution_count": null, |
| 133 | + "id": "b911bed7", |
| 134 | + "metadata": {}, |
| 135 | + "outputs": [], |
| 136 | + "source": [ |
| 137 | + "def benchmark_flash_attention(batch, nheads, head_dim, seqlen):\n", |
| 138 | + " q = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n", |
| 139 | + " k = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n", |
| 140 | + " v = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n", |
| 141 | + "\n", |
| 142 | + " torch.cuda.synchronize()\n", |
| 143 | + " start = time.time()\n", |
| 144 | + " out_torch = run_pytorch_sdpa(q, k, v)\n", |
| 145 | + " torch.cuda.synchronize()\n", |
| 146 | + " pytorch_time = time.time() - start\n", |
| 147 | + "\n", |
| 148 | + " torch.cuda.synchronize()\n", |
| 149 | + " start = time.time()\n", |
| 150 | + " out_vllm = run_vllm_flash_attention(q, k, v, seqlen)\n", |
| 151 | + " torch.cuda.synchronize()\n", |
| 152 | + " vllm_time = time.time() - start\n", |
| 153 | + "\n", |
| 154 | + " diff_vllm = torch.max(torch.abs(out_torch - out_vllm)).item()\n", |
| 155 | + " return pytorch_time, vllm_time, diff_vllm" |
| 156 | + ] |
| 157 | + }, |
| 158 | + { |
| 159 | + "cell_type": "code", |
| 160 | + "execution_count": null, |
| 161 | + "id": "2d1ce123", |
| 162 | + "metadata": {}, |
| 163 | + "outputs": [], |
| 164 | + "source": [ |
| 165 | + "seqlens = [128, 256, 512, 1024]\n", |
| 166 | + "batch, nheads, head_dim = 32, 8, 64\n", |
| 167 | + "pytorch_times, vllm_times, vllm_diffs = [], [], []\n", |
| 168 | + "\n", |
| 169 | + "for seqlen in seqlens:\n", |
| 170 | + " t_pt, t_vllm, d_vllm = benchmark_flash_attention(batch, nheads, head_dim, seqlen)\n", |
| 171 | + " pytorch_times.append(t_pt)\n", |
| 172 | + " vllm_times.append(t_vllm)\n", |
| 173 | + " vllm_diffs.append(d_vllm)\n", |
| 174 | + " print(f\"Seqlen={seqlen}: PyTorch CUDA={t_pt:.4f}s, vLLM CUDA={t_vllm:.4f}s, Diff(vLLM)={d_vllm:.2e}\")" |
| 175 | + ] |
| 176 | + }, |
| 177 | + { |
| 178 | + "cell_type": "code", |
| 179 | + "execution_count": null, |
| 180 | + "id": "b99804e7-4693-445e-a473-e2c243f77f70", |
| 181 | + "metadata": {}, |
| 182 | + "outputs": [], |
| 183 | + "source": [ |
| 184 | + "!ls /workspace/.triton/cache" |
| 185 | + ] |
| 186 | + }, |
| 187 | + { |
| 188 | + "cell_type": "code", |
| 189 | + "execution_count": null, |
| 190 | + "id": "6b8fe26d", |
| 191 | + "metadata": {}, |
| 192 | + "outputs": [], |
| 193 | + "source": [ |
| 194 | + "plt.figure()\n", |
| 195 | + "plt.plot(seqlens, pytorch_times, label=\"PyTorch SDPA (CUDA)\")\n", |
| 196 | + "plt.plot(seqlens, vllm_times, label=\"vLLM Flash Attention (CUDA)\")\n", |
| 197 | + "plt.xlabel(\"Sequence Length\")\n", |
| 198 | + "plt.ylabel(\"Time (s)\")\n", |
| 199 | + "plt.title(\"Flash Attention Performance: PyTorch vs vLLM on CUDA\")\n", |
| 200 | + "plt.legend()\n", |
| 201 | + "plt.grid()\n", |
| 202 | + "plt.show()" |
| 203 | + ] |
| 204 | + }, |
| 205 | + { |
| 206 | + "cell_type": "markdown", |
| 207 | + "id": "212e60a2", |
| 208 | + "metadata": {}, |
| 209 | + "source": [ |
| 210 | + "## What is Triton Autotuning?\n", |
| 211 | + "Triton allows kernels to be **autotuned**, meaning it will try multiple kernel configurations (block sizes, warp counts, pipeline stages) to find the optimal setup for your specific GPU hardware and workload shape.\n", |
| 212 | + "\n", |
| 213 | + "This autotuning process significantly improves performance and ensures the kernel is utilizing the GPU most efficiently.\n", |
| 214 | + "\n", |
| 215 | + "**How does it work?** \n", |
| 216 | + "- Triton runs benchmarks internally with different configurations. \n", |
| 217 | + "- It measures which configurations are fastest. \n", |
| 218 | + "- The result is cached, so future runs use the best-found setup.\n", |
| 219 | + "\n", |
| 220 | + "**Why do we re-run tuning?** \n", |
| 221 | + "- Hardware setups or driver versions may change. \n", |
| 222 | + "- Workload shapes (sequence lengths, batch sizes) might differ from defaults. \n", |
| 223 | + "- We want to confirm we’re using the best configuration for *this exact benchmark*.\n", |
| 224 | + "\n", |
| 225 | + "In the next cell, we trigger this autotuning pass.\n" |
| 226 | + ] |
| 227 | + }, |
| 228 | + { |
| 229 | + "cell_type": "markdown", |
| 230 | + "id": "51ef7113", |
| 231 | + "metadata": {}, |
| 232 | + "source": [ |
| 233 | + "## Note on Triton Autotuning and Caching Example\n", |
| 234 | + "\n", |
| 235 | + "- On the **first run**, when a specific kernel configuration (based on GPU hardware, batch size, sequence length, and head dimensions) is encountered for the first time, **Triton triggers autotuning**. \n", |
| 236 | + " - This process tries multiple kernel configurations in the background and picks the fastest one.\n", |
| 237 | + " - As a result, the **first run may be significantly slower** due to this tuning process.\n", |
| 238 | + "\n", |
| 239 | + "- Once the best-performing configuration is found, it is **stored in Triton's cache** (typically in `/workspace/.triton/cache`).\n", |
| 240 | + "\n", |
| 241 | + "- On **subsequent runs** with the same input shape and environment:\n", |
| 242 | + " - Triton **loads the tuned configuration from cache** and skips tuning.\n", |
| 243 | + " - This leads to **consistently fast kernel launches and execution** without re-tuning overhead.\n", |
| 244 | + "\n", |
| 245 | + "- If you clear the cache, the next run will re-trigger autotuning.\n", |
| 246 | + "\n", |
| 247 | + "> In short: \n", |
| 248 | + "> - First run = autotuning + execution (slow but smart) \n", |
| 249 | + "> - All future runs = cached config + execution (fast and efficient)\n" |
| 250 | + ] |
| 251 | + }, |
| 252 | + { |
| 253 | + "cell_type": "code", |
| 254 | + "execution_count": null, |
| 255 | + "id": "f972971b-fef9-4814-926a-da87432b47bb", |
| 256 | + "metadata": {}, |
| 257 | + "outputs": [], |
| 258 | + "source": [ |
| 259 | + "# Trigger re-tuning (will reuse cached or search if needed)\n", |
| 260 | + "vllm_benchmark.run(show_plots=False, print_data=True)" |
| 261 | + ] |
| 262 | + }, |
| 263 | + { |
| 264 | + "cell_type": "code", |
| 265 | + "execution_count": null, |
| 266 | + "id": "1ec96d79-50fe-4814-a668-27c95ceaa04f", |
| 267 | + "metadata": {}, |
| 268 | + "outputs": [], |
| 269 | + "source": [ |
| 270 | + "vllm_tuned_times = []\n", |
| 271 | + "\n", |
| 272 | + "for seqlen in seqlens:\n", |
| 273 | + " q = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n", |
| 274 | + " k = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n", |
| 275 | + " v = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n", |
| 276 | + "\n", |
| 277 | + " torch.cuda.synchronize()\n", |
| 278 | + " start = time.time()\n", |
| 279 | + " out_vllm_tuned = run_vllm_flash_attention(q, k, v, seqlen)\n", |
| 280 | + " torch.cuda.synchronize()\n", |
| 281 | + " tuned_time = time.time() - start\n", |
| 282 | + " vllm_tuned_times.append(tuned_time)\n", |
| 283 | + " print(f\"Seqlen={seqlen}: Tuned vLLM CUDA={tuned_time:.4f}s\")" |
| 284 | + ] |
| 285 | + }, |
| 286 | + { |
| 287 | + "cell_type": "code", |
| 288 | + "execution_count": null, |
| 289 | + "id": "0a7e9593-48a3-42f0-a94d-6a45a4388354", |
| 290 | + "metadata": {}, |
| 291 | + "outputs": [], |
| 292 | + "source": [ |
| 293 | + "print(f\"{'SeqLen':>8} | {'PyTorch Time (s)':>18} | {'vLLM Tuned Time (s)':>20} | {'Speedup (PyTorch/vLLM)':>24}\")\n", |
| 294 | + "print(\"-\" * 75)\n", |
| 295 | + "for seqlen, pt_time, tuned_time in zip(seqlens, pytorch_times, vllm_tuned_times):\n", |
| 296 | + " speedup = pt_time / tuned_time\n", |
| 297 | + " print(f\"{seqlen:8} | {pt_time:18.6f} | {tuned_time:20.6f} | {speedup:24.2f}x\")\n" |
| 298 | + ] |
| 299 | + }, |
| 300 | + { |
| 301 | + "cell_type": "code", |
| 302 | + "execution_count": null, |
| 303 | + "id": "627a2f59-96ae-4994-854e-b1c9259529bd", |
| 304 | + "metadata": {}, |
| 305 | + "outputs": [], |
| 306 | + "source": [ |
| 307 | + "plt.figure()\n", |
| 308 | + "plt.plot(seqlens, pytorch_times, label=\"PyTorch SDPA (CUDA)\")\n", |
| 309 | + "plt.plot(seqlens, vllm_times, label=\"vLLM (Original)\")\n", |
| 310 | + "plt.plot(seqlens, vllm_tuned_times, label=\"vLLM (Autotuned)\")\n", |
| 311 | + "plt.xlabel(\"Sequence Length\")\n", |
| 312 | + "plt.ylabel(\"Time (s)\")\n", |
| 313 | + "plt.title(\"Flash Attention Benchmark: PyTorch vs vLLM (Before & After Autotune)\")\n", |
| 314 | + "plt.legend()\n", |
| 315 | + "plt.grid()\n", |
| 316 | + "plt.show()" |
| 317 | + ] |
| 318 | + }, |
| 319 | + { |
| 320 | + "cell_type": "code", |
| 321 | + "execution_count": null, |
| 322 | + "id": "fd225668-944f-4058-a20e-973188c7442e", |
| 323 | + "metadata": {}, |
| 324 | + "outputs": [], |
| 325 | + "source": [ |
| 326 | + "plt.figure()\n", |
| 327 | + "plt.plot(seqlens, pytorch_times, label=\"PyTorch SDPA (CUDA)\")\n", |
| 328 | + "plt.plot(seqlens, vllm_tuned_times, label=\"vLLM (Autotuned)\")\n", |
| 329 | + "plt.xlabel(\"Sequence Length\")\n", |
| 330 | + "plt.ylabel(\"Time (s)\")\n", |
| 331 | + "plt.title(\"Flash Attention Benchmark: PyTorch vs vLLM (After Autotune)\")\n", |
| 332 | + "plt.legend()\n", |
| 333 | + "plt.grid()\n", |
| 334 | + "plt.show()" |
| 335 | + ] |
| 336 | + }, |
| 337 | + { |
| 338 | + "cell_type": "code", |
| 339 | + "execution_count": null, |
| 340 | + "id": "487d9793-b35d-4fa6-ab05-7843d5fe96b5", |
| 341 | + "metadata": {}, |
| 342 | + "outputs": [], |
| 343 | + "source": [] |
| 344 | + } |
| 345 | + ], |
| 346 | + "metadata": { |
| 347 | + "kernelspec": { |
| 348 | + "display_name": "Python 3 (ipykernel)", |
| 349 | + "language": "python", |
| 350 | + "name": "python3" |
| 351 | + }, |
| 352 | + "language_info": { |
| 353 | + "codemirror_mode": { |
| 354 | + "name": "ipython", |
| 355 | + "version": 3 |
| 356 | + }, |
| 357 | + "file_extension": ".py", |
| 358 | + "mimetype": "text/x-python", |
| 359 | + "name": "python", |
| 360 | + "nbconvert_exporter": "python", |
| 361 | + "pygments_lexer": "ipython3", |
| 362 | + "version": "3.12.5" |
| 363 | + } |
| 364 | + }, |
| 365 | + "nbformat": 4, |
| 366 | + "nbformat_minor": 5 |
| 367 | +} |
0 commit comments