Skip to content

Commit feef9ef

Browse files
authored
[Enhancement] Enhance Cast operations Vectorization (#1156)
* Enhance Cast vectorized * Add Parallel vectorized cast test * code lint * merge newest commit
1 parent 198f22b commit feef9ef

File tree

3 files changed

+140
-5
lines changed

3 files changed

+140
-5
lines changed

src/target/codegen_cuda.cc

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
919919
<< "__half22float2(*((half2*)(&(" << src << "))+1));\n";
920920
os << sret;
921921
return;
922+
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
923+
// half8 -> float8
924+
PrintIndent();
925+
stream << "((float2*)(&" << sret << "))[0] = "
926+
<< "__half22float2(*(half2*)(&(" << src << ")));\n";
927+
PrintIndent();
928+
stream << "((float2*)(&" << sret << "))[1] = "
929+
<< "__half22float2(*((half2*)(&(" << src << "))+1));\n";
930+
PrintIndent();
931+
stream << "((float2*)(&" << sret << "))[2] = "
932+
<< "__half22float2(*((half2*)(&(" << src << "))+2));\n";
933+
PrintIndent();
934+
stream << "((float2*)(&" << sret << "))[3] = "
935+
<< "__half22float2(*((half2*)(&(" << src << "))+3));\n";
936+
os << sret;
937+
return;
922938
}
923939
} else if (from_ty.is_float() && target_ty.is_float16()) {
924940
// Use __float22half2_rn for vectorized conversion (float2 -> half2)
@@ -939,6 +955,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
939955
<< "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
940956
os << sret;
941957
return;
958+
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
959+
// float8 -> half8
960+
PrintIndent();
961+
stream << "((half2*)(&" << sret << "))[0] = "
962+
<< "__float22half2_rn(*(float2*)(&(" << src << ")));\n";
963+
PrintIndent();
964+
stream << "((half2*)(&" << sret << "))[1] = "
965+
<< "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
966+
PrintIndent();
967+
stream << "((half2*)(&" << sret << "))[2] = "
968+
<< "__float22half2_rn(*((float2*)(&(" << src << "))+2));\n";
969+
PrintIndent();
970+
stream << "((half2*)(&" << sret << "))[3] = "
971+
<< "__float22half2_rn(*((float2*)(&(" << src << "))+3));\n";
972+
os << sret;
973+
return;
942974
}
943975
}
944976

@@ -965,6 +997,26 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
965997
<< src << "))+1));\n";
966998
os << sret;
967999
return;
1000+
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
1001+
// bfloat162x4 -> float8
1002+
PrintIndent();
1003+
stream << "((float2*)(&" << sret << "))[0] = "
1004+
<< "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
1005+
<< src << ")));\n";
1006+
PrintIndent();
1007+
stream << "((float2*)(&" << sret << "))[1] = "
1008+
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
1009+
<< src << "))+1));\n";
1010+
PrintIndent();
1011+
stream << "((float2*)(&" << sret << "))[2] = "
1012+
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
1013+
<< src << "))+2));\n";
1014+
PrintIndent();
1015+
stream << "((float2*)(&" << sret << "))[3] = "
1016+
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
1017+
<< src << "))+3));\n";
1018+
os << sret;
1019+
return;
9681020
}
9691021
} else if (from_ty.is_float() && target_ty.is_bfloat16()) {
9701022
// Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162)
@@ -985,6 +1037,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
9851037
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
9861038
os << sret;
9871039
return;
1040+
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
1041+
// float8 -> bfloat162x4
1042+
PrintIndent();
1043+
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = "
1044+
<< "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
1045+
PrintIndent();
1046+
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = "
1047+
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
1048+
PrintIndent();
1049+
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[2] = "
1050+
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+2));\n";
1051+
PrintIndent();
1052+
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[3] = "
1053+
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+3));\n";
1054+
os << sret;
1055+
return;
9881056
}
9891057
}
9901058

