Skip to content

Commit 50d4d5d

Browse files
committed
[Enhancement] Support lanes=4 cases and add unit test for vectorized cast
1 parent 0dd40d4 commit 50d4d5d

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

src/target/codegen_cuda.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,19 +904,41 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
904904
if (from_ty.is_float16() && target_ty.is_float()) {
905905
// Use __half22float2 for vectorized conversion (half2 -> float2)
906906
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
907+
// half2 -> float2
907908
PrintIndent();
908909
stream << sret << " = __half22float2(*(half2*)(&(" << src << ")));\n";
909910
os << sret;
910911
return;
912+
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
913+
// half4 -> float4
914+
PrintIndent();
915+
stream << "((float2*)(&" << sret << "))[0] = "
916+
<< "__half22float2(*(half2*)(&(" << src << ")));\n";
917+
PrintIndent();
918+
stream << "((float2*)(&" << sret << "))[1] = "
919+
<< "__half22float2(*((half2*)(&(" << src << "))+1));\n";
920+
os << sret;
921+
return;
911922
}
912923
} else if (from_ty.is_float() && target_ty.is_float16()) {
913924
// Use __float22half2_rn for vectorized conversion (float2 -> half2)
914925
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
926+
// float2 -> half2
915927
PrintIndent();
916928
stream << "*(half2*)(&(" << sret << ")) = __float22half2_rn(*(float2*)(&("
917929
<< src << ")));\n";
918930
os << sret;
919931
return;
932+
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
933+
// float4 -> half4
934+
PrintIndent();
935+
stream << "((half2*)(&" << sret << "))[0] = "
936+
<< "__float22half2_rn(*(float2*)(&(" << src << ")));\n";
937+
PrintIndent();
938+
stream << "((half2*)(&" << sret << "))[1] = "
939+
<< "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
940+
os << sret;
941+
return;
920942
}
921943
}
922944

@@ -926,6 +948,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
926948
// FP32 -> FP8: Use __nv_cvt_float2_to_fp8x2 for vectorized conversion
927949
// (float2 -> fp8x2)
928950
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
951+
// float2 -> fp8x2
929952
PrintIndent();
930953
stream << "*reinterpret_cast<__nv_fp8x2_storage_t*>(&(" << sret
931954
<< ")) = __nv_cvt_float2_to_fp8x2(*reinterpret_cast<float2*>(&("
@@ -934,10 +957,26 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
934957
<< ");\n";
935958
os << sret;
936959
return;
960+
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
961+
// float4 -> fp8x4
962+
PrintIndent();
963+
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
964+
<< "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
965+
<< ")), __NV_SATFINITE, "
966+
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
967+
<< ");\n";
968+
PrintIndent();
969+
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
970+
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
971+
<< "))+1), __NV_SATFINITE, "
972+
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
973+
<< ");\n";
937974
}
938975
}
939976

940977
// Handle bfloat16 special cases with supported ops
978+
// NOTE(wt): Currently bf16 related ops don't support lanes=4,
979+
// We should add this in the future.
941980
bool used_bf16_op = false;
942981
if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) {
943982
std::ostringstream func_name;
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import torch
2+
import tilelang.testing
3+
import tilelang.language as T
4+
5+
6+
str2dtype = {
7+
"float32": torch.float32,
8+
"float16": torch.float16,
9+
"bfloat16": torch.bfloat16,
10+
"float8_e4m3": torch.float8_e4m3fn,
11+
"float8_e5m2": torch.float8_e5m2,
12+
}
13+
14+
15+
@tilelang.jit(compile_flags=['-DENABLE_BF16'])
16+
def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
17+
assert M % 256 == 0
18+
19+
@T.prim_func
20+
def main(
21+
A: T.Tensor[(M), dtype_A], # noqa: F821
22+
B: T.Tensor[(M), dtype_B], # noqa: F821
23+
):
24+
with T.Kernel(1, threads=128):
25+
T.copy(A, B)
26+
27+
return main
28+
29+
30+
def run_vectorized_cast(
31+
src_dtype_str: str,
32+
dst_dtype_str: str,
33+
check_str: str,
34+
lanes: int = 2
35+
):
36+
"""Run the vectorized cast kernel and check the correctness.
37+
Args:
38+
src_dtype_str: The source data type string.
39+
dst_dtype_str: The destination data type string.
40+
check_str: Used to ensure vectorized cast is used.
41+
M: The size of the tensor.
42+
lanes: The number of lanes of the source and destination data types.
43+
"""
44+
45+
M = 128 * lanes
46+
kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
47+
48+
A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda()
49+
B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
50+
51+
kernel(A, B)
52+
53+
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B)
54+
55+
code = kernel.get_kernel_source()
56+
57+
assert check_str in code, \
58+
f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
59+
60+
61+
def test_vectorized_cast():
62+
# fp32 -> fp16
63+
run_vectorized_cast("float32", "float16", "__float22half2_rn", 2)
64+
run_vectorized_cast("float32", "float16", "__float22half2_rn", 4)
65+
66+
# # fp16 -> fp32
67+
run_vectorized_cast("float16", "float32", "__half22float2", 2)
68+
run_vectorized_cast("float16", "float32", "__half22float2", 4)
69+
70+
# # fp32 -> fp8_e4m3
71+
run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2)
72+
run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4)
73+
74+
# # fp32 -> fp8_e5m2
75+
run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2)
76+
run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4)
77+
78+
# fp32 -> bf16
79+
# NOTE(wt): currently bf16 related ops don't support lanes=4,
80+
# We will add this in the future.
81+
run_vectorized_cast("float32", "bfloat16", "fastertransformer", 2)
82+
# run_vectorized_cast("float32", "bfloat16", "fastertransformer", 4)
83+
84+
# bf16 -> fp32
85+
run_vectorized_cast("bfloat16", "float32", "fastertransformer", 2)
86+
# run_vectorized_cast("bfloat16", "float32", "fastertransformer", 4)
87+
88+
89+
if __name__ == "__main__":
90+
tilelang.testing.main()

0 commit comments

Comments
 (0)