Skip to content

Commit db573bc

Browse files
loislotensorflower-gardener
authored andcommitted
[XLA:GPU] Add explicit rounding of the F32 arguments of dot to TF32 if the dot algorithm set as TF32.
Triton lowers the tf32 dot to mma instruction that does not have explicit rounding attribute for tf32 inputs. As a result the precision of the tf32 dot is even worth than BF16_BF16_F32 algorithm. Lets round explicitly the arguments when we have this execution sequence. PiperOrigin-RevId: 741459476
1 parent f0c813d commit db573bc

File tree

7 files changed

+242
-0
lines changed

7 files changed

+242
-0
lines changed

third_party/xla/xla/backends/gpu/codegen/triton/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ xla_test(
561561
"//xla/service/gpu/tests:gpu_codegen_test",
562562
"//xla/stream_executor:device_description",
563563
"//xla/stream_executor/cuda:cuda_compute_capability",
564+
"//xla/tests:test_utils",
564565
"//xla/tests:xla_internal_test_main", # fixdeps: keep
565566
"//xla/tsl/lib/core:status_test_util",
566567
"//xla/tsl/platform:statusor",
@@ -571,6 +572,7 @@ xla_test(
571572
"@com_google_absl//absl/status:statusor",
572573
"@com_google_absl//absl/strings",
573574
"@com_google_absl//absl/strings:str_format",
575+
"@com_google_absl//absl/types:span",
574576
"@com_google_googletest//:gtest",
575577
"@local_tsl//tsl/platform:path",
576578
],

third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline_cuda.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ absl::Status CreateTritonPipeline(mlir::OpPassManager* pm,
5353
const int ccAsInt = cc.major * 10 + cc.minor;
5454
const int threadsPerWarp = 32;
5555

56+
pm->addPass(mt_xla::CreateRoundF32ToTF32ForTf32DotRewritePass());
5657
if (is_xla_fusion) {
5758
pm->addPass(mt_xla::CreateInt4ToPackedInt4RewritePass());
5859
}

third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,16 @@ limitations under the License.
1515

1616
#include <algorithm>
1717
#include <cstddef>
18+
#include <cstdint>
19+
#include <cstdlib>
1820
#include <initializer_list>
21+
#include <iomanip>
22+
#include <ios>
1923
#include <iterator>
2024
#include <limits>
2125
#include <memory>
2226
#include <string>
27+
#include <unordered_map>
2328
#include <utility>
2429
#include <variant>
2530
#include <vector>
@@ -36,6 +41,7 @@ limitations under the License.
3641
#include "absl/strings/str_format.h"
3742
#include "absl/strings/str_replace.h"
3843
#include "absl/strings/string_view.h"
44+
#include "absl/types/span.h"
3945
#include "xla/autotuning.pb.h"
4046
#include "xla/backends/gpu/codegen/triton/kernel_name_tracer.h"
4147
#include "xla/backends/gpu/codegen/triton/test_utils.h"
@@ -53,6 +59,7 @@ limitations under the License.
5359
#include "xla/service/hlo_module_config.h"
5460
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
5561
#include "xla/stream_executor/device_description.h"
62+
#include "xla/tests/test_utils.h"
5663
#include "xla/tsl/lib/core/status_test_util.h"
5764
#include "xla/tsl/platform/statusor.h"
5865
#include "xla/xla.pb.h"
@@ -1474,6 +1481,129 @@ INSTANTIATE_TEST_SUITE_P(
14741481
PC::ALG_DOT_TF32_TF32_F32_X3, PC::ALG_DOT_F64_F64_F64, PC::ALG_UNSET}),
14751482
AlgorithmTestParamToString);
14761483

