Skip to content

Commit 396cca4

Browse files
committed
[Feature] Refactor bf16 convertion operations and remove legacy compile flags
1 parent be16c15 commit 396cca4

File tree

7 files changed

+62
-92
lines changed

7 files changed

+62
-92
lines changed

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def get_bwd_configs():
2323
out_idx=[3, 4],
2424
pass_configs={
2525
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
26-
},
27-
compile_flags=["-O3", "-DENABLE_BF16"])
26+
})
2827
def flashattn_fwd(
2928
batch,
3029
heads,
@@ -143,8 +142,7 @@ def flash_fwd(
143142
out_idx=[2],
144143
pass_configs={
145144
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
146-
},
147-
compile_flags=["-O3", "-DENABLE_BF16"])
145+
})
148146
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
149147
accum_dtype = "float"
150148
shape = [batch, heads, seq_len, dim]
@@ -183,8 +181,7 @@ def make_dq_layout(dQ):
183181
out_idx=[1],
184182
pass_configs={
185183
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
186-
},
187-
compile_flags=["-O3", "-DENABLE_BF16"])
184+
})
188185
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
189186
accum_dtype = "float"
190187
shape = [batch, heads, seq_len, dim]
@@ -208,8 +205,7 @@ def flash_bwd_post(
208205
@tilelang.jit(
209206
pass_configs={
210207
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
211-
},
212-
compile_flags=["-O3", "-DENABLE_BF16"])
208+
})
213209
def flashattn_bwd(batch,
214210
heads,
215211
seq_len,

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ def get_configs():
2626
out_idx=[3],
2727
pass_configs={
2828
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
29-
},
30-
compile_flags=["-O3", "-DENABLE_BF16"])
29+
})
3130
def flashattn(
3231
batch,
3332
heads,

examples/attention_sink/example_mha_sink_bwd_bhsd.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def get_bwd_configs():
2323
out_idx=[3, 4],
2424
pass_configs={
2525
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
26-
},
27-
compile_flags=["-O3", "-DENABLE_BF16"])
26+
})
2827
def flashattn_fwd(
2928
batch,
3029
heads,
@@ -140,8 +139,7 @@ def flash_fwd(
140139
out_idx=[2],
141140
pass_configs={
142141
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
143-
},
144-
compile_flags=["-O3", "-DENABLE_BF16"])
142+
})
145143
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
146144
accum_dtype = "float"
147145
shape = [batch, heads, seq_len, dim]
@@ -180,8 +178,7 @@ def make_dq_layout(dQ):
180178
out_idx=[1],
181179
pass_configs={
182180
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
183-
},
184-
compile_flags=["-O3", "-DENABLE_BF16"])
181+
})
185182
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
186183
accum_dtype = "float"
187184
shape = [batch, heads, seq_len, dim]
@@ -205,8 +202,7 @@ def flash_bwd_post(
205202
@tilelang.jit(
206203
pass_configs={
207204
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
208-
},
209-
compile_flags=["-O3", "-DENABLE_BF16"])
205+
})
210206
def flashattn_bwd(
211207
batch,
212208
heads,

examples/attention_sink/example_mha_sink_fwd_bhsd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ def get_configs():
2121
out_idx=[3],
2222
pass_configs={
2323
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
24-
},
25-
compile_flags=["-O3", "-DENABLE_BF16"])
24+
})
2625
def flashattn(
2726
batch,
2827
heads,

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ def get_configs():
2222
out_idx=[3],
2323
pass_configs={
2424
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
25-
},
26-
compile_flags=["-O3", "-DENABLE_BF16"])
25+
})
2726
def flashattn(
2827
batch,
2928
heads,

src/target/codegen_cuda.cc

Lines changed: 42 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,48 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
942942
}
943943
}
944944

945+
// Handle conversion between bfloat16 and float32
946+
if (from_ty.is_bfloat16() && target_ty.is_float()) {
947+
// Use __bfloat1622float2 for vectorized conversion (bfloat162 -> float2)
948+
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
949+
// bfloat162 -> float2
950+
PrintIndent();
951+
stream << sret << " = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&(" << src << ")));\n";
952+
os << sret;
953+
return;
954+
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
955+
// bfloat162x2 -> float4
956+
PrintIndent();
957+
stream << "((float2*)(&" << sret << "))[0] = "
958+
<< "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&(" << src << ")));\n";
959+
PrintIndent();
960+
stream << "((float2*)(&" << sret << "))[1] = "
961+
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&(" << src << "))+1));\n";
962+
os << sret;
963+
return;
964+
}
965+
} else if (from_ty.is_float() && target_ty.is_bfloat16()) {
966+
// Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162)
967+
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
968+
// float2 -> bfloat162
969+
PrintIndent();
970+
stream << "*reinterpret_cast<__nv_bfloat162*>(&(" << sret << ")) = __float22bfloat162_rn(*(float2*)(&("
971+
<< src << ")));\n";
972+
os << sret;
973+
return;
974+
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
975+
// float4 -> bfloat162x2
976+
PrintIndent();
977+
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = "
978+
<< "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
979+
PrintIndent();
980+
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = "
981+
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
982+
os << sret;
983+
return;
984+
}
985+
}
986+
945987
// Handle conversion from float32 to float8 (E4M3/E5M2)
946988
if (from_ty.is_float() &&
947989
(target_ty.is_float8_e4m3() || target_ty.is_float8_e5m2())) {
@@ -974,63 +1016,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
9741016
}
9751017
}
9761018

