diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index df74026c5d2d50..771c4f5d4121f4 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -4524,6 +4524,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> { let Prototype = "void*(unsigned char)"; } +def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_dot"]; + let Attributes = [NoThrow, Const, CustomTypeChecking]; + let Prototype = "void(...)"; +} + // Builtins for XRay. def XRayCustomEvent : Builtin { let Spellings = ["__xray_customevent"]; diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 6e3cebc311eeb9..46f5424038d8b0 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10267,6 +10267,8 @@ def err_vec_builtin_non_vector : Error< "first two arguments to %0 must be vectors">; def err_vec_builtin_incompatible_vector : Error< "first two arguments to %0 must have the same type">; +def err_vec_builtin_incompatible_size : Error< + "first two arguments to %0 must have the same size">; def err_vsx_builtin_nonconstant_argument : Error< "argument %0 to %1 must be a 2-bit unsigned literal (i.e. 0, 1, 2 or 3)">; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index e9cd42ae777df5..3557db56905ff8 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -14055,6 +14055,7 @@ class Sema final { bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID, CallExpr *TheCall); bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall); + bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall); bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum); bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID, CallExpr *TheCall); diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index d454ccc1dd8613..4bad41a9be214e 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -44,6 +44,7 @@ #include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/IntrinsicsARM.h" #include "llvm/IR/IntrinsicsBPF.h" +#include "llvm/IR/IntrinsicsDirectX.h" #include "llvm/IR/IntrinsicsHexagon.h" #include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/IntrinsicsPowerPC.h" @@ -5982,6 +5983,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID, llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr"); } + // EmitHLSLBuiltinExpr will check getLangOpts().HLSL + if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E)) + return RValue::get(V); + if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice) return EmitHipStdParUnsupportedBuiltin(this, FD); @@ -17896,6 +17901,50 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments, return Arg; } +Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, + const CallExpr *E) { + if (!getLangOpts().HLSL) + return nullptr; + + switch (BuiltinID) { + case Builtin::BI__builtin_hlsl_dot: { + Value *Op0 = EmitScalarExpr(E->getArg(0)); + Value *Op1 = EmitScalarExpr(E->getArg(1)); + llvm::Type *T0 = Op0->getType(); + llvm::Type *T1 = Op1->getType(); + if (!T0->isVectorTy() && !T1->isVectorTy()) { + if (T0->isFloatingPointTy()) { + return Builder.CreateFMul(Op0, Op1, "dx.dot"); + } + + if (T0->isIntegerTy()) { + return Builder.CreateMul(Op0, Op1, "dx.dot"); + } + assert( + false && + "Dot product on a scalar is only supported on integers and floats."); + } + assert(T0->isVectorTy() && T1->isVectorTy() && + "Dot product of vector and scalar is not supported."); + + // NOTE: this assert will need to be revisited after overload resoltion + // PR merges. + assert(T0->getScalarType() == T1->getScalarType() && + "Dot product of vectors need the same element types."); + + auto *VecTy0 = E->getArg(0)->getType()->getAs(); + auto *VecTy1 = E->getArg(1)->getType()->getAs(); + assert(VecTy0->getNumElements() == VecTy1->getNumElements() && + "Dot product requires vectors to be of the same size."); + + return Builder.CreateIntrinsic( + /*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot, + ArrayRef{Op0, Op1}, nullptr, "dx.dot"); + } break; + } + return nullptr; +} + Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E) { llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent; diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index caa6a327550baa..3169802ffc2c7d 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -4405,6 +4405,7 @@ class CodeGenFunction : public CodeGenTypeCache { llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E); llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E); llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E); + llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E); llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx, const CallExpr *E); llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E); diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h index f87ac977997962..a92f0d0849ba77 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h @@ -179,6 +179,98 @@ double3 cos(double3); _HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos) double4 cos(double4); +//===----------------------------------------------------------------------===// +// dot product builtins +//===----------------------------------------------------------------------===// +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +half dot(half, half); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +half dot(half2, half2); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +half dot(half3, half3); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +half dot(half4, half4); + +#ifdef __HLSL_ENABLE_16_BIT +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int16_t dot(int16_t, int16_t); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int16_t dot(int16_t2, int16_t2); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int16_t dot(int16_t3, int16_t3); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int16_t dot(int16_t4, int16_t4); + +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint16_t dot(uint16_t, uint16_t); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint16_t dot(uint16_t2, uint16_t2); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint16_t dot(uint16_t3, uint16_t3); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint16_t dot(uint16_t4, uint16_t4); +#endif + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +float dot(float, float); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +float dot(float2, float2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +float dot(float3, float3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +float dot(float4, float4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +double dot(double, double); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int dot(int, int); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int dot(int2, int2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int dot(int3, int3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int dot(int4, int4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint dot(uint, uint); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint dot(uint2, uint2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint dot(uint3, uint3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint dot(uint4, uint4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int64_t dot(int64_t, int64_t); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int64_t dot(int64_t2, int64_t2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int64_t dot(int64_t3, int64_t3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +int64_t dot(int64_t4, int64_t4); + +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint64_t dot(uint64_t, uint64_t); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint64_t dot(uint64_t2, uint64_t2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint64_t dot(uint64_t3, uint64_t3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot) +uint64_t dot(uint64_t4, uint64_t4); + //===----------------------------------------------------------------------===// // floor builtins //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index 8e763384774444..b5f3f8b2e96db7 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -2120,10 +2120,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID, // not a valid type, emit an error message and return true. Otherwise return // false. static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc, - QualType Ty) { - if (!Ty->getAs() && !ConstantMatrixType::isValidElementType(Ty)) { + QualType ArgTy, int ArgIndex) { + if (!ArgTy->getAs() && + !ConstantMatrixType::isValidElementType(ArgTy)) { return S.Diag(Loc, diag::err_builtin_invalid_arg_type) - << 1 << /* vector, integer or float ty*/ 0 << Ty; + << ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy; } return false; @@ -2958,6 +2959,10 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID, } } + if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall)) { + return ExprError(); + } + // Since the target specific builtins for each arch overlap, only check those // of the arch we are compiling for. if (Context.BuiltinInfo.isTSBuiltin(BuiltinID)) { @@ -5158,6 +5163,75 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) { return false; } +// Note: returning true in this case results in CheckBuiltinFunctionCall +// returning an ExprError +bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { + switch (BuiltinID) { + case Builtin::BI__builtin_hlsl_dot: { + if (checkArgCount(*this, TheCall, 2)) { + return true; + } + Expr *Arg0 = TheCall->getArg(0); + QualType ArgTy0 = Arg0->getType(); + + Expr *Arg1 = TheCall->getArg(1); + QualType ArgTy1 = Arg1->getType(); + + auto *VecTy0 = ArgTy0->getAs(); + auto *VecTy1 = ArgTy1->getAs(); + SourceLocation BuiltinLoc = TheCall->getBeginLoc(); + + // if arg0 is bool then call Diag with err_builtin_invalid_arg_type + if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) { + return true; + } + + // if arg1 is bool then call Diag with err_builtin_invalid_arg_type + if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) { + return true; + } + + if (VecTy0 == nullptr && VecTy1 == nullptr) { + if (ArgTy0 != ArgTy1) { + return true; + } else { + return false; + } + } + + if ((VecTy0 == nullptr && VecTy1 != nullptr) || + (VecTy0 != nullptr && VecTy1 == nullptr)) { + + Diag(BuiltinLoc, diag::err_vec_builtin_non_vector) + << TheCall->getDirectCallee() + << SourceRange(TheCall->getArg(0)->getBeginLoc(), + TheCall->getArg(1)->getEndLoc()); + return true; + } + + if (VecTy0->getElementType() != VecTy1->getElementType()) { + // Note: This case should never happen. If type promotion occurs + // then element types won't be different. This diag error is here + // b\c EmitHLSLBuiltinExpr asserts on this case. + Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector) + << TheCall->getDirectCallee() + << SourceRange(TheCall->getArg(0)->getBeginLoc(), + TheCall->getArg(1)->getEndLoc()); + return true; + } + if (VecTy0->getNumElements() != VecTy1->getNumElements()) { + Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size) + << TheCall->getDirectCallee() + << SourceRange(TheCall->getArg(0)->getBeginLoc(), + TheCall->getArg(1)->getEndLoc()); + return true; + } + break; + } + } + return false; +} + bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { // position of memory order and scope arguments in the builtin @@ -19583,7 +19657,7 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) { TheCall->setArg(0, A.get()); QualType TyA = A.get()->getType(); - if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA)) + if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1)) return true; TheCall->setType(TyA); @@ -19611,7 +19685,7 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) { diag::err_typecheck_call_different_arg_types) << TyA << TyB; - if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA)) + if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1)) return true; TheCall->setArg(0, A.get()); diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl new file mode 100644 index 00000000000000..9a895cd190ba9f --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl @@ -0,0 +1,216 @@ +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \ +// RUN: -emit-llvm -disable-llvm-passes -O3 -o - | FileCheck %s \ +// RUN: --check-prefixes=CHECK,NATIVE_HALF +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \ +// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF + +// -fnative-half-type sets __HLSL_ENABLE_16_BIT +#ifdef __HLSL_ENABLE_16_BIT +// NATIVE_HALF: %dx.dot = mul i16 %0, %1 +// NATIVE_HALF: ret i16 %dx.dot +int16_t test_dot_short ( int16_t p0, int16_t p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1) +// NATIVE_HALF: ret i16 %dx.dot +int16_t test_dot_short2 ( int16_t2 p0, int16_t2 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1) +// NATIVE_HALF: ret i16 %dx.dot +int16_t test_dot_short3 ( int16_t3 p0, int16_t3 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1) +// NATIVE_HALF: ret i16 %dx.dot +int16_t test_dot_short4 ( int16_t4 p0, int16_t4 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = mul i16 %0, %1 +// NATIVE_HALF: ret i16 %dx.dot +uint16_t test_dot_ushort ( uint16_t p0, uint16_t p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1) +// NATIVE_HALF: ret i16 %dx.dot +uint16_t test_dot_ushort2 ( uint16_t2 p0, uint16_t2 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1) +// NATIVE_HALF: ret i16 %dx.dot +uint16_t test_dot_ushort3 ( uint16_t3 p0, uint16_t3 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1) +// NATIVE_HALF: ret i16 %dx.dot +uint16_t test_dot_ushort4 ( uint16_t4 p0, uint16_t4 p1 ) { + return dot ( p0, p1 ); +} +#endif + +// CHECK: %dx.dot = mul i32 %0, %1 +// CHECK: ret i32 %dx.dot +int test_dot_int ( int p0, int p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1) +// CHECK: ret i32 %dx.dot +int test_dot_int2 ( int2 p0, int2 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1) +// CHECK: ret i32 %dx.dot +int test_dot_int3 ( int3 p0, int3 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1) +// CHECK: ret i32 %dx.dot +int test_dot_int4 ( int4 p0, int4 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = mul i32 %0, %1 +// CHECK: ret i32 %dx.dot +uint test_dot_uint ( uint p0, uint p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1) +// CHECK: ret i32 %dx.dot +uint test_dot_uint2 ( uint2 p0, uint2 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1) +// CHECK: ret i32 %dx.dot +uint test_dot_uint3 ( uint3 p0, uint3 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1) +// CHECK: ret i32 %dx.dot +uint test_dot_uint4 ( uint4 p0, uint4 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = mul i64 %0, %1 +// CHECK: ret i64 %dx.dot +int64_t test_dot_long ( int64_t p0, int64_t p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1) +// CHECK: ret i64 %dx.dot +int64_t test_dot_uint2 ( int64_t2 p0, int64_t2 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1) +// CHECK: ret i64 %dx.dot +int64_t test_dot_uint3 ( int64_t3 p0, int64_t3 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1) +// CHECK: ret i64 %dx.dot +int64_t test_dot_uint4 ( int64_t4 p0, int64_t4 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = mul i64 %0, %1 +// CHECK: ret i64 %dx.dot +uint64_t test_dot_ulong ( uint64_t p0, uint64_t p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1) +// CHECK: ret i64 %dx.dot +uint64_t test_dot_uint2 ( uint64_t2 p0, uint64_t2 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1) +// CHECK: ret i64 %dx.dot +uint64_t test_dot_uint3 ( uint64_t3 p0, uint64_t3 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1) +// CHECK: ret i64 %dx.dot +uint64_t test_dot_uint4 ( uint64_t4 p0, uint64_t4 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = fmul half %0, %1 +// NATIVE_HALF: ret half %dx.dot +// NO_HALF: %dx.dot = fmul float %0, %1 +// NO_HALF: ret float %dx.dot +half test_dot_half ( half p0, half p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1) +// NATIVE_HALF: ret half %dx.dot +// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1) +// NO_HALF: ret float %dx.dot +half test_dot_half2 ( half2 p0, half2 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1) +// NATIVE_HALF: ret half %dx.dot +// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1) +// NO_HALF: ret float %dx.dot +half test_dot_half3 ( half3 p0, half3 p1 ) { + return dot ( p0, p1 ); +} + +// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1) +// NATIVE_HALF: ret half %dx.dot +// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1) +// NO_HALF: ret float %dx.dot +half test_dot_half4 ( half4 p0, half4 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = fmul float %0, %1 +// CHECK: ret float %dx.dot +float test_dot_float ( float p0, float p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1) +// CHECK: ret float %dx.dot +float test_dot_float2 ( float2 p0, float2 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1) +// CHECK: ret float %dx.dot +float test_dot_float3 ( float3 p0, float3 p1 ) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1) +// CHECK: ret float %dx.dot +float test_dot_float4 ( float4 p0, float4 p1) { + return dot ( p0, p1 ); +} + +// CHECK: %dx.dot = fmul double %0, %1 +// CHECK: ret double %dx.dot +double test_dot_double ( double p0, double p1 ) { + return dot ( p0, p1 ); +} diff --git a/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl new file mode 100644 index 00000000000000..a5acb400ab9c7b --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl @@ -0,0 +1,46 @@ +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \ +// RUN: -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected +// NOTE: This test is marked XFAIL because when overload resolution merges +// NOTE: test_dot_element_type_mismatch & test_dot_scalar_mismatch will have different behavior +// XFAIL: * + +float test_first_arg_is_not_vector ( float p0, float2 p1 ) { + return dot ( p0, p1 ); + // expected-error@-1 {{first two arguments to 'dot' must be vectors}} +} + +float test_second_arg_is_not_vector ( float2 p0, float p1 ) { + return dot ( p0, p1 ); + // expected-error@-1 {{first two arguments to 'dot' must be vectors}} +} + +int test_dot_unsupported_scalar_arg0 ( bool p0, int p1 ) { + return dot ( p0, p1 ); + // expected-error@-1 {{1st argument must be a vector, integer or floating point type (was 'bool')}} +} + +int test_dot_unsupported_scalar_arg1 ( int p0, bool p1 ) { + return dot ( p0, p1 ); + // expected-error@-1 {{2nd argument must be a vector, integer or floating point type (was 'bool')}} +} + +float test_dot_scalar_mismatch ( float p0, int p1 ) { + return dot ( p0, p1 ); + // expected-error@-1 {{call to 'dot' is ambiguous}} +} + +float test_dot_vector_size_mismatch ( float3 p0, float2 p1 ) { + return dot ( p0, p1 ); + // expected-error@-1 {{first two arguments to 'dot' must have the same size}} +} + +float test__no_second_arg ( float2 p0) { + return dot ( p0 ); + // expected-error@-1 {{no matching function for call to 'dot'}} +} + +float test_dot_element_type_mismatch ( int2 p0, float2 p1 ) { + return dot ( p0, p1 ); + // expected-error@-1 {{call to 'dot' is ambiguous}} +} diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 2fe4fdfd5953be..c192d4b84417c9 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -19,4 +19,9 @@ def int_dx_flattened_thread_id_in_group : Intrinsic<[llvm_i32_ty], [], [IntrNoMe def int_dx_create_handle : ClangBuiltin<"__builtin_hlsl_create_handle">, Intrinsic<[ llvm_ptr_ty ], [llvm_i8_ty], [IntrWillReturn]>; + +def int_dx_dot : + Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], + [IntrNoMem, IntrWillReturn, Commutative] >; }