1484+
class PrecisionTestsForTriton : public TritonAlgorithmTest,
1485+
public NumericTestsArguments,
1486+
public WithParamInterface<PC::Algorithm> {
1487+
public:
1488+
PrecisionTestsForTriton() : TritonAlgorithmTest() {
1489+
algorithm_ = AlgorithmToString(GetParam());
1490+
}
1491+
1492+
std::string test_hlo_text() const {
1493+
return absl::StrReplaceAll(kHloText, {{"${test_name}", HloModuleTestName()},
1494+
{"${algorithm}", algorithm_}});
1495+
}
1496+
std::string reference_hlo_text() const {
1497+
return absl::StrReplaceAll(kHloText, {{"${test_name}", HloModuleTestName()},
1498+
{"${algorithm}", "dot_f32_f32_f32"}});
1499+
}
1500+
1501+
absl::string_view algorithm() const { return algorithm_; }
1502+
1503+
static constexpr absl::string_view kPattern = R"(CHECK: __triton_gemm)";
1504+
1505+
absl::StatusOr<std::unique_ptr<HloModule>> GetModule(
1506+
const std::string& hlo_text) {
1507+
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1508+
GetOptimizedModule(hlo_text));
1509+
auto module_text = module->ToString();
1510+
TF_ASSIGN_OR_RETURN(auto ok, RunFileCheck(module_text, kPattern));
1511+
if (!ok) {
1512+
return absl::InternalError(
1513+
"The module does not contain the pattern __triton_gemm.");
1514+
}
1515+
return module;
1516+
}
1517+
1518+
private:
1519+
static constexpr absl::string_view kHloText = R"(
1520+
HloModule ${test_name}
1521+
1522+
ENTRY main {
1523+
p0 = f32[1024,1024]{1,0} parameter(0)
1524+
p1 = f32[1024,1024]{1,0} parameter(1)
1525+
ROOT %dot = f32[1024,1024]{1,0} dot(p0, p1),
1526+
lhs_contracting_dims={1},
1527+
rhs_contracting_dims={0},
1528+
algorithm=${algorithm}
1529+
}
1530+
)";
1531+
std::string algorithm_;
1532+
};
1533+
1534+
TEST_P(PrecisionTestsForTriton, PrecisionCheck) {
1535+
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
1536+
GTEST_SKIP() << "Precision tests is unknown for ROCM.";
1537+
}
1538+
1539+
TF_ASSERT_OK_AND_ASSIGN(auto test_module, GetModule(test_hlo_text()));
1540+
TF_ASSERT_OK_AND_ASSIGN(auto ref_module, GetModule(reference_hlo_text()));
1541+
1542+
// Prepare arguments.
1543+
absl::StatusOr<std::vector<Literal>> fake_arguments = MakeFakeArguments(
1544+
test_module.get(), /*pseudo_random=*/true, /*use_large_range=*/false,
1545+
/*treat_gte_as_data_formatting=*/false, 23);
1546+
CHECK_OK(fake_arguments);
1547+
1548+
// abs the arguments.
1549+
for (auto& literal : *fake_arguments) {
1550+
literal.MutableEachCell<float>([](absl::Span<const int64_t> indices,
1551+
float value) { return std::abs(value); });
1552+
}
1553+
std::vector<Literal*> fake_argument_ptrs;
1554+
absl::c_transform(
1555+
*fake_arguments, std::back_inserter(fake_argument_ptrs),
1556+
[](const Literal& literal) { return const_cast<Literal*>(&literal); });
1557+
1558+
// Run the test and reference modules.
1559+
TF_ASSERT_OK_AND_ASSIGN(
1560+
auto test_result,
1561+
test_runner().Execute(std::move(test_module), fake_argument_ptrs, false));
1562+
TF_ASSERT_OK_AND_ASSIGN(
1563+
auto ref_result,
1564+
test_runner().Execute(std::move(ref_module), fake_argument_ptrs, false));
1565+
1566+
// Calculate the relative and absolute errors.
1567+
absl::Span<const float> test_data = test_result.data<float>();
1568+
absl::Span<const float> ref_data = ref_result.data<float>();
1569+
float abs_error = 0.0f;
1570+
float rel_error = 0.0f;
1571+
for (int i = 0; i < test_data.size(); ++i) {
1572+
abs_error += std::abs(test_data[i] - ref_data[i]);
1573+
rel_error += std::abs((test_data[i] - ref_data[i]) / ref_data[i]);
1574+
}
1575+
abs_error /= test_data.size();
1576+
rel_error /= test_data.size();
1577+
1578+
std::unordered_map<PC::Algorithm, float> max_mean_rel_error = {
1579+
{PC::ALG_DOT_BF16_BF16_F32, 6e-5},
1580+
{PC::ALG_DOT_TF32_TF32_F32, 2e-5},
1581+
{PC::ALG_DOT_BF16_BF16_F32_X3, 2e-5},
1582+
{PC::ALG_DOT_BF16_BF16_F32_X6, 4e-7},
1583+
{PC::ALG_DOT_BF16_BF16_F32_X9, 4e-7},
1584+
{PC::ALG_DOT_TF32_TF32_F32_X3, 5e-7}};
1585+
1586+
LOG(INFO) << "mean(abs_error): " << abs_error;
1587+
LOG(ERROR) << "mean(rel_error): " << std::fixed << std::setprecision(9)
1588+
<< rel_error;
1589+
LOG(ERROR) << "max_mean_rel_error: " << std::fixed << std::setprecision(9)
1590+
<< max_mean_rel_error[GetParam()];
1591+
1592+
ASSERT_TRUE(max_mean_rel_error.find(GetParam()) != max_mean_rel_error.end())
1593+
<< "No precision test for algorithm " << algorithm();
1594+
EXPECT_LT(rel_error, max_mean_rel_error[GetParam()])
1595+
<< "mean(rel_error) is too high.";
1596+
}
1597+
1598+
INSTANTIATE_TEST_SUITE_P(PrecisionTestsForTriton, PrecisionTestsForTriton,
1599+
::testing::ValuesIn({PC::ALG_DOT_TF32_TF32_F32,
1600+
PC::ALG_DOT_TF32_TF32_F32_X3,
1601+
PC::ALG_DOT_BF16_BF16_F32,
1602+
PC::ALG_DOT_BF16_BF16_F32_X3,
1603+
PC::ALG_DOT_BF16_BF16_F32_X6,
1604+
PC::ALG_DOT_BF16_BF16_F32_X9}),
1605+
AlgorithmTestParamToString);
1606+
14771607
} // namespace
14781608
} // namespace gpu
14791609
} // namespace xla

