Skip to content

Commit 42bab0a

Browse files
committedMar 20, 2025·
chore: add some more details to the jupyter notebook
Signed-off-by: Maryam Tahhan <mtahhan@redhat.com>
1 parent 9ba3b09 commit 42bab0a

File tree

1 file changed

+366
-321
lines changed

1 file changed

+366
-321
lines changed
 

‎examples/flash_attention_demo.ipynb

+366-321
Original file line numberDiff line numberDiff line change
@@ -1,322 +1,367 @@
11
{
2-
"cells": [
3-
{
4-
"cell_type": "code",
5-
"execution_count": null,
6-
"id": "29482448-d69e-4412-82b8-e6b8243699fe",
7-
"metadata": {},
8-
"outputs": [],
9-
"source": [
10-
"!python triton-gpu-check.py"
11-
]
12-
},
13-
{
14-
"cell_type": "code",
15-
"execution_count": null,
16-
"id": "591e3d51-8a52-45b5-b6a8-99acd0b1180c",
17-
"metadata": {},
18-
"outputs": [],
19-
"source": [
20-
"!cd triton && pip install ./python && cd -"
21-
]
22-
},
23-
{
24-
"cell_type": "code",
25-
"execution_count": null,
26-
"id": "2278ff2f",
27-
"metadata": {},
28-
"outputs": [],
29-
"source": [
30-
"import torch\n",
31-
"import triton\n",
32-
"import triton.language as tl\n",
33-
"import matplotlib.pyplot as plt\n",
34-
"import time\n",
35-
"\n",
36-
"print(\"Torch version:\", torch.__version__)\n",
37-
"print(\"Triton version:\", triton.__version__)"
38-
]
39-
},
40-
{
41-
"cell_type": "markdown",
42-
"id": "75bc884d",
43-
"metadata": {},
44-
"source": [
45-
"## Flash Attention Benchmark (PyTorch SDPA vs vLLM Kernel)\n",
46-
"This notebook benchmarks the PyTorch `scaled_dot_product_attention` against the vLLM Triton-based flash attention kernel."
47-
]
48-
},
49-
{
50-
"cell_type": "code",
51-
"execution_count": null,
52-
"id": "vllm-import",
53-
"metadata": {},
54-
"outputs": [],
55-
"source": [
56-
"# Assuming vllm_flash_attention.py is present in the same directory or accessible path\n",
57-
"from flash_attention import triton_attention as vllm_flash_attention\n",
58-
"from flash_attention import benchmark_flash_attention as vllm_benchmark\n"
59-
]
60-
},
61-
{
62-
"cell_type": "code",
63-
"execution_count": null,
64-
"id": "e4958854-eb3a-44d8-a52e-c4f77cefce94",
65-
"metadata": {},
66-
"outputs": [],
67-
"source": [
68-
"!ls /workspace/.triton/cache"
69-
]
70-
},
71-
{
72-
"cell_type": "code",
73-
"execution_count": null,
74-
"id": "bc0bbe33",
75-
"metadata": {},
76-
"outputs": [],
77-
"source": [
78-
"def run_pytorch_sdpa(q, k, v):\n",
79-
" return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0)"
80-
]
81-
},
82-
{
83-
"cell_type": "code",
84-
"execution_count": null,
85-
"id": "8ac54ef2-2877-49bb-aca6-c7ca2e6d55d2",
86-
"metadata": {},
87-
"outputs": [],
88-
"source": [
89-
"!ls /workspace/.triton/cache"
90-
]
91-
},
92-
{
93-
"cell_type": "code",
94-
"execution_count": null,
95-
"id": "vllm-kernel-wrapper",
96-
"metadata": {},
97-
"outputs": [],
98-
"source": [
99-
"def run_vllm_flash_attention(q, k, v, seqlen):\n",
100-
" q_flat = q.permute(0, 2, 1, 3).reshape(-1, q.shape[1], q.shape[3])\n",
101-
" k_flat = k.permute(0, 2, 1, 3).reshape(-1, k.shape[1], k.shape[3])\n",
102-
" v_flat = v.permute(0, 2, 1, 3).reshape(-1, v.shape[1], v.shape[3])\n",
103-
" cu_seqlens_q = torch.arange(0, q.shape[0] + 1, dtype=torch.int32, device=q.device) * seqlen\n",
104-
" cu_seqlens_k = torch.arange(0, q.shape[0] + 1, dtype=torch.int32, device=q.device) * seqlen\n",
105-
" o, _ = vllm_flash_attention(q_flat, k_flat, v_flat, None, cu_seqlens_q, cu_seqlens_k, seqlen, seqlen, False, 1.0, None)\n",
106-
" return o.view(q.shape[0], seqlen, q.shape[1], q.shape[3]).permute(0, 2, 1, 3)"
107-
]
108-
},
109-
{
110-
"cell_type": "code",
111-
"execution_count": null,
112-
"id": "b911bed7",
113-
"metadata": {},
114-
"outputs": [],
115-
"source": [
116-
"def benchmark_flash_attention(batch, nheads, head_dim, seqlen):\n",
117-
" q = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n",
118-
" k = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n",
119-
" v = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n",
120-
"\n",
121-
" torch.cuda.synchronize()\n",
122-
" start = time.time()\n",
123-
" out_torch = run_pytorch_sdpa(q, k, v)\n",
124-
" torch.cuda.synchronize()\n",
125-
" pytorch_time = time.time() - start\n",
126-
"\n",
127-
" torch.cuda.synchronize()\n",
128-
" start = time.time()\n",
129-
" out_vllm = run_vllm_flash_attention(q, k, v, seqlen)\n",
130-
" torch.cuda.synchronize()\n",
131-
" vllm_time = time.time() - start\n",
132-
"\n",
133-
" diff_vllm = torch.max(torch.abs(out_torch - out_vllm)).item()\n",
134-
" return pytorch_time, vllm_time, diff_vllm"
135-
]
136-
},
137-
{
138-
"cell_type": "code",
139-
"execution_count": null,
140-
"id": "2d1ce123",
141-
"metadata": {},
142-
"outputs": [],
143-
"source": [
144-
"seqlens = [128, 256, 512, 1024]\n",
145-
"batch, nheads, head_dim = 32, 8, 64\n",
146-
"pytorch_times, vllm_times, vllm_diffs = [], [], []\n",
147-
"\n",
148-
"for seqlen in seqlens:\n",
149-
" t_pt, t_vllm, d_vllm = benchmark_flash_attention(batch, nheads, head_dim, seqlen)\n",
150-
" pytorch_times.append(t_pt)\n",
151-
" vllm_times.append(t_vllm)\n",
152-
" vllm_diffs.append(d_vllm)\n",
153-
" print(f\"Seqlen={seqlen}: PyTorch CUDA={t_pt:.4f}s, vLLM CUDA={t_vllm:.4f}s, Diff(vLLM)={d_vllm:.2e}\")"
154-
]
155-
},
156-
{
157-
"cell_type": "code",
158-
"execution_count": null,
159-
"id": "b99804e7-4693-445e-a473-e2c243f77f70",
160-
"metadata": {},
161-
"outputs": [],
162-
"source": [
163-
"!ls /workspace/.triton/cache"
164-
]
165-
},
166-
{
167-
"cell_type": "code",
168-
"execution_count": null,
169-
"id": "6b8fe26d",
170-
"metadata": {},
171-
"outputs": [],
172-
"source": [
173-
"plt.figure()\n",
174-
"plt.plot(seqlens, pytorch_times, label=\"PyTorch SDPA (CUDA)\")\n",
175-
"plt.plot(seqlens, vllm_times, label=\"vLLM Flash Attention (CUDA)\")\n",
176-
"plt.xlabel(\"Sequence Length\")\n",
177-
"plt.ylabel(\"Time (s)\")\n",
178-
"plt.title(\"Flash Attention Performance: PyTorch vs vLLM on CUDA\")\n",
179-
"plt.legend()\n",
180-
"plt.grid()\n",
181-
"plt.show()"
182-
]
183-
},
184-
{
185-
"cell_type": "markdown",
186-
"id": "212e60a2",
187-
"metadata": {},
188-
"source": [
189-
"## What is Triton Autotuning?\n",
190-
"Triton allows kernels to be **autotuned**, meaning it will try multiple kernel configurations (block sizes, warp counts, pipeline stages) to find the optimal setup for your specific GPU hardware and workload shape.\n",
191-
"\n",
192-
"This autotuning process significantly improves performance and ensures the kernel is utilizing the GPU most efficiently.\n",
193-
"\n",
194-
"**How does it work?** \n",
195-
"- Triton runs benchmarks internally with different configurations. \n",
196-
"- It measures which configurations are fastest. \n",
197-
"- The result is cached, so future runs use the best-found setup.\n",
198-
"\n",
199-
"**Why do we re-run tuning?** \n",
200-
"- Hardware setups or driver versions may change. \n",
201-
"- Workload shapes (sequence lengths, batch sizes) might differ from defaults. \n",
202-
"- We want to confirm we’re using the best configuration for *this exact benchmark*.\n",
203-
"\n",
204-
"In the next cell, we trigger this autotuning pass.\n"
205-
]
206-
},
207-
{
208-
"cell_type": "code",
209-
"execution_count": null,
210-
"id": "f972971b-fef9-4814-926a-da87432b47bb",
211-
"metadata": {},
212-
"outputs": [],
213-
"source": [
214-
"# Trigger re-tuning (will reuse cached or search if needed)\n",
215-
"vllm_benchmark.run(show_plots=False, print_data=True)"
216-
]
217-
},
218-
{
219-
"cell_type": "code",
220-
"execution_count": null,
221-
"id": "1ec96d79-50fe-4814-a668-27c95ceaa04f",
222-
"metadata": {},
223-
"outputs": [],
224-
"source": [
225-
"vllm_tuned_times = []\n",
226-
"\n",
227-
"for seqlen in seqlens:\n",
228-
" q = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n",
229-
" k = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n",
230-
" v = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n",
231-
"\n",
232-
" torch.cuda.synchronize()\n",
233-
" start = time.time()\n",
234-
" out_vllm_tuned = run_vllm_flash_attention(q, k, v, seqlen)\n",
235-
" torch.cuda.synchronize()\n",
236-
" tuned_time = time.time() - start\n",
237-
" vllm_tuned_times.append(tuned_time)\n",
238-
" print(f\"Seqlen={seqlen}: Tuned vLLM CUDA={tuned_time:.4f}s\")"
239-
]
240-
},
241-
{
242-
"cell_type": "code",
243-
"execution_count": null,
244-
"id": "0a7e9593-48a3-42f0-a94d-6a45a4388354",
245-
"metadata": {},
246-
"outputs": [],
247-
"source": [
248-
"print(f\"{'SeqLen':>8} | {'PyTorch Time (s)':>18} | {'vLLM Tuned Time (s)':>20} | {'Speedup (PyTorch/vLLM)':>24}\")\n",
249-
"print(\"-\" * 75)\n",
250-
"for seqlen, pt_time, tuned_time in zip(seqlens, pytorch_times, vllm_tuned_times):\n",
251-
" speedup = pt_time / tuned_time\n",
252-
" print(f\"{seqlen:8} | {pt_time:18.6f} | {tuned_time:20.6f} | {speedup:24.2f}x\")\n"
253-
]
254-
},
255-
{
256-
"cell_type": "code",
257-
"execution_count": null,
258-
"id": "627a2f59-96ae-4994-854e-b1c9259529bd",
259-
"metadata": {},
260-
"outputs": [],
261-
"source": [
262-
"plt.figure()\n",
263-
"plt.plot(seqlens, pytorch_times, label=\"PyTorch SDPA (CUDA)\")\n",
264-
"plt.plot(seqlens, vllm_times, label=\"vLLM (Original)\")\n",
265-
"plt.plot(seqlens, vllm_tuned_times, label=\"vLLM (Autotuned)\")\n",
266-
"plt.xlabel(\"Sequence Length\")\n",
267-
"plt.ylabel(\"Time (s)\")\n",
268-
"plt.title(\"Flash Attention Benchmark: PyTorch vs vLLM (Before & After Autotune)\")\n",
269-
"plt.legend()\n",
270-
"plt.grid()\n",
271-
"plt.show()"
272-
]
273-
},
274-
{
275-
"cell_type": "code",
276-
"execution_count": null,
277-
"id": "fd225668-944f-4058-a20e-973188c7442e",
278-
"metadata": {},
279-
"outputs": [],
280-
"source": [
281-
"plt.figure()\n",
282-
"plt.plot(seqlens, pytorch_times, label=\"PyTorch SDPA (CUDA)\")\n",
283-
"plt.plot(seqlens, vllm_tuned_times, label=\"vLLM (Autotuned)\")\n",
284-
"plt.xlabel(\"Sequence Length\")\n",
285-
"plt.ylabel(\"Time (s)\")\n",
286-
"plt.title(\"Flash Attention Benchmark: PyTorch vs vLLM (After Autotune)\")\n",
287-
"plt.legend()\n",
288-
"plt.grid()\n",
289-
"plt.show()"
290-
]
291-
},
292-
{
293-
"cell_type": "code",
294-
"execution_count": null,
295-
"id": "487d9793-b35d-4fa6-ab05-7843d5fe96b5",
296-
"metadata": {},
297-
"outputs": [],
298-
"source": []
299-
}
300-
],
301-
"metadata": {
302-
"kernelspec": {
303-
"display_name": "Python 3 (ipykernel)",
304-
"language": "python",
305-
"name": "python3"
306-
},
307-
"language_info": {
308-
"codemirror_mode": {
309-
"name": "ipython",
310-
"version": 3
311-
},
312-
"file_extension": ".py",
313-
"mimetype": "text/x-python",
314-
"name": "python",
315-
"nbconvert_exporter": "python",
316-
"pygments_lexer": "ipython3",
317-
"version": "3.12.5"
318-
}
319-
},
320-
"nbformat": 4,
321-
"nbformat_minor": 5
322-
}
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "9012182b",
6+
"metadata": {},
7+
"source": [
8+
"## Summary\n",
9+
"\n",
10+
"This notebook benchmarks PyTorch's scaled_dot_product_attention (SDPA) against a vLLMs Triton-based flash attention kernel.\n",
11+
"\n",
12+
"Key highlights:\n",
13+
"- Environment Setup: GPU checks and Triton installation.\n",
14+
"- Baseline Performance: Measure PyTorch SDPA runtimes for various sequence lengths.\n",
15+
"- vLLM Triton Kernel Benchmark: Compare initial vLLM kernel performance vs. PyTorch.\n",
16+
" - Triton Autotuning & Caching:\n",
17+
" - The first run triggers autotuning (testing multiple configurations), making it slower.\n",
18+
" - The best configuration is cached for future runs.\n",
19+
" - Subsequent runs reuse the cached kernel and run significantly faster without re-tuning.\n",
20+
"- Visualization: Clear plots show performance improvements before and after autotuning.\n",
21+
"- Speedup Summary: A table and plots demonstrate consistent 2-4x speedups compared to PyTorch after caching.\n"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"id": "29482448-d69e-4412-82b8-e6b8243699fe",
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"!python triton-gpu-check.py"
32+
]
33+
},
34+
{
35+
"cell_type": "code",
36+
"execution_count": null,
37+
"id": "591e3d51-8a52-45b5-b6a8-99acd0b1180c",
38+
"metadata": {},
39+
"outputs": [],
40+
"source": [
41+
"!cd triton && pip install ./python && cd -"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": null,
47+
"id": "2278ff2f",
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"import torch\n",
52+
"import triton\n",
53+
"import triton.language as tl\n",
54+
"import matplotlib.pyplot as plt\n",
55+
"import time\n",
56+
"\n",
57+
"print(\"Torch version:\", torch.__version__)\n",
58+
"print(\"Triton version:\", triton.__version__)"
59+
]
60+
},
61+
{
62+
"cell_type": "markdown",
63+
"id": "75bc884d",
64+
"metadata": {},
65+
"source": [
66+
"## Flash Attention Benchmark (PyTorch SDPA vs vLLM Kernel)\n",
67+
"This notebook benchmarks the PyTorch `scaled_dot_product_attention` against the vLLM Triton-based flash attention kernel."
68+
]
69+
},
70+
{
71+
"cell_type": "code",
72+
"execution_count": null,
73+
"id": "vllm-import",
74+
"metadata": {},
75+
"outputs": [],
76+
"source": [
77+
"# Assuming vllm_flash_attention.py is present in the same directory or accessible path\n",
78+
"from flash_attention import triton_attention as vllm_flash_attention\n",
79+
"from flash_attention import benchmark_flash_attention as vllm_benchmark\n"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"id": "e4958854-eb3a-44d8-a52e-c4f77cefce94",
86+
"metadata": {},
87+
"outputs": [],
88+
"source": [
89+
"!ls /workspace/.triton/cache"
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": null,
95+
"id": "bc0bbe33",
96+
"metadata": {},
97+
"outputs": [],
98+
"source": [
99+
"def run_pytorch_sdpa(q, k, v):\n",
100+
" return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0)"
101+
]
102+
},
103+
{
104+
"cell_type": "code",
105+
"execution_count": null,
106+
"id": "8ac54ef2-2877-49bb-aca6-c7ca2e6d55d2",
107+
"metadata": {},
108+
"outputs": [],
109+
"source": [
110+
"!ls /workspace/.triton/cache"
111+
]
112+
},
113+
{
114+
"cell_type": "code",
115+
"execution_count": null,
116+
"id": "vllm-kernel-wrapper",
117+
"metadata": {},
118+
"outputs": [],
119+
"source": [
120+
"def run_vllm_flash_attention(q, k, v, seqlen):\n",
121+
" q_flat = q.permute(0, 2, 1, 3).reshape(-1, q.shape[1], q.shape[3])\n",
122+
" k_flat = k.permute(0, 2, 1, 3).reshape(-1, k.shape[1], k.shape[3])\n",
123+
" v_flat = v.permute(0, 2, 1, 3).reshape(-1, v.shape[1], v.shape[3])\n",
124+
" cu_seqlens_q = torch.arange(0, q.shape[0] + 1, dtype=torch.int32, device=q.device) * seqlen\n",
125+
" cu_seqlens_k = torch.arange(0, q.shape[0] + 1, dtype=torch.int32, device=q.device) * seqlen\n",
126+
" o, _ = vllm_flash_attention(q_flat, k_flat, v_flat, None, cu_seqlens_q, cu_seqlens_k, seqlen, seqlen, False, 1.0, None)\n",
127+
" return o.view(q.shape[0], seqlen, q.shape[1], q.shape[3]).permute(0, 2, 1, 3)"
128+
]
129+
},
130+
{
131+
"cell_type": "code",
132+
"execution_count": null,
133+
"id": "b911bed7",
134+
"metadata": {},
135+
"outputs": [],
136+
"source": [
137+
"def benchmark_flash_attention(batch, nheads, head_dim, seqlen):\n",
138+
" q = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n",
139+
" k = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n",
140+
" v = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n",
141+
"\n",
142+
" torch.cuda.synchronize()\n",
143+
" start = time.time()\n",
144+
" out_torch = run_pytorch_sdpa(q, k, v)\n",
145+
" torch.cuda.synchronize()\n",
146+
" pytorch_time = time.time() - start\n",
147+
"\n",
148+
" torch.cuda.synchronize()\n",
149+
" start = time.time()\n",
150+
" out_vllm = run_vllm_flash_attention(q, k, v, seqlen)\n",
151+
" torch.cuda.synchronize()\n",
152+
" vllm_time = time.time() - start\n",
153+
"\n",
154+
" diff_vllm = torch.max(torch.abs(out_torch - out_vllm)).item()\n",
155+
" return pytorch_time, vllm_time, diff_vllm"
156+
]
157+
},
158+
{
159+
"cell_type": "code",
160+
"execution_count": null,
161+
"id": "2d1ce123",
162+
"metadata": {},
163+
"outputs": [],
164+
"source": [
165+
"seqlens = [128, 256, 512, 1024]\n",
166+
"batch, nheads, head_dim = 32, 8, 64\n",
167+
"pytorch_times, vllm_times, vllm_diffs = [], [], []\n",
168+
"\n",
169+
"for seqlen in seqlens:\n",
170+
" t_pt, t_vllm, d_vllm = benchmark_flash_attention(batch, nheads, head_dim, seqlen)\n",
171+
" pytorch_times.append(t_pt)\n",
172+
" vllm_times.append(t_vllm)\n",
173+
" vllm_diffs.append(d_vllm)\n",
174+
" print(f\"Seqlen={seqlen}: PyTorch CUDA={t_pt:.4f}s, vLLM CUDA={t_vllm:.4f}s, Diff(vLLM)={d_vllm:.2e}\")"
175+
]
176+
},
177+
{
178+
"cell_type": "code",
179+
"execution_count": null,
180+
"id": "b99804e7-4693-445e-a473-e2c243f77f70",
181+
"metadata": {},
182+
"outputs": [],
183+
"source": [
184+
"!ls /workspace/.triton/cache"
185+
]
186+
},
187+
{
188+
"cell_type": "code",
189+
"execution_count": null,
190+
"id": "6b8fe26d",
191+
"metadata": {},
192+
"outputs": [],
193+
"source": [
194+
"plt.figure()\n",
195+
"plt.plot(seqlens, pytorch_times, label=\"PyTorch SDPA (CUDA)\")\n",
196+
"plt.plot(seqlens, vllm_times, label=\"vLLM Flash Attention (CUDA)\")\n",
197+
"plt.xlabel(\"Sequence Length\")\n",
198+
"plt.ylabel(\"Time (s)\")\n",
199+
"plt.title(\"Flash Attention Performance: PyTorch vs vLLM on CUDA\")\n",
200+
"plt.legend()\n",
201+
"plt.grid()\n",
202+
"plt.show()"
203+
]
204+
},
205+
{
206+
"cell_type": "markdown",
207+
"id": "212e60a2",
208+
"metadata": {},
209+
"source": [
210+
"## What is Triton Autotuning?\n",
211+
"Triton allows kernels to be **autotuned**, meaning it will try multiple kernel configurations (block sizes, warp counts, pipeline stages) to find the optimal setup for your specific GPU hardware and workload shape.\n",
212+
"\n",
213+
"This autotuning process significantly improves performance and ensures the kernel is utilizing the GPU most efficiently.\n",
214+
"\n",
215+
"**How does it work?** \n",
216+
"- Triton runs benchmarks internally with different configurations. \n",
217+
"- It measures which configurations are fastest. \n",
218+
"- The result is cached, so future runs use the best-found setup.\n",
219+
"\n",
220+
"**Why do we re-run tuning?** \n",
221+
"- Hardware setups or driver versions may change. \n",
222+
"- Workload shapes (sequence lengths, batch sizes) might differ from defaults. \n",
223+
"- We want to confirm we’re using the best configuration for *this exact benchmark*.\n",
224+
"\n",
225+
"In the next cell, we trigger this autotuning pass.\n"
226+
]
227+
},
228+
{
229+
"cell_type": "markdown",
230+
"id": "51ef7113",
231+
"metadata": {},
232+
"source": [
233+
"## Note on Triton Autotuning and Caching Example\n",
234+
"\n",
235+
"- On the **first run**, when a specific kernel configuration (based on GPU hardware, batch size, sequence length, and head dimensions) is encountered for the first time, **Triton triggers autotuning**. \n",
236+
" - This process tries multiple kernel configurations in the background and picks the fastest one.\n",
237+
" - As a result, the **first run may be significantly slower** due to this tuning process.\n",
238+
"\n",
239+
"- Once the best-performing configuration is found, it is **stored in Triton's cache** (typically in `/workspace/.triton/cache`).\n",
240+
"\n",
241+
"- On **subsequent runs** with the same input shape and environment:\n",
242+
" - Triton **loads the tuned configuration from cache** and skips tuning.\n",
243+
" - This leads to **consistently fast kernel launches and execution** without re-tuning overhead.\n",
244+
"\n",
245+
"- If you clear the cache, the next run will re-trigger autotuning.\n",
246+
"\n",
247+
"> In short: \n",
248+
"> - First run = autotuning + execution (slow but smart) \n",
249+
"> - All future runs = cached config + execution (fast and efficient)\n"
250+
]
251+
},
252+
{
253+
"cell_type": "code",
254+
"execution_count": null,
255+
"id": "f972971b-fef9-4814-926a-da87432b47bb",
256+
"metadata": {},
257+
"outputs": [],
258+
"source": [
259+
"# Trigger re-tuning (will reuse cached or search if needed)\n",
260+
"vllm_benchmark.run(show_plots=False, print_data=True)"
261+
]
262+
},
263+
{
264+
"cell_type": "code",
265+
"execution_count": null,
266+
"id": "1ec96d79-50fe-4814-a668-27c95ceaa04f",
267+
"metadata": {},
268+
"outputs": [],
269+
"source": [
270+
"vllm_tuned_times = []\n",
271+
"\n",
272+
"for seqlen in seqlens:\n",
273+
" q = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n",
274+
" k = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n",
275+
" v = torch.randn(batch, nheads, seqlen, head_dim, device='cuda')\n",
276+
"\n",
277+
" torch.cuda.synchronize()\n",
278+
" start = time.time()\n",
279+
" out_vllm_tuned = run_vllm_flash_attention(q, k, v, seqlen)\n",
280+
" torch.cuda.synchronize()\n",
281+
" tuned_time = time.time() - start\n",
282+
" vllm_tuned_times.append(tuned_time)\n",
283+
" print(f\"Seqlen={seqlen}: Tuned vLLM CUDA={tuned_time:.4f}s\")"
284+
]
285+
},
286+
{
287+
"cell_type": "code",
288+
"execution_count": null,
289+
"id": "0a7e9593-48a3-42f0-a94d-6a45a4388354",
290+
"metadata": {},
291+
"outputs": [],
292+
"source": [
293+
"print(f\"{'SeqLen':>8} | {'PyTorch Time (s)':>18} | {'vLLM Tuned Time (s)':>20} | {'Speedup (PyTorch/vLLM)':>24}\")\n",
294+
"print(\"-\" * 75)\n",
295+
"for seqlen, pt_time, tuned_time in zip(seqlens, pytorch_times, vllm_tuned_times):\n",
296+
" speedup = pt_time / tuned_time\n",
297+
" print(f\"{seqlen:8} | {pt_time:18.6f} | {tuned_time:20.6f} | {speedup:24.2f}x\")\n"
298+
]
299+
},
300+
{
301+
"cell_type": "code",
302+
"execution_count": null,
303+
"id": "627a2f59-96ae-4994-854e-b1c9259529bd",
304+
"metadata": {},
305+
"outputs": [],
306+
"source": [
307+
"plt.figure()\n",
308+
"plt.plot(seqlens, pytorch_times, label=\"PyTorch SDPA (CUDA)\")\n",
309+
"plt.plot(seqlens, vllm_times, label=\"vLLM (Original)\")\n",
310+
"plt.plot(seqlens, vllm_tuned_times, label=\"vLLM (Autotuned)\")\n",
311+
"plt.xlabel(\"Sequence Length\")\n",
312+
"plt.ylabel(\"Time (s)\")\n",
313+
"plt.title(\"Flash Attention Benchmark: PyTorch vs vLLM (Before & After Autotune)\")\n",
314+
"plt.legend()\n",
315+
"plt.grid()\n",
316+
"plt.show()"
317+
]
318+
},
319+
{
320+
"cell_type": "code",
321+
"execution_count": null,
322+
"id": "fd225668-944f-4058-a20e-973188c7442e",
323+
"metadata": {},
324+
"outputs": [],
325+
"source": [
326+
"plt.figure()\n",
327+
"plt.plot(seqlens, pytorch_times, label=\"PyTorch SDPA (CUDA)\")\n",
328+
"plt.plot(seqlens, vllm_tuned_times, label=\"vLLM (Autotuned)\")\n",
329+
"plt.xlabel(\"Sequence Length\")\n",
330+
"plt.ylabel(\"Time (s)\")\n",
331+
"plt.title(\"Flash Attention Benchmark: PyTorch vs vLLM (After Autotune)\")\n",
332+
"plt.legend()\n",
333+
"plt.grid()\n",
334+
"plt.show()"
335+
]
336+
},
337+
{
338+
"cell_type": "code",
339+
"execution_count": null,
340+
"id": "487d9793-b35d-4fa6-ab05-7843d5fe96b5",
341+
"metadata": {},
342+
"outputs": [],
343+
"source": []
344+
}
345+
],
346+
"metadata": {
347+
"kernelspec": {
348+
"display_name": "Python 3 (ipykernel)",
349+
"language": "python",
350+
"name": "python3"
351+
},
352+
"language_info": {
353+
"codemirror_mode": {
354+
"name": "ipython",
355+
"version": 3
356+
},
357+
"file_extension": ".py",
358+
"mimetype": "text/x-python",
359+
"name": "python",
360+
"nbconvert_exporter": "python",
361+
"pygments_lexer": "ipython3",
362+
"version": "3.12.5"
363+
}
364+
},
365+
"nbformat": 4,
366+
"nbformat_minor": 5
367+
}

0 commit comments

Comments
 (0)
Please sign in to comment.