977-
// Handle bfloat16 special cases with supported ops
978-
// NOTE(wt): Currently bf16 related ops don't support lanes=4,
979-
// We should add this in the future.
980-
bool used_bf16_op = false;
981-
if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) {
982-
std::ostringstream func_name;
983-
if (from_ty.is_bfloat16()) {
984-
func_name << "bf16";
985-
} else if (from_ty.is_float()) {
986-
func_name << "float";
987-
}
988-
if (from_ty.lanes() > 1) {
989-
func_name << from_ty.lanes();
990-
}
991-
func_name << "2";
992-
if (target_ty.is_bfloat16()) {
993-
func_name << "bf16";
994-
} else if (target_ty.is_float()) {
995-
func_name << "float";
996-
} else if (target_ty == DataType::Int(16)) {
997-
func_name << "int16";
998-
}
999-
if (target_ty.lanes() > 1) {
1000-
func_name << target_ty.lanes();
1001-
}
1002-
1003-
auto fname = func_name.str();
1004-
if (bf16_supported_ops_.count(fname)) {
1005-
used_bf16_op = true;
1006-
stream << "#ifdef ENABLE_BF16\n";
1007-
PrintIndent();
1008-
stream << "reinterpret_cast<";
1009-
if (target_ty.is_bfloat16()) {
1010-
stream << "__nv_bfloat16";
1011-
} else {
1012-
PrintType(target_ty.element_of(), stream);
1013-
}
1014-
if (target_ty.lanes() > 1) {
1015-
stream << target_ty.lanes();
1016-
}
1017-
stream << " &>(" << sret << ") = fastertransformer::" << fname
1018-
<< "(reinterpret_cast<";
1019-
if (from_ty.is_bfloat16()) {
1020-
stream << "__nv_bfloat16";
1021-
} else {
1022-
PrintType(from_ty.element_of(), stream);
1023-
}
1024-
if (from_ty.lanes() > 1) {
1025-
stream << from_ty.lanes();
1026-
}
1027-
stream << " const &>(" << src << "));\n";
1028-
stream << "#else\n";
1029-
// bf16 cases don't need early return, as we use elementwise cast as
1030-
// fallback
1031-
}
1032-
}
1033-
10341019
// Fallback: elementwise cast
10351020
for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
10361021
std::ostringstream val;
@@ -1042,9 +1027,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
10421027
PrintVecElemStore(sret, target_ty, i, val.str());
10431028
}
10441029

1045-
if (used_bf16_op) {
1046-
stream << "#endif\n";
1047-
}
10481030
os << sret;
10491031
}
10501032

testing/python/language/test_tilelang_language_vectorized_cast.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
}
1212

1313

14-
@tilelang.jit(compile_flags=['-DENABLE_BF16'])
14+
@tilelang.jit
1515
def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
1616
assert M % 256 == 0
1717

@@ -57,28 +57,27 @@ def test_vectorized_cast():
5757
run_vectorized_cast("float32", "float16", "__float22half2_rn", 2)
5858
run_vectorized_cast("float32", "float16", "__float22half2_rn", 4)
5959

60-
# # fp16 -> fp32
60+
# fp16 -> fp32
6161
run_vectorized_cast("float16", "float32", "__half22float2", 2)
6262
run_vectorized_cast("float16", "float32", "__half22float2", 4)
6363

64-
# # fp32 -> fp8_e4m3
64+
# fp32 -> fp8_e4m3
6565
run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2)
6666
run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4)
6767

68-
# # fp32 -> fp8_e5m2
68+
# fp32 -> fp8_e5m2
6969
run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2)
7070
run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4)
7171

7272
# fp32 -> bf16
73-
# NOTE(wt): currently bf16 related ops don't support lanes=4,
74-
# We will add this in the future.
75-
run_vectorized_cast("float32", "bfloat16", "fastertransformer", 2)
76-
# run_vectorized_cast("float32", "bfloat16", "fastertransformer", 4)
73+
run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 2)
74+
run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 4)
7775

7876
# bf16 -> fp32
79-
run_vectorized_cast("bfloat16", "float32", "fastertransformer", 2)
80-
# run_vectorized_cast("bfloat16", "float32", "fastertransformer", 4)
77+
run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 2)
78+
run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 4)
8179

8280

8381
if __name__ == "__main__":
8482
tilelang.testing.main()
83+

0 commit comments

Comments
 (0)