third_party/xla/xla/backends/gpu/codegen/triton/transforms/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ cc_library(
3737
"generalize_kernel_signature.cc",
3838
"int4_passes.cc",
3939
"prevent_mmav3_loop_unrolling_pass.cc",
40+
"round_f32_to_tf32_for_tf32_dot_pass.cc",
4041
"triton_xla_extract_insert_to_triton_pass.cc",
4142
],
4243
hdrs = ["passes.h"],

third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ std::unique_ptr<mlir::Pass> CreateTritonXLAExtractInsertToTritonPass(
3838
std::unique_ptr<mlir::Pass> CreateGeneralizeKernelSignaturePass();
3939
std::unique_ptr<mlir::Pass> CreatePreventMmaV3LoopUnrollingPass();
4040
std::unique_ptr<mlir::Pass> CreateInt4ToPackedInt4RewritePass();
41+
std::unique_ptr<mlir::Pass> CreateRoundF32ToTF32ForTf32DotRewritePass();
4142

4243
// Returns true if the `op` contains an operation in it's regions that satisfies
4344
// the `fn`.

third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,16 @@ def LoadInt4RewritePass
7070
let constructor = "CreateInt4ToPackedInt4RewritePass()";
7171
}
7272

73+
def RoundF32ToTF32ForTf32DotRewritePass
74+
: Pass<"round-f32-to-tf32-for-tf32-dot-rewrite", "mlir::ModuleOp"> {
75+
let summary = "dot with tf32 algorithm requires explicit rounding.";
76+
let description = [{
77+
This pass adds explicit rounding from f32 to tf32 for the dot with tf32 algorithm.
78+
This is required because mma instruction does not have explicit rounding and
79+
by default does truncation. As a result, the dot with tf32 algorithm has too
80+
small precision. It is even less than for the dot with BF16 arguments.
81+
}];
82+
let constructor = "CreateRoundF32ToTF32ForTf32DotRewritePass()";
83+
}
84+
7385
#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_PASSES_TD_
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <memory>
17+
#include <utility>
18+
19+
#include "mlir/IR/BuiltinAttributes.h"
20+
#include "mlir/IR/BuiltinTypeInterfaces.h"
21+
#include "mlir/IR/BuiltinTypes.h"
22+
#include "mlir/IR/Operation.h"
23+
#include "mlir/IR/OperationSupport.h"
24+
#include "mlir/IR/PatternMatch.h"
25+
#include "mlir/IR/Types.h"
26+
#include "mlir/IR/Value.h"
27+
#include "mlir/Pass/Pass.h"
28+
#include "mlir/Support/LLVM.h"
29+
#include "mlir/Support/LogicalResult.h"
30+
#include "mlir/Transforms/DialectConversion.h"
31+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32+
#include "triton/Dialect/Triton/IR/Dialect.h"
33+
34+
namespace mlir::triton::xla {
35+
36+
namespace mt = ::mlir::triton;
37+
38+
#define GEN_PASS_DEF_ROUNDF32TOTF32FORTF32DOTREWRITEPASS
39+
#include "xla/backends/gpu/codegen/triton/transforms/passes.h.inc"
40+
41+
namespace {
42+
43+
class Tf32DotPattern : public OpRewritePattern<mt::DotOp> {
44+
public:
45+
explicit Tf32DotPattern(MLIRContext *context)
46+
: OpRewritePattern<mt::DotOp>(context) {}
47+
48+
using OpRewritePattern<mt::DotOp>::OpRewritePattern;
49+
50+
mlir::LogicalResult matchAndRewrite(
51+
mt::DotOp op, PatternRewriter &rewriter) const override {
52+
constexpr auto tf32_args_rounded = "tf32_arguments_rounded";
53+
if (op.getInputPrecision() != mt::InputPrecision::TF32) return failure();
54+
if (!op.getA().getType().getElementType().isF32()) return failure();
55+
if (!op.getB().getType().getElementType().isF32()) return failure();
56+
if (op->hasAttr(tf32_args_rounded)) return failure();
57+
58+
auto f32ToTF32 = [&](Value value) -> Value {
59+
return rewriter
60+
.create<ElementwiseInlineAsmOp>(
61+
op.getLoc(), value.getType(), "cvt.rna.tf32.f32 $0, $1;", "=r,r",
62+
/*isPure=*/true, /*pack=*/1, ArrayRef<Value>{value})
63+
->getResult(0);
64+
};
65+
auto lhs = f32ToTF32(op.getA());
66+
auto rhs = f32ToTF32(op.getB());
67+
auto dot = rewriter.replaceOpWithNewOp<mt::DotOp>(
68+
op, op.getC().getType(), lhs, rhs, op.getC(), mt::InputPrecision::TF32,
69+
/*maxNumImpreciseAcc=*/0);
70+
dot->setAttr(tf32_args_rounded, rewriter.getUnitAttr());
71+
72+
return success();
73+
}
74+
};
75+
76+
struct RoundF32ToTF32ForTf32DotRewritePass
77+
: public impl::RoundF32ToTF32ForTf32DotRewritePassBase<
78+
RoundF32ToTF32ForTf32DotRewritePass> {
79+
void runOnOperation() override {
80+
auto module = getOperation();
81+
RewritePatternSet patterns(&getContext(),
82+
std::make_unique<Tf32DotPattern>(&getContext()));
83+
if (failed(applyPatternsGreedily(module, std::move(patterns)))) {
84+
signalPassFailure();
85+
}
86+
}
87+
};
88+
89+
} // namespace
90+
91+
std::unique_ptr<Pass> CreateRoundF32ToTF32ForTf32DotRewritePass() {
92+
return std::make_unique<RoundF32ToTF32ForTf32DotRewritePass>();
93+
}
94+
95+
} // namespace mlir::triton::xla

0 commit comments

Comments
 (0)