Commit 070068f
committed
Add a_1_128_w_128_128 (DeepSeek style) float8 scaling for inference
Summary:
Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in
torchao inference. In detail:
1. bring the 128x128 gemm triton kernel we have out of prototype and
wrap it with a custom op for `torch.compile` compatibility
2. enable the new granularity in various utility functions
3. wire the new granularity through the float8 inference configs
4. add a test which tests for e2e numerical correctness via SQNR
comparison vs high precision baseline
For now I added a fallback which only requires triton and is numerically
correct but may not reach optimal performance. Performance optimization is
left for future PRs:
1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+
2. we should enable an fbgemm_gpu_genai path, if available in user env
3. we should map to a triton kernel for quantizing the weights, as
`torch.compile` is currently known slow for 128x128 block
quantization
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
ghstack-source-id: db464e1
ghstack-comment-id: 3460951962
Pull-Request: #32571 parent 841d104 commit 070068f
File tree
3 files changed
+110
-31
lines changed- test/quantization/quantize_/workflows/float8
- torchao
- float8
- quantization/quantize_/workflows/float8
3 files changed
+110
-31
lines changedLines changed: 24 additions & 6 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| 21 | + | |
21 | 22 | | |
22 | 23 | | |
23 | 24 | | |
| |||
61 | 62 | | |
62 | 63 | | |
63 | 64 | | |
64 | | - | |
65 | | - | |
66 | | - | |
67 | | - | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
68 | 85 | | |
69 | 86 | | |
70 | | - | |
| 87 | + | |
| 88 | + | |
71 | 89 | | |
72 | 90 | | |
73 | 91 | | |
74 | 92 | | |
75 | 93 | | |
76 | 94 | | |
77 | | - | |
| 95 | + | |
78 | 96 | | |
79 | 97 | | |
80 | 98 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| 17 | + | |
17 | 18 | | |
18 | 19 | | |
19 | 20 | | |
| |||
196 | 197 | | |
197 | 198 | | |
198 | 199 | | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
199 | 220 | | |
200 | 221 | | |
201 | 222 | | |
| |||
211 | 232 | | |
212 | 233 | | |
213 | 234 | | |
214 | | - | |
215 | | - | |
216 | | - | |
217 | | - | |
218 | | - | |
219 | | - | |
220 | | - | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
221 | 247 | | |
222 | 248 | | |
223 | | - | |
| 249 | + | |
224 | 250 | | |
225 | 251 | | |
226 | 252 | | |
227 | | - | |
228 | | - | |
229 | | - | |
| 253 | + | |
230 | 254 | | |
231 | 255 | | |
232 | 256 | | |
| |||
243 | 267 | | |
244 | 268 | | |
245 | 269 | | |
246 | | - | |
247 | | - | |
248 | | - | |
249 | | - | |
250 | | - | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
251 | 279 | | |
| 280 | + | |
252 | 281 | | |
253 | 282 | | |
254 | 283 | | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
Lines changed: 34 additions & 9 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
| 18 | + | |
| 19 | + | |
18 | 20 | | |
19 | 21 | | |
20 | 22 | | |
21 | 23 | | |
22 | 24 | | |
23 | 25 | | |
24 | 26 | | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
25 | 30 | | |
26 | 31 | | |
27 | 32 | | |
| |||
337 | 342 | | |
338 | 343 | | |
339 | 344 | | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
340 | 350 | | |
341 | 351 | | |
342 | 352 | | |
343 | 353 | | |
344 | | - | |
345 | | - | |
346 | | - | |
347 | | - | |
348 | | - | |
349 | | - | |
350 | | - | |
351 | | - | |
352 | | - | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
353 | 378 | | |
354 | 379 | | |
355 | 380 | | |
| |||
0 commit comments