From 3524651e17a34851615cfbe6b6a078a7d4a9fe79 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi Date: Tue, 20 Feb 2024 10:56:04 -0500 Subject: [PATCH] Add tests for call directly to builtin Add more robustness to SemaChecking --- clang/include/clang/Basic/Builtins.td | 2 +- clang/include/clang/Sema/Sema.h | 2 + clang/lib/CodeGen/CGBuiltin.cpp | 10 +- clang/lib/Sema/SemaChecking.cpp | 225 ++++++++++++++----- clang/test/CodeGenHLSL/builtins/dot.hlsl | 170 ++++++++++++++ clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl | 50 ++--- 6 files changed, 376 insertions(+), 83 deletions(-) diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index 771c4f5d4121f4..e3432f7925ba14 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -4526,7 +4526,7 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> { def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> { let Spellings = ["__builtin_hlsl_dot"]; - let Attributes = [NoThrow, Const, CustomTypeChecking]; + let Attributes = [NoThrow, Const]; let Prototype = "void(...)"; } diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 6fe10fad45daff..3841ea3f06f757 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -14121,6 +14121,8 @@ class Sema final { bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc); + bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res); + bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall); bool SemaBuiltinElementwiseMath(CallExpr *TheCall); bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall); bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall); diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 4bad41a9be214e..7e993feb588f4c 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -17914,26 +17914,28 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, llvm::Type *T1 = Op1->getType(); if (!T0->isVectorTy() && !T1->isVectorTy()) { if (T0->isFloatingPointTy()) { - return Builder.CreateFMul(Op0, Op1, "dx.dot"); + return Builder.CreateFMul(Op0, Op1, "dx.dot"); } if (T0->isIntegerTy()) { - return Builder.CreateMul(Op0, Op1, "dx.dot"); + return Builder.CreateMul(Op0, Op1, "dx.dot"); } + // Bools should have been promoted assert( false && "Dot product on a scalar is only supported on integers and floats."); } + // A VectorSplat should have happened assert(T0->isVectorTy() && T1->isVectorTy() && "Dot product of vector and scalar is not supported."); - // NOTE: this assert will need to be revisited after overload resoltion - // PR merges. + // A vector sext or sitofp should have happened assert(T0->getScalarType() == T1->getScalarType() && "Dot product of vectors need the same element types."); auto *VecTy0 = E->getArg(0)->getType()->getAs(); auto *VecTy1 = E->getArg(1)->getType()->getAs(); + // A HLSLVectorTruncation should have happend assert(VecTy0->getNumElements() == VecTy1->getNumElements() && "Dot product requires vectors to be of the same size."); diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index 0933e39eb931ba..d3b9706ab313eb 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -5163,69 +5163,167 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) { return false; } -// Note: returning true in this case results in CheckBuiltinFunctionCall -// returning an ExprError -bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { - switch (BuiltinID) { - case Builtin::BI__builtin_hlsl_dot: { - if (checkArgCount(*this, TheCall, 2)) { +bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) { + unsigned NumArgs = TheCall->getNumArgs(); + + for (unsigned i = 0; i < NumArgs; ++i) { + ExprResult A = TheCall->getArg(i); + if (!A.get()->getType()->isBooleanType()) + return false; + } + // if we got here all args are bool + for (unsigned i = 0; i < NumArgs; ++i) { + ExprResult A = TheCall->getArg(i); + ExprResult ResA = S->PerformImplicitConversion(A.get(), S->Context.IntTy, + Sema::AA_Converting); + if (ResA.isInvalid()) return true; - } - Expr *Arg0 = TheCall->getArg(0); - QualType ArgTy0 = Arg0->getType(); + TheCall->setArg(0, ResA.get()); + } + return false; +} - Expr *Arg1 = TheCall->getArg(1); - QualType ArgTy1 = Arg1->getType(); +int overloadOrder(Sema *S, QualType ArgTyA) { + auto kind = ArgTyA->getAs()->getKind(); + switch (kind) { + case BuiltinType::Short: + case BuiltinType::UShort: + return 1; + case BuiltinType::Int: + case BuiltinType::UInt: + return 2; + case BuiltinType::Long: + case BuiltinType::ULong: + return 3; + case BuiltinType::LongLong: + case BuiltinType::ULongLong: + return 4; + case BuiltinType::Float16: + case BuiltinType::Half: + return 5; + case BuiltinType::Float: + return 6; + default: + break; + } + return 0; +} - auto *VecTy0 = ArgTy0->getAs(); - auto *VecTy1 = ArgTy1->getAs(); - SourceLocation BuiltinLoc = TheCall->getBeginLoc(); +QualType getVecLargestBitness(Sema *S, QualType ArgTyA, QualType ArgTyB) { + auto *VecTyA = ArgTyA->getAs(); + auto *VecTyB = ArgTyB->getAs(); + QualType VecTyAElem = VecTyA->getElementType(); + QualType VecTyBElem = VecTyB->getElementType(); + int vecAElemWidth = overloadOrder(S, VecTyAElem); + int vecBElemWidth = overloadOrder(S, VecTyBElem); + return vecAElemWidth > vecBElemWidth ? ArgTyA : ArgTyB; +} - // if arg0 is bool then call Diag with err_builtin_invalid_arg_type - if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) { - return true; - } +void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) { + assert(TheCall->getNumArgs() > 1); + ExprResult A = TheCall->getArg(0); + ExprResult B = TheCall->getArg(1); + QualType ArgTyA = A.get()->getType(); + QualType ArgTyB = B.get()->getType(); - // if arg1 is bool then call Diag with err_builtin_invalid_arg_type - if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) { - return true; + auto *VecTyA = ArgTyA->getAs(); + auto *VecTyB = ArgTyB->getAs(); + if (VecTyA == nullptr && VecTyB == nullptr) + return; + if (VecTyA == nullptr || VecTyB == nullptr) + return; + if (VecTyA->getNumElements() == VecTyB->getNumElements()) + return; + + Expr *LargerArg = B.get(); + Expr *SmallerArg = A.get(); + int largerIndex = 1; + if (VecTyA->getNumElements() > VecTyB->getNumElements()) { + LargerArg = A.get(); + SmallerArg = B.get(); + largerIndex = 0; + } + S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation) + << LargerArg->getType() << SmallerArg->getType() + << LargerArg->getSourceRange() << SmallerArg->getSourceRange(); + ExprResult ResLargerArg = S->ImpCastExprToType( + LargerArg, SmallerArg->getType(), CK_HLSLVectorTruncation); + TheCall->setArg(largerIndex, ResLargerArg.get()); + return; +} + +bool PromoteVectorElementCallArgs(Sema *S, CallExpr *TheCall) { + assert(TheCall->getNumArgs() > 1); + ExprResult A = TheCall->getArg(0); + ExprResult B = TheCall->getArg(1); + QualType ArgTyA = A.get()->getType(); + QualType ArgTyB = B.get()->getType(); + + auto *VecTyA = ArgTyA->getAs(); + auto *VecTyB = ArgTyB->getAs(); + if (VecTyA == nullptr && VecTyB == nullptr) + return false; + if (VecTyA && VecTyB) { + if (VecTyA->getElementType() == VecTyB->getElementType()) { + TheCall->setType(VecTyA->getElementType()); + return false; + } + SourceLocation BuiltinLoc = TheCall->getBeginLoc(); + QualType CastType = getVecLargestBitness(S, ArgTyA, ArgTyB); + if (CastType == ArgTyA) { + ExprResult ResB = S->SemaConvertVectorExpr( + B.get(), S->Context.CreateTypeSourceInfo(ArgTyA), BuiltinLoc, + B.get()->getBeginLoc()); + TheCall->setArg(1, ResB.get()); + TheCall->setType(VecTyA->getElementType()); + return false; } - if (VecTy0 == nullptr && VecTy1 == nullptr) { - if (ArgTy0 != ArgTy1) { - return true; - } else { - return false; - } + if (CastType == ArgTyB) { + ExprResult ResA = S->SemaConvertVectorExpr( + A.get(), S->Context.CreateTypeSourceInfo(ArgTyB), BuiltinLoc, + A.get()->getBeginLoc()); + TheCall->setArg(0, ResA.get()); + TheCall->setType(VecTyB->getElementType()); + return false; } + return false; + } - if ((VecTy0 == nullptr && VecTy1 != nullptr) || - (VecTy0 != nullptr && VecTy1 == nullptr)) { + if (VecTyB) { + // Convert to the vector result type + ExprResult ResA = A; + if (VecTyB->getElementType() != ArgTyA) + ResA = S->ImpCastExprToType(ResA.get(), VecTyB->getElementType(), + CK_FloatingCast); + ResA = S->ImpCastExprToType(ResA.get(), ArgTyB, CK_VectorSplat); + TheCall->setArg(0, ResA.get()); + } + if (VecTyA) { + ExprResult ResB = B; + if (VecTyA->getElementType() != ArgTyB) + ResB = S->ImpCastExprToType(ResB.get(), VecTyA->getElementType(), + CK_FloatingCast); + ResB = S->ImpCastExprToType(ResB.get(), ArgTyA, CK_VectorSplat); + TheCall->setArg(1, ResB.get()); + } + return false; +} - Diag(BuiltinLoc, diag::err_vec_builtin_non_vector) - << TheCall->getDirectCallee() - << SourceRange(TheCall->getArg(0)->getBeginLoc(), - TheCall->getArg(1)->getEndLoc()); +// Note: returning true in this case results in CheckBuiltinFunctionCall +// returning an ExprError +bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { + switch (BuiltinID) { + case Builtin::BI__builtin_hlsl_dot: { + if (checkArgCount(*this, TheCall, 2)) return true; - } - - if (VecTy0->getElementType() != VecTy1->getElementType()) { - // Note: This case should never happen. If type promotion occurs - // then element types won't be different. This diag error is here - // b\c EmitHLSLBuiltinExpr asserts on this case. - Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector) - << TheCall->getDirectCallee() - << SourceRange(TheCall->getArg(0)->getBeginLoc(), - TheCall->getArg(1)->getEndLoc()); + if (PromoteBoolsToInt(this, TheCall)) return true; - } - if (VecTy0->getNumElements() != VecTy1->getNumElements()) { - Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size) - << TheCall->getDirectCallee() - << SourceRange(TheCall->getArg(0)->getBeginLoc(), - TheCall->getArg(1)->getEndLoc()); + if (PromoteVectorElementCallArgs(this, TheCall)) + return true; + PromoteVectorArgTruncation(this, TheCall); + if (SemaBuiltinVectorToScalarMath(TheCall)) return true; - } break; } } @@ -19669,6 +19767,29 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) { } bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) { + QualType Res; + bool result = SemaBuiltinVectorMath(TheCall, Res); + if (result) + return true; + TheCall->setType(Res); + return false; +} + +bool Sema::SemaBuiltinVectorToScalarMath(CallExpr *TheCall) { + QualType Res; + bool result = SemaBuiltinVectorMath(TheCall, Res); + if (result) + return true; + + if (auto *VecTy0 = Res->getAs()) { + TheCall->setType(VecTy0->getElementType()); + } else { + TheCall->setType(Res); + } + return false; +} + +bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) { if (checkArgCount(*this, TheCall, 2)) return true; @@ -19676,8 +19797,7 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) { ExprResult B = TheCall->getArg(1); // Do standard promotions between the two arguments, returning their common // type. - QualType Res = - UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison); + Res = UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison); if (A.isInvalid() || B.isInvalid()) return true; @@ -19694,7 +19814,6 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) { TheCall->setArg(0, A.get()); TheCall->setArg(1, B.get()); - TheCall->setType(Res); return false; } diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl index 9a895cd190ba9f..b2cd3b6302af6a 100644 --- a/clang/test/CodeGenHLSL/builtins/dot.hlsl +++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl @@ -55,6 +55,34 @@ uint16_t test_dot_ushort3 ( uint16_t3 p0, uint16_t3 p1 ) { uint16_t test_dot_ushort4 ( uint16_t4 p0, uint16_t4 p1 ) { return dot ( p0, p1 ); } + +// NATIVE_HALF: %conv = sitofp <2 x i16> %1 to <2 x float> +// NATIVE_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %conv) +// NATIVE_HALF: ret float %dx.dot +float test_builtin_dot_vec_int16_to_float_promotion( float2 p0, int16_t2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); +} + +// NATIVE_HALF: %conv = sitofp <2 x i16> %1 to <2 x half> +// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %conv) +// NATIVE_HALF: ret half %dx.dot +half test_builtin_dot_vec_int16_to_half_promotion( half2 p0, int16_t2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); +} + +// NATIVE_HALF: %conv = sext <2 x i16> %1 to <2 x i32> +// NATIVE_HALF: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %conv) +// NATIVE_HALF: ret i32 %dx.dot +int test_builtin_dot_vec_int16_to_int_promotion( int2 p0, int16_t2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); +} + +// NATIVE_HALF: %conv = sext <2 x i16> %1 to <2 x i64> +// NATIVE_HALF: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %conv) +// NATIVE_HALF: ret i64 %dx.dot +int64_t test_builtin_dot_vec_int16_to_int64_promotion( int64_t2 p0, int16_t2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); +} #endif // CHECK: %dx.dot = mul i32 %0, %1 @@ -184,6 +212,13 @@ half test_dot_half3 ( half3 p0, half3 p1 ) { half test_dot_half4 ( half4 p0, half4 p1 ) { return dot ( p0, p1 ); } +// NATIVE_HALF: %conv = fpext <2 x half> %1 to <2 x float> +// NATIVE_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %conv) +// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1) +// CHECK: ret float %dx.dot +float test_builtin_dot_vec_half_to_float_promotion( float2 p0, half2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); +} // CHECK: %dx.dot = fmul float %0, %1 // CHECK: ret float %dx.dot @@ -209,8 +244,143 @@ float test_dot_float4 ( float4 p0, float4 p1) { return dot ( p0, p1 ); } +// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1) +// CHECK: ret float %dx.dot +float test_dot_float2_splat ( float p0, float2 p1 ) { + return dot( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1) +// CHECK: ret float %dx.dot +float test_dot_float3_splat ( float p0, float3 p1 ) { + return dot( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1) +// CHECK: ret float %dx.dot +float test_dot_float4_splat ( float p0, float4 p1 ) { + return dot( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1) +// CHECK: ret float %dx.dot +float test_builtin_dot_float2_splat ( float p0, float2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1) +// CHECK: ret float %dx.dot +float test_builtin_dot_float3_splat ( float p0, float3 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1) +// CHECK: ret float %dx.dot +float test_builtin_dot_float4_splat ( float p0, float4 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); +} + +// CHECK: %conv = sitofp i32 %1 to float +// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0 +// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer +// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat) +// CHECK: ret float %dx.dot +float test_dot_float2_int_splat ( float2 p0, int p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); +} + +// CHECK: %conv = sitofp i32 %1 to float +// CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0 +// CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer +// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat) +// CHECK: ret float %dx.dot +float test_builtin_dot_float2_int_splat ( float2 p0, int p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %conv = sitofp i32 %1 to float +// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0 +// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer +// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat) +// CHECK: ret float %dx.dot +float test_dot_float3_int_splat ( float3 p0, int p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); +} + +// CHECK: %conv = sitofp i32 %1 to float +// CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0 +// CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer +// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat) +// CHECK: ret float %dx.dot +float test_builtin_dot_float3_int_splat ( float3 p0, int p1 ) { + return dot ( p0, p1 ); +} + // CHECK: %dx.dot = fmul double %0, %1 // CHECK: ret double %dx.dot double test_dot_double ( double p0, double p1 ) { return dot ( p0, p1 ); } + +// CHECK: %conv = zext i1 %tobool to i32 +// CHECK: %dx.dot = mul i32 %conv, %1 +// CHECK: ret i32 %dx.dot +int test_dot_bool_scalar_arg0_type_promotion ( bool p0, int p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %conv = zext i1 %tobool to i32 +// CHECK: %dx.dot = mul i32 %0, %conv +// CHECK: ret i32 %dx.dot +int test_dot_bool_scalar_arg1_type_promotion ( int p0, bool p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %conv1 = uitofp i1 %tobool to double +// CHECK: %dx.dot = fmul double %conv, %conv1 +// CHECK: %conv2 = fptrunc double %dx.dot to float +// CHECK: ret float %conv2 +float builtin_bool_to_float_type_promotion ( float p0, bool p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); +} + +// CHECK: %conv = uitofp i1 %tobool to double +// CHECK: %conv1 = fpext float %1 to double +// CHECK: %dx.dot = fmul double %conv, %conv1 +// CHECK: %conv2 = fptrunc double %dx.dot to float +// CHECK: ret float %conv2 +float builtin_bool_to_float_arg1_type_promotion ( bool p0, float p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); +} + +// CHECK: %conv = zext i1 %tobool to i32 +// CHECK: %conv3 = zext i1 %tobool2 to i32 +// CHECK: %dx.dot = mul i32 %conv, %conv3 +// CHECK: ret i32 %dx.dot +int test_builtin_dot_bool_type_promotion ( bool p0, bool p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); +} + +// CHECK: %conv = fpext float %0 to double +// CHECK: %conv1 = sitofp i32 %1 to double +// CHECK: dx.dot = fmul double %conv, %conv1 +// CHECK: %conv2 = fptrunc double %dx.dot to float +// CHECK: ret float %conv2 +float test_builtin_dot_int_to_float_promotion ( float p0, int p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); + // expected-error@-1 {{call to 'dot' is ambiguous}} +} + +// CHECK: %conv = sitofp <2 x i32> %0 to <2 x float> +// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %conv, <2 x float> %1) +// CHECK: ret float %dx.dot +float test_builtin_dot_vec_int_to_float_promotion ( int2 p0, float2 p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); +} + +// CHECK: %conv = sext <2 x i32> %1 to <2 x i64> +// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %conv) +// CHECK: ret i64 %dx.dot +int64_t test_builtin_dot_vec_int_to_int64_promotion( int64_t2 p0, int2 p1 ) { + return __builtin_hlsl_dot( p0, p1 ); +} \ No newline at end of file diff --git a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl index a5acb400ab9c7b..2f1a833f5ca364 100644 --- a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl +++ b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl @@ -1,46 +1,46 @@ // RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \ // RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \ // RUN: -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected -// NOTE: This test is marked XFAIL because when overload resolution merges -// NOTE: test_dot_element_type_mismatch & test_dot_scalar_mismatch will have different behavior -// XFAIL: * -float test_first_arg_is_not_vector ( float p0, float2 p1 ) { - return dot ( p0, p1 ); - // expected-error@-1 {{first two arguments to 'dot' must be vectors}} +float test_no_second_arg ( float2 p0) { + return __builtin_hlsl_dot ( p0 ); + // expected-error@-1 {{too few arguments to function call, expected 2, have 1}} } -float test_second_arg_is_not_vector ( float2 p0, float p1 ) { - return dot ( p0, p1 ); - // expected-error@-1 {{first two arguments to 'dot' must be vectors}} +float test_too_many_arg ( float2 p0) { + return __builtin_hlsl_dot ( p0, p0, p0 ); + // expected-error@-1 {{too many arguments to function call, expected 2, have 3}} } -int test_dot_unsupported_scalar_arg0 ( bool p0, int p1 ) { - return dot ( p0, p1 ); - // expected-error@-1 {{1st argument must be a vector, integer or floating point type (was 'bool')}} +//NOTE: eventually behavior should match builtin +float test_dot_no_second_arg ( float2 p0) { + return dot ( p0 ); + // expected-error@-1 {{no matching function for call to 'dot'}} } -int test_dot_unsupported_scalar_arg1 ( int p0, bool p1 ) { +float test_dot_vector_size_mismatch ( float3 p0, float2 p1 ) { return dot ( p0, p1 ); - // expected-error@-1 {{2nd argument must be a vector, integer or floating point type (was 'bool')}} + // expected-warning@-1 {{implicit conversion truncates vector: 'float3' (aka 'vector') to 'float __attribute__((ext_vector_type(2)))' (vector of 2 'float' values)}} } -float test_dot_scalar_mismatch ( float p0, int p1 ) { - return dot ( p0, p1 ); - // expected-error@-1 {{call to 'dot' is ambiguous}} +float test_dot_builtin_vector_size_mismatch ( float3 p0, float2 p1 ) { + return __builtin_hlsl_dot ( p0, p1 ); + // expected-warning@-1 {{implicit conversion truncates vector: 'float3' (aka 'vector') to 'float2' (aka 'vector')}} } -float test_dot_vector_size_mismatch ( float3 p0, float2 p1 ) { - return dot ( p0, p1 ); - // expected-error@-1 {{first two arguments to 'dot' must have the same size}} -} -float test__no_second_arg ( float2 p0) { - return dot ( p0 ); - // expected-error@-1 {{no matching function for call to 'dot'}} +//NOTE: this case runs into the same problem as the below example +//int Fn1(int p0, int p1); +//int Fn1(float p0, float p1); +//int test_dot_scalar_mismatch ( float p0, int p1 ) { +// return Fn1( p0, p1 ); +//} +float test_dot_scalar_mismatch ( float p0, int p1 ) { + return dot ( p0, p1 ); + // expected-error@-1 {{call to 'dot' is ambiguous}} } float test_dot_element_type_mismatch ( int2 p0, float2 p1 ) { return dot ( p0, p1 ); // expected-error@-1 {{call to 'dot' is ambiguous}} -} +} \ No newline at end of file