Commit 56bd8d1
PR tensorflow#24114: Triton/Nvidia: Fix fused fp8 <-> fp8 conversions
Imported from GitHub PR openxla/xla#24114
Converting FP8 <-> FP8 fails because the Triton compiler does not support it.
The proposed fix will make the conversion go through FP16.
Two questions:
1) Are there any better approaches of solving this?
2) I could not find a place to put unit tests for this, and in the code there is a comment saying:
```
// TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as
// we can't test the code below without patching the feature.
```
Wondering if there is a place where I can add a test?
### Details
When converting FP8 types, the XLA compiler emits a `fp_to_fp` Triton instruction. If the source type is FP8, no rounding strategy is specified.
Concretely, this causes the following Triton to be emitted:
<details>
<summary>
<code>%24 = tt.fp_to_fp %20 : tensor<32x64xf8E5M2> -> tensor<32x64xf8E4M3FN></code>
</summary>
```
module {
tt.func @gemm_fusion_dot_320_impl(%arg0: !tt.ptr<f8E4M3FN> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f8E4M3FN> {tt.divisibility = 16 : i32}) {
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf8E4M3FN>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x64xf8E4M3FN>
%c90_i32 = arith.constant 90 : i32
%c32000_i64 = arith.constant 32000 : i64
%c64_i32 = arith.constant 64 : i32
%c90_i64 = arith.constant 90 : i64
%c768_i64 = arith.constant 768 : i64
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64
%c32_i32 = arith.constant 32 : i32
%c24_i32 = arith.constant 24 : i32
%c8_i32 = arith.constant 8 : i32
%c4000_i32 = arith.constant 4000 : i32
%cst_1 = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
%0 = tt.get_program_id x : i32
%1 = arith.divsi %0, %c4000_i32 : i32
%2 = arith.muli %1, %c8_i32 : i32
%3 = arith.subi %c24_i32, %2 : i32
%4 = arith.cmpi slt, %3, %c8_i32 : i32
%5 = arith.select %4, %3, %c8_i32 : i32
%6 = arith.remsi %0, %5 : i32
%7 = arith.addi %2, %6 : i32
%8 = arith.remsi %0, %c4000_i32 : i32
%9 = arith.divsi %8, %5 : i32
%10 = arith.muli %7, %c32_i32 : i32
%11 = tt.make_tensor_ptr %arg1, [%c768_i64, %c90_i64], [%c1_i64, %c768_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf8E5M2>>
%12 = tt.advance %11, [%10, %c0_i32] : <tensor<32x64xf8E5M2>>
%13 = arith.muli %9, %c64_i32 : i32
%14 = tt.make_tensor_ptr %arg0, [%c90_i64, %c32000_i64], [%c1_i64, %c90_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf8E4M3FN>>
%15 = tt.advance %14, [%c0_i32, %13] : <tensor<64x64xf8E4M3FN>>
%16:3 = scf.for %arg3 = %c0_i32 to %c90_i32 step %c64_i32 iter_args(%arg4 = %12, %arg5 = %15, %arg6 = %cst_1) -> (!tt.ptr<tensor<32x64xf8E5M2>>, !tt.ptr<tensor<64x64xf8E4M3FN>>, tensor<32x64xf32>) : i32 {
%20 = tt.load %arg4 {boundaryCheck = array<i32: 1>, padding = 1 : i32} : !tt.ptr<tensor<32x64xf8E5M2>>
%21 = tt.advance %arg4, [%c0_i32, %c64_i32] : <tensor<32x64xf8E5M2>>
%22 = tt.load %arg5 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<64x64xf8E4M3FN>>
%23 = tt.advance %arg5, [%c64_i32, %c0_i32] : <tensor<64x64xf8E4M3FN>>
%24 = tt.fp_to_fp %20 : tensor<32x64xf8E5M2> -> tensor<32x64xf8E4M3FN>
%25 = arith.subi %c90_i32, %arg3 : i32
%26 = arith.cmpi slt, %25, %c64_i32 : i32
%27 = scf.if %26 -> (tensor<32x64xf8E4M3FN>) {
%30 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%31 = tt.expand_dims %30 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
%32 = tt.splat %25 : i32 -> tensor<1x64xi32>
%33 = arith.cmpi slt, %31, %32 : tensor<1x64xi32>
%34 = tt.broadcast %33 : tensor<1x64xi1> -> tensor<32x64xi1>
%35 = arith.select %34, %24, %cst_0 : tensor<32x64xi1>, tensor<32x64xf8E4M3FN>
scf.yield %35 : tensor<32x64xf8E4M3FN>
} else {
scf.yield %24 : tensor<32x64xf8E4M3FN>
}
%28 = scf.if %26 -> (tensor<64x64xf8E4M3FN>) {
%30 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%31 = tt.expand_dims %30 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
%32 = tt.splat %25 : i32 -> tensor<64x1xi32>
%33 = arith.cmpi slt, %31, %32 : tensor<64x1xi32>
%34 = tt.broadcast %33 : tensor<64x1xi1> -> tensor<64x64xi1>
%35 = arith.select %34, %22, %cst : tensor<64x64xi1>, tensor<64x64xf8E4M3FN>
scf.yield %35 : tensor<64x64xf8E4M3FN>
} else {
scf.yield %22 : tensor<64x64xf8E4M3FN>
}
%29 = tt.dot %27, %28, %arg6, inputPrecision = tf32 {maxNumImpreciseAcc = 2147483647 : i32} : tensor<32x64xf8E4M3FN> * tensor<64x64xf8E4M3FN> -> tensor<32x64xf32>
scf.yield %21, %23, %29 : !tt.ptr<tensor<32x64xf8E5M2>>, !tt.ptr<tensor<64x64xf8E4M3FN>>, tensor<32x64xf32>
}
%17 = tt.fp_to_fp %16#2, rounding = rtne : tensor<32x64xf32> -> tensor<32x64xf8E4M3FN>
%18 = tt.make_tensor_ptr %arg2, [%c768_i64, %c32000_i64], [%c1_i64, %c768_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf8E4M3FN>>
%19 = tt.advance %18, [%10, %13] : <tensor<32x64xf8E4M3FN>>
tt.store %19, %17 : !tt.ptr<tensor<32x64xf8E4M3FN>>
tt.return
}
}
```
</details>
Which leads to a failing assertion:
```
#0 0x000073413786d9fc in pthread_kill () from /lib/x86_64-linux-gnu/libc.so.6
#1 0x0000734137819476 in raise () from /lib/x86_64-linux-gnu/libc.so.6
#2 0x00007341377ff7f3 in abort () from /lib/x86_64-linux-gnu/libc.so.6
#3 0x00007341377ff71b in ?? () from /lib/x86_64-linux-gnu/libc.so.6
#4 0x0000734137810e96 in __assert_fail () from /lib/x86_64-linux-gnu/libc.so.6
#5 0x000057d936b1777b in mlir::triton::gpu::(anonymous namespace)::FpToFpOpConversion::createDestOps (this=0x733d08425cc0, op=..., adaptor=..., rewriter=..., elemTy=..., operands=..., loc=...)
at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp:500
#6 0x000057d936b17195 in mlir::triton::gpu::ElementwiseOpConversionBase<mlir::triton::FpToFpOp, mlir::triton::gpu::(anonymous namespace)::FpToFpOpConversion>::matchAndRewrite (this=0x733d08425cc0, op=..., adaptor=..., rewriter=...)
at external/triton/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h:188
[...]
#29 0x000057d93fa6cade in mlir::PassManager::run (this=0x733e80fba158, op=0x733d080bbc20) at external/llvm-project/mlir/lib/Pass/Pass.cpp:885
#30 0x000057d9363f6b1b in xla::gpu::CompileTritonToLLVM (hlo_config=..., hlo_module_name="gemm_fusion_dot.320", device_info=..., block_level_parameters=..., triton_module=..., llvm_module=0x733d0816d6a0, mlir_context=..., is_xla_fusion=true, emit_kernel=true)
at xla/backends/gpu/codegen/triton/fusion_emitter.cc:1627
#31 0x000057d9363f5a5d in xla::gpu::TritonWrapper (fn_name="gemm_fusion_dot_320_impl", fusion=0x733d080a31c0, cc=std::variant<stream_executor::CudaComputeCapability, stream_executor::RocmComputeCapability> [index 0] = {...}, device_info=..., block_level_parameters=...,
llvm_module=0x733d0816d6a0, mlir_context=...) at xla/backends/gpu/codegen/triton/fusion_emitter.cc:1531
```
However, this fails Triton compilation:
* First it hits an assertion that the rounding strategy when the destination type is FP8 must be specified
* Adding the rounding strategy, then goes on to another issue, that no methods for converting FP8 <-> FP8 are specified
To work around the above two issues, I propose going through FP16 when both the source and destination types are FP8's.
Copybara import of the project:
--
afd3929099fc4d1045275ca3210e0bc727a2b906 by Kasper Nielsen <kasper0406@gmail.com>:
Fix fused fp8 <-> fp8 conversions
--
66340aa808f58e5dc6ab1c2e06790ceccde95540 by Kasper Nielsen <kasper0406@gmail.com>:
Add unit tests and refactor duplicated code
--
07ae307879eff24ad2f85607e94503deda1074e4 by Kasper Nielsen <kasper0406@gmail.com>:
Run clang-format
--
fe967ff94ffc5f34f07bff142b5d10d81d5e4dce by Kasper Nielsen <kasper0406@gmail.com>:
Fix support conversion tests
Merging this change closes tensorflow#24114
PiperOrigin-RevId: 7414736481 parent 362cc00 commit 56bd8d1
File tree
6 files changed
+93
-130
lines changed- third_party/xla/xla/backends/gpu/codegen/triton
6 files changed
+93
-130
lines changedLines changed: 11 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
126 | 126 | | |
127 | 127 | | |
128 | 128 | | |
| 129 | + | |
129 | 130 | | |
130 | 131 | | |
131 | 132 | | |
132 | 133 | | |
133 | 134 | | |
| 135 | + | |
134 | 136 | | |
135 | 137 | | |
136 | 138 | | |
| |||
156 | 158 | | |
157 | 159 | | |
158 | 160 | | |
159 | | - | |
| 161 | + | |
160 | 162 | | |
161 | 163 | | |
162 | | - | |
| 164 | + | |
163 | 165 | | |
164 | 166 | | |
165 | 167 | | |
166 | 168 | | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
167 | 176 | | |
168 | 177 | | |
169 | 178 | | |
| |||
Lines changed: 30 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4202 | 4202 | | |
4203 | 4203 | | |
4204 | 4204 | | |
| 4205 | + | |
| 4206 | + | |
| 4207 | + | |
| 4208 | + | |
| 4209 | + | |
| 4210 | + | |
| 4211 | + | |
| 4212 | + | |
| 4213 | + | |
| 4214 | + | |
| 4215 | + | |
| 4216 | + | |
| 4217 | + | |
| 4218 | + | |
| 4219 | + | |
| 4220 | + | |
| 4221 | + | |
| 4222 | + | |
| 4223 | + | |
| 4224 | + | |
| 4225 | + | |
| 4226 | + | |
| 4227 | + | |
| 4228 | + | |
| 4229 | + | |
| 4230 | + | |
| 4231 | + | |
| 4232 | + | |
| 4233 | + | |
| 4234 | + | |
4205 | 4235 | | |
4206 | 4236 | | |
4207 | 4237 | | |
| |||
Lines changed: 34 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2021 | 2021 | | |
2022 | 2022 | | |
2023 | 2023 | | |
| 2024 | + | |
| 2025 | + | |
| 2026 | + | |
| 2027 | + | |
| 2028 | + | |
| 2029 | + | |
| 2030 | + | |
| 2031 | + | |
| 2032 | + | |
| 2033 | + | |
| 2034 | + | |
| 2035 | + | |
| 2036 | + | |
| 2037 | + | |
| 2038 | + | |
| 2039 | + | |
| 2040 | + | |
| 2041 | + | |
| 2042 | + | |
| 2043 | + | |
| 2044 | + | |
| 2045 | + | |
| 2046 | + | |
| 2047 | + | |
| 2048 | + | |
| 2049 | + | |
| 2050 | + | |
| 2051 | + | |
| 2052 | + | |
| 2053 | + | |
| 2054 | + | |
| 2055 | + | |
| 2056 | + | |
| 2057 | + | |
2024 | 2058 | | |
2025 | 2059 | | |
2026 | 2060 | | |
| |||
Lines changed: 8 additions & 123 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
227 | 227 | | |
228 | 228 | | |
229 | 229 | | |
230 | | - | |
231 | | - | |
232 | | - | |
233 | | - | |
234 | | - | |
235 | | - | |
236 | | - | |
237 | | - | |
238 | | - | |
239 | | - | |
240 | | - | |
241 | | - | |
242 | | - | |
243 | | - | |
244 | | - | |
245 | | - | |
246 | | - | |
247 | | - | |
248 | | - | |
249 | | - | |
250 | | - | |
251 | | - | |
252 | | - | |
253 | | - | |
254 | | - | |
255 | | - | |
256 | | - | |
257 | | - | |
258 | | - | |
259 | | - | |
260 | | - | |
261 | | - | |
262 | | - | |
263 | | - | |
264 | | - | |
265 | | - | |
266 | | - | |
267 | | - | |
268 | | - | |
269 | | - | |
270 | | - | |
271 | | - | |
272 | | - | |
273 | | - | |
274 | | - | |
275 | | - | |
276 | | - | |
277 | | - | |
278 | | - | |
279 | | - | |
280 | | - | |
281 | | - | |
282 | | - | |
283 | | - | |
284 | | - | |
285 | | - | |
286 | | - | |
287 | | - | |
288 | | - | |
289 | | - | |
290 | | - | |
291 | | - | |
292 | | - | |
293 | | - | |
294 | | - | |
295 | | - | |
296 | | - | |
297 | | - | |
298 | | - | |
299 | | - | |
300 | | - | |
301 | | - | |
302 | | - | |
303 | | - | |
304 | | - | |
305 | | - | |
306 | | - | |
307 | | - | |
308 | | - | |
309 | | - | |
310 | | - | |
311 | | - | |
312 | | - | |
313 | | - | |
314 | | - | |
315 | | - | |
316 | | - | |
317 | | - | |
318 | | - | |
319 | | - | |
320 | | - | |
321 | | - | |
322 | | - | |
323 | | - | |
324 | | - | |
325 | | - | |
326 | | - | |
327 | | - | |
328 | | - | |
329 | | - | |
330 | | - | |
331 | | - | |
332 | | - | |
333 | | - | |
334 | | - | |
335 | | - | |
336 | | - | |
337 | | - | |
338 | | - | |
339 | | - | |
340 | | - | |
341 | | - | |
342 | | - | |
343 | | - | |
344 | | - | |
345 | 230 | | |
346 | 231 | | |
347 | 232 | | |
| |||
448 | 333 | | |
449 | 334 | | |
450 | 335 | | |
451 | | - | |
| 336 | + | |
452 | 337 | | |
453 | 338 | | |
454 | 339 | | |
| |||
661 | 546 | | |
662 | 547 | | |
663 | 548 | | |
664 | | - | |
| 549 | + | |
665 | 550 | | |
666 | 551 | | |
667 | 552 | | |
| |||
817 | 702 | | |
818 | 703 | | |
819 | 704 | | |
820 | | - | |
| 705 | + | |
821 | 706 | | |
822 | 707 | | |
823 | 708 | | |
| |||
1480 | 1365 | | |
1481 | 1366 | | |
1482 | 1367 | | |
1483 | | - | |
| 1368 | + | |
1484 | 1369 | | |
1485 | 1370 | | |
1486 | 1371 | | |
| |||
2041 | 1926 | | |
2042 | 1927 | | |
2043 | 1928 | | |
2044 | | - | |
| 1929 | + | |
2045 | 1930 | | |
2046 | 1931 | | |
2047 | 1932 | | |
| |||
2167 | 2052 | | |
2168 | 2053 | | |
2169 | 2054 | | |
2170 | | - | |
| 2055 | + | |
2171 | 2056 | | |
2172 | 2057 | | |
2173 | | - | |
| 2058 | + | |
2174 | 2059 | | |
2175 | 2060 | | |
2176 | 2061 | | |
| |||
2364 | 2249 | | |
2365 | 2250 | | |
2366 | 2251 | | |
2367 | | - | |
| 2252 | + | |
2368 | 2253 | | |
2369 | 2254 | | |
2370 | 2255 | | |
| |||
Lines changed: 7 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
130 | 130 | | |
131 | 131 | | |
132 | 132 | | |
133 | | - | |
134 | | - | |
135 | | - | |
136 | | - | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
137 | 140 | | |
138 | 141 | | |
139 | 142 | | |
| |||
Lines changed: 3 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
410 | 410 | | |
411 | 411 | | |
412 | 412 | | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
413 | 416 | | |
414 | 417 | | |
415 | 418 | | |
416 | 419 | | |
417 | | - | |
418 | 420 | | |
419 | 421 | | |
420 | 422 | | |
| |||
0 commit comments