@@ -1019,6 +1087,34 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
10191087
<< ");\n";
10201088
os << sret;
10211089
return;
1090+
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
1091+
// float8 -> fp8x8
1092+
PrintIndent();
1093+
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
1094+
<< "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
1095+
<< ")), __NV_SATFINITE, "
1096+
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
1097+
<< ");\n";
1098+
PrintIndent();
1099+
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
1100+
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
1101+
<< "))+1), __NV_SATFINITE, "
1102+
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
1103+
<< ");\n";
1104+
PrintIndent();
1105+
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[2] = "
1106+
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
1107+
<< "))+2), __NV_SATFINITE, "
1108+
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
1109+
<< ");\n";
1110+
PrintIndent();
1111+
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[3] = "
1112+
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
1113+
<< "))+3), __NV_SATFINITE, "
1114+
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
1115+
<< ");\n";
1116+
os << sret;
1117+
return;
10221118
}
10231119
}
10241120

src/transform/layout_inference.cc

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,9 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
597597
}
598598
}
599599
// Update the best plan if this one uses fewer registers
600-
if (reg_num < min_reg_num) {
600+
if (reg_num < min_reg_num ||
601+
(reg_num == min_reg_num &&
602+
attempt_infer_root < min_reg_num_infer_root)) {
601603
best_infer_list =
602604
BackupInferList(); // Use backup to avoid moving out infer_list_
603605
best_layout_map = tmp_layout_map;
@@ -787,7 +789,18 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
787789
}
788790
});
789791

790-
if (has_non_local && !has_reducer) {
792+
// If a cast operation exists, vectorization may still be required
793+
bool has_cast_operations = false;
794+
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
795+
if (const auto *store = obj.as<BufferStoreNode>()) {
796+
// Check if this is a non-reducer store with Cast operation
797+
if (store->value.as<CastNode>()) {
798+
has_cast_operations = true;
799+
}
800+
}
801+
});
802+
803+
if ((has_non_local || has_cast_operations) && !has_reducer) {
791804
for_node = VectorizeLoop(for_node);
792805
}
793806

testing/python/language/test_tilelang_language_vectorized_cast.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,36 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
1717

1818
@T.prim_func
1919
def main(
20-
A: T.Tensor[(M), dtype_A], # noqa: F821
21-
B: T.Tensor[(M), dtype_B], # noqa: F821
20+
A: T.Tensor[(M,), dtype_A], # noqa: F821
21+
B: T.Tensor[(M,), dtype_B], # noqa: F821
2222
):
2323
with T.Kernel(1, threads=128):
2424
T.copy(A, B)
2525

2626
return main
2727

2828

29+
@tilelang.jit
30+
def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
31+
assert M % 256 == 0
32+
33+
@T.prim_func
34+
def main(
35+
A: T.Tensor[(M,), dtype_A], # noqa: F821
36+
B: T.Tensor[(M,), dtype_B], # noqa: F821
37+
):
38+
with T.Kernel(1, threads=128):
39+
A_local = T.alloc_fragment((M,), dtype_A)
40+
B_local = T.alloc_fragment((M,), dtype_B)
41+
42+
T.copy(A, A_local)
43+
for i in T.Parallel(M):
44+
B_local[i] = A_local[i]
45+
T.copy(B_local, B)
46+
47+
return main
48+
49+
2950
def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, lanes: int = 2):
3051
"""Run the vectorized cast kernel and check the correctness.
3152
Args:
@@ -37,17 +58,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
3758

3859
M = 128 * lanes
3960
kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
61+
kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
4062

4163
A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda()
4264
B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
65+
C = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
4366

4467
kernel(A, B)
68+
kernel_parallel(A, C)
4569

4670
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B)
71+
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), C)
4772

4873
code = kernel.get_kernel_source()
74+
code_parallel = kernel_parallel.get_kernel_source()
4975

50-
assert check_str in code, \
76+
assert check_str in code and check_str in code_parallel, \
5177
f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
5278

5379

0 commit comments

Comments
 (0)