Skip to content

Commit 6d2d287

Browse files
authored
[Precision] Introduce T.ieee_rsqrt and related high precision op (tile-ai#882)
* Add fast math operations for CUDA: exp, exp10, log, log2, log10, tan, cos, and sin (tile-ai#865) * Refactor fast math operation definitions for consistency and readability in CUDA code. Consolidated multiple definitions into single lines and improved formatting in related test files for better clarity. * Remove unnecessary pass configurations for warp specialization and TMA lowering in fast math operation tests for CUDA. This simplifies the test setup while maintaining the focus on fast math functionality. * Update fastmath tests to reflect that tl.* intrinsics generate no fastmath versions and disable cache in main execution. * Fix formatting in fastmath test comments for clarity on tl.* intrinsics behavior. * Add precision comparison tool for CUDA operations This commit introduces a new Python script and CUDA source file for a precision comparison tool that evaluates the accuracy of various CUDA operations (including division, reciprocal, exponential, logarithmic, and trigonometric functions) across different implementations: CUDA Precise, CUDA Fast, Triton, Triton LibDevice, and TileLang. The tool generates test data, executes the operations, and summarizes the error statistics for each implementation against a double precision reference. Additionally, a README file is added to document the results of the comparisons for various operations. * Add precision comparison tool for CUDA operations This commit introduces a new precision comparison tool implemented in Python and CUDA, designed to evaluate the accuracy of various mathematical operations (division, reciprocal, exponential, logarithmic, trigonometric, square root, etc.) across different frameworks including CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang. The tool includes functionality for generating test data, executing operations, and summarizing error statistics for each implementation. Additionally, it provides a comprehensive README with error metrics for each operation tested. * Add IEEE-compliant mathematical operations and refactor fast math module This commit introduces new high precision mathematical operations including ieee_add, ieee_sub, ieee_mul, ieee_fmaf, ieee_frcp, ieee_fsqrt, ieee_frsqrt, and ieee_fdiv to the TileLang framework. The fast math module has been refactored to remove the deprecated fastmath.py file and update the import paths accordingly. Additionally, the CUDA code generation has been enhanced to support these new operations, ensuring compatibility with IEEE standards for floating-point arithmetic. * debug removed * Refactor IEEE math tests for improved readability and consistency This commit enhances the formatting of the `test_ieee_math.py` and `test_mathops_fastmath.py` files by adjusting line breaks for better clarity. It also removes unnecessary comments and ensures that the main execution of tests is streamlined. These changes aim to improve the overall maintainability of the test code. * Update README.md to enhance formatting of precision comparison results This commit reformats the precision comparison results in the README.md file, converting the error statistics tables into a more structured markdown format. This change improves readability and accessibility of the data for various mathematical operations across different implementations, including FP32 Precise, Triton, TileLang, and CUDA.
1 parent c10fffb commit 6d2d287

File tree

8 files changed

+1155
-110
lines changed

8 files changed

+1155
-110
lines changed

maint/precision/README.md

Lines changed: 119 additions & 109 deletions
Large diffs are not rendered by default.

src/op/builtin.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,35 @@ TIR_DEFINE_TL_BUILTIN(__cos).set_num_inputs(1).set_attr<TCallEffectKind>(
6666
TIR_DEFINE_TL_BUILTIN(__sin).set_num_inputs(1).set_attr<TCallEffectKind>(
6767
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
6868

69+
// high precision with IEEE-compliant
70+
TIR_DEFINE_TL_BUILTIN(ieee_add).set_num_inputs(3).set_attr<TCallEffectKind>(
71+
"TCallEffectKind", Integer(CallEffectKind::kPure));
72+
73+
TIR_DEFINE_TL_BUILTIN(ieee_sub).set_num_inputs(3).set_attr<TCallEffectKind>(
74+
"TCallEffectKind", Integer(CallEffectKind::kPure));
75+
76+
TIR_DEFINE_TL_BUILTIN(ieee_mul).set_num_inputs(3).set_attr<TCallEffectKind>(
77+
"TCallEffectKind", Integer(CallEffectKind::kPure));
78+
79+
TIR_DEFINE_TL_BUILTIN(ieee_fmaf).set_num_inputs(4).set_attr<TCallEffectKind>(
80+
"TCallEffectKind", Integer(CallEffectKind::kPure));
81+
82+
TIR_DEFINE_TL_BUILTIN(ieee_frcp).set_num_inputs(2).set_attr<TCallEffectKind>(
83+
"TCallEffectKind", Integer(CallEffectKind::kPure));
84+
85+
TIR_DEFINE_TL_BUILTIN(ieee_fsqrt)
86+
.set_num_inputs(2)
87+
.set_attr<TCallEffectKind>("TCallEffectKind",
88+
Integer(CallEffectKind::kPure));
89+
90+
TIR_DEFINE_TL_BUILTIN(ieee_frsqrt)
91+
.set_num_inputs(1)
92+
.set_attr<TCallEffectKind>("TCallEffectKind",
93+
Integer(CallEffectKind::kPure));
94+
95+
TIR_DEFINE_TL_BUILTIN(ieee_fdiv).set_num_inputs(3).set_attr<TCallEffectKind>(
96+
"TCallEffectKind", Integer(CallEffectKind::kPure));
97+
6998
TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier)
7099
.set_num_inputs(-1)
71100
.set_attr<TCallEffectKind>("TCallEffectKind",

src/op/builtin.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,41 @@ static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
9090
DataType cuTensorMapType();
9191

9292
// fast math related op
93+
// __exp(x) - fast exponential
9394
TVM_DLL const Op &__exp();
95+
// __exp10(x) - fast base-10 exponential
9496
TVM_DLL const Op &__exp10();
97+
// __log(x) - fast natural logarithm
9598
TVM_DLL const Op &__log();
99+
// __log2(x) - fast base-2 logarithm
96100
TVM_DLL const Op &__log2();
101+
// __log10(x) - fast base-10 logarithm
97102
TVM_DLL const Op &__log10();
103+
// __tan(x) - fast tangent
98104
TVM_DLL const Op &__tan();
105+
// __cos(x) - fast cosine
99106
TVM_DLL const Op &__cos();
107+
// __sin(x) - fast sine
100108
TVM_DLL const Op &__sin();
101109

110+
// high precision with IEEE-compliant.
111+
// ieee_add(x, y, rounding_mode) - IEEE-compliant addition
112+
TVM_DLL const Op &ieee_add();
113+
// ieee_sub(x, y, rounding_mode) - IEEE-compliant subtraction
114+
TVM_DLL const Op &ieee_sub();
115+
// ieee_mul(x, y, rounding_mode) - IEEE-compliant multiplication
116+
TVM_DLL const Op &ieee_mul();
117+
// ieee_fmaf(x, y, z, rounding_mode) - IEEE-compliant fused multiply-add
118+
TVM_DLL const Op &ieee_fmaf();
119+
// ieee_frcp(x, rounding_mode) - IEEE-compliant reciprocal
120+
TVM_DLL const Op &ieee_frcp();
121+
// ieee_fsqrt(x, rounding_mode) - IEEE-compliant square root
122+
TVM_DLL const Op &ieee_fsqrt();
123+
// ieee_frsqrt(x) - IEEE-compliant reciprocal square root (rn only)
124+
TVM_DLL const Op &ieee_frsqrt();
125+
// ieee_fdiv(x, y, rounding_mode) - IEEE-compliant division
126+
TVM_DLL const Op &ieee_fdiv();
127+
102128
/*!
103129
* \brief tvm intrinsics for TMADescriptor creation for tiled load
104130
*

src/target/codegen_cuda.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,18 @@ struct CUDAFastMathTan : public CUDAMath {
9494
}
9595
};
9696

97+
struct CUDAIEEEMath {
98+
std::string operator()(DataType t, std::string name,
99+
std::string rounding_mode) const {
100+
if (t.is_float() && t.bits() == 32) {
101+
return "__" + name + "_" + rounding_mode;
102+
} else if (t.is_float() && t.bits() == 64) {
103+
return "__d" + name + "_" + rounding_mode;
104+
}
105+
return "";
106+
}
107+
};
108+
97109
static std::string GetFP8Type(DataType type) {
98110
std::stringstream stream;
99111
int32_t lanes = type.lanes();
@@ -1733,6 +1745,50 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
17331745
CUDAFastMath math_func;
17341746
std::string func_name = math_func(op->dtype, "sin");
17351747
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
1748+
} else if (op->op.same_as(tl::ieee_add())) {
1749+
CUDAIEEEMath math_func;
1750+
std::string rounding_mode = Downcast<StringImm>(op->args[2])->value;
1751+
std::string func_name = math_func(op->dtype, "fadd", rounding_mode);
1752+
os << func_name << "(" << PrintExpr(op->args[0]) << ", "
1753+
<< PrintExpr(op->args[1]) << ")";
1754+
} else if (op->op.same_as(tl::ieee_sub())) {
1755+
CUDAIEEEMath math_func;
1756+
std::string rounding_mode = Downcast<StringImm>(op->args[2])->value;
1757+
std::string func_name = math_func(op->dtype, "fsub", rounding_mode);
1758+
os << func_name << "(" << PrintExpr(op->args[0]) << ", "
1759+
<< PrintExpr(op->args[1]) << ")";
1760+
} else if (op->op.same_as(tl::ieee_mul())) {
1761+
CUDAIEEEMath math_func;
1762+
std::string rounding_mode = Downcast<StringImm>(op->args[2])->value;
1763+
std::string func_name = math_func(op->dtype, "fmul", rounding_mode);
1764+
os << func_name << "(" << PrintExpr(op->args[0]) << ", "
1765+
<< PrintExpr(op->args[1]) << ")";
1766+
} else if (op->op.same_as(tl::ieee_fmaf())) {
1767+
CUDAIEEEMath math_func;
1768+
std::string rounding_mode = Downcast<StringImm>(op->args[3])->value;
1769+
std::string func_name = math_func(op->dtype, "fmaf", rounding_mode);
1770+
os << func_name << "(" << PrintExpr(op->args[0]) << ", "
1771+
<< PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ")";
1772+
} else if (op->op.same_as(tl::ieee_frcp())) {
1773+
CUDAIEEEMath math_func;
1774+
std::string rounding_mode = Downcast<StringImm>(op->args[1])->value;
1775+
std::string func_name = math_func(op->dtype, "frcp", rounding_mode);
1776+
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
1777+
} else if (op->op.same_as(tl::ieee_fsqrt())) {
1778+
CUDAIEEEMath math_func;
1779+
std::string rounding_mode = Downcast<StringImm>(op->args[1])->value;
1780+
std::string func_name = math_func(op->dtype, "fsqrt", rounding_mode);
1781+
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
1782+
} else if (op->op.same_as(tl::ieee_frsqrt())) {
1783+
CUDAIEEEMath math_func;
1784+
std::string func_name = math_func(op->dtype, "frsqrt", "rn");
1785+
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
1786+
} else if (op->op.same_as(tl::ieee_fdiv())) {
1787+
CUDAIEEEMath math_func;
1788+
std::string rounding_mode = Downcast<StringImm>(op->args[2])->value;
1789+
std::string func_name = math_func(op->dtype, "fdiv", rounding_mode);
1790+
os << func_name << "(" << PrintExpr(op->args[0]) << ", "
1791+
<< PrintExpr(op->args[1]) << ")";
17361792
} else {
17371793
CodeGenC::VisitExpr_(op, os);
17381794
}
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import tilelang
2+
import tilelang.language as T
3+
import torch
4+
import tilelang.testing
5+
import pytest
6+
7+
8+
def run_ieee_math_test(mathop_name,
9+
mathop_func,
10+
rounding_mode="rn",
11+
M=128,
12+
N=128,
13+
block_M=32,
14+
block_N=32,
15+
dtype="float32"):
16+
"""
17+
Test IEEE-compliant math operations with specified rounding modes.
18+
"""
19+
20+
# Define the appropriate function based on operation type to avoid TVM parsing conflicts
21+
if mathop_name == "ieee_fmaf":
22+
23+
@T.prim_func
24+
def main_func(
25+
A: T.Tensor((M, N), dtype),
26+
B: T.Tensor((M, N), dtype),
27+
C: T.Tensor((M, N), dtype),
28+
D: T.Tensor((M, N), dtype),
29+
):
30+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
31+
for i, j in T.Parallel(block_M, block_N):
32+
D[by * block_M + i,
33+
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
34+
B[by * block_M + i, bx * block_N + j],
35+
C[by * block_M + i,
36+
bx * block_N + j], rounding_mode)
37+
38+
out_idx = [3]
39+
num_inputs = 3
40+
elif mathop_name in ["ieee_add", "ieee_sub", "ieee_mul", "ieee_fdiv"]:
41+
42+
@T.prim_func
43+
def main_func(
44+
A: T.Tensor((M, N), dtype),
45+
B: T.Tensor((M, N), dtype),
46+
C: T.Tensor((M, N), dtype),
47+
):
48+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
49+
for i, j in T.Parallel(block_M, block_N):
50+
C[by * block_M + i,
51+
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
52+
B[by * block_M + i,
53+
bx * block_N + j], rounding_mode)
54+
55+
out_idx = [2]
56+
num_inputs = 2
57+
else: # Single argument operations
58+
59+
@T.prim_func
60+
def main_func(
61+
A: T.Tensor((M, N), dtype),
62+
B: T.Tensor((M, N), dtype),
63+
):
64+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
65+
for i, j in T.Parallel(block_M, block_N):
66+
B[by * block_M + i,
67+
bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j],
68+
rounding_mode)
69+
70+
out_idx = [1]
71+
num_inputs = 1
72+
73+
# Test compilation
74+
kernel = tilelang.compile(
75+
main_func,
76+
out_idx=out_idx,
77+
target="cuda",
78+
pass_configs={
79+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
80+
})
81+
82+
print(f"\n=== Testing {mathop_name} with rounding mode {rounding_mode} ===")
83+
print(f"✓ {mathop_name} compilation test passed")
84+
85+
# Test numerical execution
86+
torch_dtype = getattr(torch, dtype)
87+
a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
88+
89+
if num_inputs >= 2:
90+
b = torch.randn(M, N, device="cuda", dtype=torch_dtype)
91+
if num_inputs == 3:
92+
c = torch.randn(M, N, device="cuda", dtype=torch_dtype)
93+
94+
# Ensure positive values for functions that need them
95+
if mathop_name in ["ieee_frcp", "ieee_fsqrt"]:
96+
a = torch.abs(a) + 0.1
97+
elif mathop_name == "ieee_fdiv":
98+
b = torch.abs(b) + 0.1 # Avoid division by zero
99+
100+
# Execute kernel
101+
try:
102+
if num_inputs == 1:
103+
result = kernel(a)
104+
elif num_inputs == 2:
105+
result = kernel(a, b)
106+
else: # num_inputs == 3
107+
result = kernel(a, b, c)
108+
109+
assert result is not None
110+
print(f"✓ {mathop_name} numerical execution test passed")
111+
except Exception as e:
112+
print(f"Warning: {mathop_name} execution failed: {e}")
113+
114+
115+
def test_rounding_mode_validation():
116+
"""Test that invalid rounding modes raise ValueError"""
117+
118+
# Test with invalid rounding mode
119+
with pytest.raises(ValueError, match="Invalid rounding mode"):
120+
T.ieee_add(1.0, 2.0, "invalid_mode")
121+
122+
with pytest.raises(ValueError, match="Invalid rounding mode"):
123+
T.ieee_mul(1.0, 2.0, "xy")
124+
125+
with pytest.raises(ValueError, match="Invalid rounding mode"):
126+
T.ieee_fsqrt(4.0, "bad_mode")
127+
128+
print("✓ Rounding mode validation test passed")
129+
130+
131+
@tilelang.testing.requires_cuda
132+
def test_ieee_add_all_rounding_modes():
133+
"""Test IEEE addition with all rounding modes"""
134+
rounding_modes = ["rn", "rz", "ru", "rd"]
135+
136+
for mode in rounding_modes:
137+
run_ieee_math_test("ieee_add", T.ieee_add, rounding_mode=mode)
138+
print(f"✓ ieee_add with {mode} passed")
139+
140+
141+
@tilelang.testing.requires_cuda
142+
def test_ieee_sub_all_rounding_modes():
143+
"""Test IEEE subtraction with all rounding modes"""
144+
rounding_modes = ["rn", "rz", "ru", "rd"]
145+
146+
for mode in rounding_modes:
147+
run_ieee_math_test("ieee_sub", T.ieee_sub, rounding_mode=mode)
148+
print(f"✓ ieee_sub with {mode} passed")
149+
150+
151+
@tilelang.testing.requires_cuda
152+
def test_ieee_mul_all_rounding_modes():
153+
"""Test IEEE multiplication with all rounding modes"""
154+
rounding_modes = ["rn", "rz", "ru", "rd"]
155+
156+
for mode in rounding_modes:
157+
run_ieee_math_test("ieee_mul", T.ieee_mul, rounding_mode=mode)
158+
print(f"✓ ieee_mul with {mode} passed")
159+
160+
161+
@tilelang.testing.requires_cuda
162+
def test_ieee_fmaf_all_rounding_modes():
163+
"""Test IEEE fused multiply-add with all rounding modes"""
164+
rounding_modes = ["rn", "rz", "ru", "rd"]
165+
166+
for mode in rounding_modes:
167+
run_ieee_math_test("ieee_fmaf", T.ieee_fmaf, rounding_mode=mode)
168+
print(f"✓ ieee_fmaf with {mode} passed")
169+
170+
171+
@tilelang.testing.requires_cuda
172+
def test_ieee_frcp_all_rounding_modes():
173+
"""Test IEEE reciprocal with all rounding modes"""
174+
rounding_modes = ["rn", "rz", "ru", "rd"]
175+
176+
for mode in rounding_modes:
177+
run_ieee_math_test("ieee_frcp", T.ieee_frcp, rounding_mode=mode)
178+
print(f"✓ ieee_frcp with {mode} passed")
179+
180+
181+
@tilelang.testing.requires_cuda
182+
def test_ieee_fsqrt_all_rounding_modes():
183+
"""Test IEEE square root with all rounding modes"""
184+
rounding_modes = ["rn", "rz", "ru", "rd"]
185+
186+
for mode in rounding_modes:
187+
run_ieee_math_test("ieee_fsqrt", T.ieee_fsqrt, rounding_mode=mode)
188+
print(f"✓ ieee_fsqrt with {mode} passed")
189+
190+
191+
@tilelang.testing.requires_cuda
192+
def test_ieee_frsqrt_rn_only():
193+
"""Test IEEE reciprocal square root (round to nearest only)"""
194+
195+
@T.prim_func
196+
def main(
197+
A: T.Tensor((128, 128), "float32"),
198+
B: T.Tensor((128, 128), "float32"),
199+
):
200+
with T.Kernel(T.ceildiv(128, 32), T.ceildiv(128, 32), threads=128) as (bx, by):
201+
for i, j in T.Parallel(32, 32):
202+
B[by * 32 + i, bx * 32 + j] = T.ieee_frsqrt(A[by * 32 + i, bx * 32 + j])
203+
204+
kernel = tilelang.compile(
205+
main,
206+
out_idx=[1],
207+
target="cuda",
208+
pass_configs={
209+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
210+
})
211+
212+
print("\n=== Testing ieee_frsqrt (rn only) ===")
213+
print("✓ ieee_frsqrt compilation test passed")
214+
215+
# Test numerical execution
216+
a = torch.abs(torch.randn(128, 128, device="cuda", dtype=torch.float32)) + 0.1
217+
218+
try:
219+
result = kernel(a)
220+
assert result is not None
221+
print("✓ ieee_frsqrt numerical execution test passed")
222+
except Exception as e:
223+
print(f"Warning: ieee_frsqrt execution failed: {e}")
224+
225+
226+
@tilelang.testing.requires_cuda
227+
def test_ieee_fdiv_all_rounding_modes():
228+
"""Test IEEE division with all rounding modes"""
229+
rounding_modes = ["rn", "rz", "ru", "rd"]
230+
231+
for mode in rounding_modes:
232+
run_ieee_math_test("ieee_fdiv", T.ieee_fdiv, rounding_mode=mode)
233+
print(f"✓ ieee_fdiv with {mode} passed")
234+
235+
236+
if __name__ == "__main__":
237+
tilelang.testing.main()

0 commit comments

Comments
 (0)