@@ -919,6 +919,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
919919 << " __half22float2(*((half2*)(&(" << src << " ))+1));\n " ;
920920 os << sret;
921921 return ;
922+ } else if (from_ty.lanes () == 8 && target_ty.lanes () == 8 ) {
923+ // half8 -> float8
924+ PrintIndent ();
925+ stream << " ((float2*)(&" << sret << " ))[0] = "
926+ << " __half22float2(*(half2*)(&(" << src << " )));\n " ;
927+ PrintIndent ();
928+ stream << " ((float2*)(&" << sret << " ))[1] = "
929+ << " __half22float2(*((half2*)(&(" << src << " ))+1));\n " ;
930+ PrintIndent ();
931+ stream << " ((float2*)(&" << sret << " ))[2] = "
932+ << " __half22float2(*((half2*)(&(" << src << " ))+2));\n " ;
933+ PrintIndent ();
934+ stream << " ((float2*)(&" << sret << " ))[3] = "
935+ << " __half22float2(*((half2*)(&(" << src << " ))+3));\n " ;
936+ os << sret;
937+ return ;
922938 }
923939 } else if (from_ty.is_float () && target_ty.is_float16 ()) {
924940 // Use __float22half2_rn for vectorized conversion (float2 -> half2)
@@ -939,6 +955,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
939955 << " __float22half2_rn(*((float2*)(&(" << src << " ))+1));\n " ;
940956 os << sret;
941957 return ;
958+ } else if (from_ty.lanes () == 8 && target_ty.lanes () == 8 ) {
959+ // float8 -> half8
960+ PrintIndent ();
961+ stream << " ((half2*)(&" << sret << " ))[0] = "
962+ << " __float22half2_rn(*(float2*)(&(" << src << " )));\n " ;
963+ PrintIndent ();
964+ stream << " ((half2*)(&" << sret << " ))[1] = "
965+ << " __float22half2_rn(*((float2*)(&(" << src << " ))+1));\n " ;
966+ PrintIndent ();
967+ stream << " ((half2*)(&" << sret << " ))[2] = "
968+ << " __float22half2_rn(*((float2*)(&(" << src << " ))+2));\n " ;
969+ PrintIndent ();
970+ stream << " ((half2*)(&" << sret << " ))[3] = "
971+ << " __float22half2_rn(*((float2*)(&(" << src << " ))+3));\n " ;
972+ os << sret;
973+ return ;
942974 }
943975 }
944976
@@ -965,6 +997,26 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
965997 << src << " ))+1));\n " ;
966998 os << sret;
967999 return ;
1000+ } else if (from_ty.lanes () == 8 && target_ty.lanes () == 8 ) {
1001+ // bfloat162x4 -> float8
1002+ PrintIndent ();
1003+ stream << " ((float2*)(&" << sret << " ))[0] = "
1004+ << " __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
1005+ << src << " )));\n " ;
1006+ PrintIndent ();
1007+ stream << " ((float2*)(&" << sret << " ))[1] = "
1008+ << " __bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
1009+ << src << " ))+1));\n " ;
1010+ PrintIndent ();
1011+ stream << " ((float2*)(&" << sret << " ))[2] = "
1012+ << " __bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
1013+ << src << " ))+2));\n " ;
1014+ PrintIndent ();
1015+ stream << " ((float2*)(&" << sret << " ))[3] = "
1016+ << " __bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
1017+ << src << " ))+3));\n " ;
1018+ os << sret;
1019+ return ;
9681020 }
9691021 } else if (from_ty.is_float () && target_ty.is_bfloat16 ()) {
9701022 // Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162)
@@ -985,6 +1037,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
9851037 << " __float22bfloat162_rn(*((float2*)(&(" << src << " ))+1));\n " ;
9861038 os << sret;
9871039 return ;
1040+ } else if (from_ty.lanes () == 8 && target_ty.lanes () == 8 ) {
1041+ // float8 -> bfloat162x4
1042+ PrintIndent ();
1043+ stream << " (reinterpret_cast<__nv_bfloat162*>(&" << sret << " ))[0] = "
1044+ << " __float22bfloat162_rn(*(float2*)(&(" << src << " )));\n " ;
1045+ PrintIndent ();
1046+ stream << " (reinterpret_cast<__nv_bfloat162*>(&" << sret << " ))[1] = "
1047+ << " __float22bfloat162_rn(*((float2*)(&(" << src << " ))+1));\n " ;
1048+ PrintIndent ();
1049+ stream << " (reinterpret_cast<__nv_bfloat162*>(&" << sret << " ))[2] = "
1050+ << " __float22bfloat162_rn(*((float2*)(&(" << src << " ))+2));\n " ;
1051+ PrintIndent ();
1052+ stream << " (reinterpret_cast<__nv_bfloat162*>(&" << sret << " ))[3] = "
1053+ << " __float22bfloat162_rn(*((float2*)(&(" << src << " ))+3));\n " ;
1054+ os << sret;
1055+ return ;
9881056 }
9891057 }
9901058
@@ -1019,6 +1087,34 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
10191087 << " );\n " ;
10201088 os << sret;
10211089 return ;
1090+ } else if (from_ty.lanes () == 8 && target_ty.lanes () == 8 ) {
1091+ // float8 -> fp8x8
1092+ PrintIndent ();
1093+ stream << " ((__nv_fp8x2_storage_t*)(&" << sret << " ))[0] = "
1094+ << " __nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
1095+ << " )), __NV_SATFINITE, "
1096+ << (target_ty.is_float8_e4m3 () ? " __NV_E4M3" : " __NV_E5M2" )
1097+ << " );\n " ;
1098+ PrintIndent ();
1099+ stream << " ((__nv_fp8x2_storage_t*)(&" << sret << " ))[1] = "
1100+ << " __nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
1101+ << " ))+1), __NV_SATFINITE, "
1102+ << (target_ty.is_float8_e4m3 () ? " __NV_E4M3" : " __NV_E5M2" )
1103+ << " );\n " ;
1104+ PrintIndent ();
1105+ stream << " ((__nv_fp8x2_storage_t*)(&" << sret << " ))[2] = "
1106+ << " __nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
1107+ << " ))+2), __NV_SATFINITE, "
1108+ << (target_ty.is_float8_e4m3 () ? " __NV_E4M3" : " __NV_E5M2" )
1109+ << " );\n " ;
1110+ PrintIndent ();
1111+ stream << " ((__nv_fp8x2_storage_t*)(&" << sret << " ))[3] = "
1112+ << " __nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
1113+ << " ))+3), __NV_SATFINITE, "
1114+ << (target_ty.is_float8_e4m3 () ? " __NV_E4M3" : " __NV_E5M2" )
1115+ << " );\n " ;
1116+ os << sret;
1117+ return ;
10221118 }
10231119 }
10241120
0 commit comments