Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HLSL] Implementation of dot intrinsic #81190

Merged
merged 6 commits into from
Feb 26, 2024

Conversation

farzonl
Copy link
Member

@farzonl farzonl commented Feb 8, 2024

This change implements #70073

HLSL has a dot intrinsic defined here:
https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot

The intrinsic itself is defined as a HLSL_LANG LangBuiltin in Builtins.td.
This is used to associate all the dot product typdef defined hlsl_intrinsics.h
with a single intrinsic check in CGBuiltin.cpp & SemaChecking.cpp.

In IntrinsicsDirectX.td we define the llvmIR for the dot product.
A few goals were in mind for this IR. First it should operate on only
vectors. Second the return type should be the vector element type. Third
the second parameter vector should be of the same size as the first
parameter. Finally a dot b should be the same as b dot a.

In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot
product though is language specific intrinsic and so is guarded behind getLangOpts().HLSL.
The call chain looks like this: EmitBuiltinExpr -> EmitHLSLBuiltinExp

EmitHLSLBuiltinExp dot product intrinsics makes a destinction
between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply.

Sema.h & SemaChecking.cpp saw the addition of CheckHLSLBuiltinFunctionCall, a language specific semantic validation that can be expanded for other hlsl specific intrinsics.

clang/lib/Headers/hlsl/hlsl_intrinsics.h Show resolved Hide resolved
clang/lib/CodeGen/CGBuiltin.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super excited to see this coming along.

clang/include/clang/Basic/Builtins.td Outdated Show resolved Hide resolved
clang/lib/CodeGen/CGBuiltin.cpp Outdated Show resolved Hide resolved
clang/lib/CodeGen/CGBuiltin.cpp Outdated Show resolved Hide resolved
clang/lib/Headers/hlsl/hlsl_intrinsics.h Show resolved Hide resolved
clang/lib/Headers/hlsl/hlsl_intrinsics.h Show resolved Hide resolved
llvm/include/llvm/IR/IntrinsicsDirectX.td Show resolved Hide resolved
llvm/include/llvm/IR/IntrinsicsDirectX.td Outdated Show resolved Hide resolved
@farzonl farzonl force-pushed the llvm-farzon/hlsl-dot-intrinsic branch 3 times, most recently from 4103940 to 73cc9fd Compare February 13, 2024 02:40
clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl Outdated Show resolved Hide resolved
clang/test/CodeGenHLSL/builtins/dot.hlsl Outdated Show resolved Hide resolved
@farzonl farzonl force-pushed the llvm-farzon/hlsl-dot-intrinsic branch from 73cc9fd to 23c277d Compare February 16, 2024 00:09
Copy link

github-actions bot commented Feb 16, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@farzonl farzonl force-pushed the llvm-farzon/hlsl-dot-intrinsic branch 3 times, most recently from f6dbbfb to bebb3ad Compare February 16, 2024 16:56
Copy link
Contributor

@dmpots dmpots left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@farzonl farzonl marked this pull request as ready for review February 16, 2024 17:12
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen backend:DirectX HLSL HLSL Language Support llvm:ir labels Feb 16, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 16, 2024

@llvm/pr-subscribers-backend-directx
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-hlsl

@llvm/pr-subscribers-clang

Author: Farzon Lotfi (farzonl)

Changes

This change implements #70073

HLSL has a dot intrinsic defined here:
https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot

The intrinsic itself is defined as a HLSL_LANG LangBuiltin in Builtins.td.
This is used to associate all the dot product typdef defined hlsl_intrinsics.h
with a single intrinsic check in CGBuiltin.cpp & SemaChecking.cpp.

In IntrinsicsDirectX.td we define the llvmIR for the dot product.
A few goals were in mind for this IR. First it should operate on only
vectors. Second the return type should be the vector element type. Third
the second parameter vector should be of the same size as the first
parameter. Finally a dot b should be the same as b dot a.

In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot
product though is language specific intrinsic and so is guarded behind getLangOpts().HLSL.
The call chain looks like this: EmitBuiltinExpr -> EmitHLSLBuiltinExp

EmitHLSLBuiltinExp dot product intrinsics makes a destinction
between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply.

Sema.h & SemaChecking.cpp saw the addition of CheckHLSLBuiltinFunctionCall, a language specific semantic validation that can be expanded for other hlsl specific intrinsics.


Patch is 23.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81190.diff

10 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+2)
  • (modified) clang/include/clang/Sema/Sema.h (+1)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+49)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+1)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+92)
  • (modified) clang/lib/Sema/SemaChecking.cpp (+79-5)
  • (added) clang/test/CodeGenHLSL/builtins/dot.hlsl (+216)
  • (added) clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl (+46)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+5)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index df74026c5d2d50..771c4f5d4121f4 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4524,6 +4524,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void*(unsigned char)";
 }
 
+def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_dot"];
+  let Attributes = [NoThrow, Const, CustomTypeChecking];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index fb44495199ee56..9957f4e743e9cb 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10267,6 +10267,8 @@ def err_vec_builtin_non_vector : Error<
  "first two arguments to %0 must be vectors">;
 def err_vec_builtin_incompatible_vector : Error<
   "first two arguments to %0 must have the same type">;
+def err_vec_builtin_incompatible_size : Error<
+  "first two arguments to %0 must have the same size">;
 def err_vsx_builtin_nonconstant_argument : Error<
   "argument %0 to %1 must be a 2-bit unsigned literal (i.e. 0, 1, 2 or 3)">;
 
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index e9cd42ae777df5..3557db56905ff8 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -14055,6 +14055,7 @@ class Sema final {
   bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
                                    CallExpr *TheCall);
   bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
+  bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
   bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum);
   bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
                                      CallExpr *TheCall);
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index d454ccc1dd8613..4bad41a9be214e 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -44,6 +44,7 @@
 #include "llvm/IR/IntrinsicsAMDGPU.h"
 #include "llvm/IR/IntrinsicsARM.h"
 #include "llvm/IR/IntrinsicsBPF.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
 #include "llvm/IR/IntrinsicsHexagon.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/IntrinsicsPowerPC.h"
@@ -5982,6 +5983,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
     llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr");
   }
 
+  // EmitHLSLBuiltinExpr will check getLangOpts().HLSL
+  if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E))
+    return RValue::get(V);
+
   if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice)
     return EmitHipStdParUnsupportedBuiltin(this, FD);
 
@@ -17896,6 +17901,50 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
   return Arg;
 }
 
+Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
+                                            const CallExpr *E) {
+  if (!getLangOpts().HLSL)
+    return nullptr;
+
+  switch (BuiltinID) {
+  case Builtin::BI__builtin_hlsl_dot: {
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    Value *Op1 = EmitScalarExpr(E->getArg(1));
+    llvm::Type *T0 = Op0->getType();
+    llvm::Type *T1 = Op1->getType();
+    if (!T0->isVectorTy() && !T1->isVectorTy()) {
+      if (T0->isFloatingPointTy()) {
+        return Builder.CreateFMul(Op0, Op1, "dx.dot");
+      }
+
+      if (T0->isIntegerTy()) {
+        return Builder.CreateMul(Op0, Op1, "dx.dot");
+      }
+      assert(
+          false &&
+          "Dot product on a scalar is only supported on integers and floats.");
+    }
+    assert(T0->isVectorTy() && T1->isVectorTy() &&
+           "Dot product of vector and scalar is not supported.");
+
+    // NOTE: this assert will need to be revisited after overload resoltion
+    //  PR merges.
+    assert(T0->getScalarType() == T1->getScalarType() &&
+           "Dot product of vectors need the same element types.");
+
+    auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
+    auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
+    assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
+           "Dot product requires vectors to be of the same size.");
+
+    return Builder.CreateIntrinsic(
+        /*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
+        ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
+  } break;
+  }
+  return nullptr;
+}
+
 Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
                                               const CallExpr *E) {
   llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent;
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index caa6a327550baa..3169802ffc2c7d 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4405,6 +4405,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
+  llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
                                            const CallExpr *E);
   llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index f87ac977997962..a92f0d0849ba77 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -179,6 +179,98 @@ double3 cos(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
 double4 cos(double4);
 
+//===----------------------------------------------------------------------===//
+// dot product builtins
+//===----------------------------------------------------------------------===//
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half, half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half2, half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half3, half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half4, half4);
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t, int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t2, int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t3, int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t4, int16_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t, uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t2, uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t3, uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t4, uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float, float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float2, float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float3, float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float4, float4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+double dot(double, double);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int, int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int2, int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int3, int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int4, int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint, uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint2, uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint3, uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint4, uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t, int64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t2, int64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t3, int64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t4, int64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t, uint64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t2, uint64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t3, uint64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t4, uint64_t4);
+
 //===----------------------------------------------------------------------===//
 // floor builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 8e763384774444..b5f3f8b2e96db7 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2120,10 +2120,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
 // not a valid type, emit an error message and return true. Otherwise return
 // false.
 static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
-                                        QualType Ty) {
-  if (!Ty->getAs<VectorType>() && !ConstantMatrixType::isValidElementType(Ty)) {
+                                        QualType ArgTy, int ArgIndex) {
+  if (!ArgTy->getAs<VectorType>() &&
+      !ConstantMatrixType::isValidElementType(ArgTy)) {
     return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
-           << 1 << /* vector, integer or float ty*/ 0 << Ty;
+           << ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
   }
 
   return false;
@@ -2958,6 +2959,10 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
   }
   }
 
+  if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall)) {
+    return ExprError();
+  }
+
   // Since the target specific builtins for each arch overlap, only check those
   // of the arch we are compiling for.
   if (Context.BuiltinInfo.isTSBuiltin(BuiltinID)) {
@@ -5158,6 +5163,75 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
   return false;
 }
 
+// Note: returning true in this case results in CheckBuiltinFunctionCall
+// returning an ExprError
+bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
+  switch (BuiltinID) {
+  case Builtin::BI__builtin_hlsl_dot: {
+    if (checkArgCount(*this, TheCall, 2)) {
+      return true;
+    }
+    Expr *Arg0 = TheCall->getArg(0);
+    QualType ArgTy0 = Arg0->getType();
+
+    Expr *Arg1 = TheCall->getArg(1);
+    QualType ArgTy1 = Arg1->getType();
+
+    auto *VecTy0 = ArgTy0->getAs<VectorType>();
+    auto *VecTy1 = ArgTy1->getAs<VectorType>();
+    SourceLocation BuiltinLoc = TheCall->getBeginLoc();
+
+    // if arg0 is bool then call Diag with err_builtin_invalid_arg_type
+    if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
+      return true;
+    }
+
+    // if arg1 is bool then call Diag with err_builtin_invalid_arg_type
+    if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
+      return true;
+    }
+
+    if (VecTy0 == nullptr && VecTy1 == nullptr) {
+      if (ArgTy0 != ArgTy1) {
+        return true;
+      } else {
+        return false;
+      }
+    }
+
+    if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
+        (VecTy0 != nullptr && VecTy1 == nullptr)) {
+
+      Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
+          << TheCall->getDirectCallee()
+          << SourceRange(TheCall->getArg(0)->getBeginLoc(),
+                         TheCall->getArg(1)->getEndLoc());
+      return true;
+    }
+
+    if (VecTy0->getElementType() != VecTy1->getElementType()) {
+      // Note: This case should never happen. If type promotion occurs
+      // then element types won't be different. This diag error is here
+      // b\c EmitHLSLBuiltinExpr asserts on this case.
+      Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
+          << TheCall->getDirectCallee()
+          << SourceRange(TheCall->getArg(0)->getBeginLoc(),
+                         TheCall->getArg(1)->getEndLoc());
+      return true;
+    }
+    if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
+      Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size)
+          << TheCall->getDirectCallee()
+          << SourceRange(TheCall->getArg(0)->getBeginLoc(),
+                         TheCall->getArg(1)->getEndLoc());
+      return true;
+    }
+    break;
+  }
+  }
+  return false;
+}
+
 bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID,
                                           CallExpr *TheCall) {
   // position of memory order and scope arguments in the builtin
@@ -19583,7 +19657,7 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
   TheCall->setArg(0, A.get());
   QualType TyA = A.get()->getType();
 
-  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
+  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
     return true;
 
   TheCall->setType(TyA);
@@ -19611,7 +19685,7 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
                 diag::err_typecheck_call_different_arg_types)
            << TyA << TyB;
 
-  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
+  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
     return true;
 
   TheCall->setArg(0, A.get());
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
new file mode 100644
index 00000000000000..9a895cd190ba9f
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -0,0 +1,216 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -O3 -o - | FileCheck %s \ 
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// -fnative-half-type sets __HLSL_ENABLE_16_BIT
+#ifdef __HLSL_ENABLE_16_BIT
+// NATIVE_HALF: %dx.dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dx.dot
+int16_t test_dot_short ( int16_t p0, int16_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+int16_t test_dot_short2 ( int16_t2 p0, int16_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+int16_t test_dot_short3 ( int16_t3 p0, int16_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+int16_t test_dot_short4 ( int16_t4 p0, int16_t4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort ( uint16_t p0, uint16_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort2 ( uint16_t2 p0, uint16_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort3 ( uint16_t3 p0, uint16_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort4 ( uint16_t4 p0, uint16_t4 p1 ) {
+  return dot ( p0, p1 );
+}
+#endif
+
+// CHECK: %dx.dot = mul i32 %0, %1
+// CHECK: ret i32 %dx.dot
+int test_dot_int ( int p0, int p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int2 ( int2 p0, int2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int3 ( int3 p0, int3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int4 ( int4 p0, int4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = mul i32 %0, %1
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint ( uint p0, uint p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint2 ( uint2 p0, uint2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint3 ( uint3 p0, uint3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint4 ( uint4 p0, uint4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = mul i64 %0, %1
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_long ( int64_t p0, int64_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_uint2 ( int64_t2 p0, int64_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_uint3 ( int64_t3 p0, int64_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_uint4 ( int64_t4 p0, int64_t4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK:  %dx.dot = mul i64 %0, %1
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_ulong ( uint64_t p0, uint64_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_uint2 ( uint64_t2 p0, uint64_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_uint3 ( uint64_t3 p0, uint64_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_uint4 ( uint64_t4 p0, uint64_t4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = fmul half %0, %1
+// NATIVE_HALF: ret half %dx.dot
+// NO_HALF: %dx.dot = fmul float %0, %1
+// NO_HALF: ret float %dx.dot
+half test_dot_half ( half p0, half p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half2 ( half2 p0, half2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half3 ( half3 p0, half3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half4 ( half4 p0, half4 p1 ) {
+  return dot ( p0, p1 );
+}...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 16, 2024

@llvm/pr-subscribers-backend-x86

Author: Farzon Lotfi (farzonl)

Changes

This change implements #70073

HLSL has a dot intrinsic defined here:
https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot

The intrinsic itself is defined as a HLSL_LANG LangBuiltin in Builtins.td.
This is used to associate all the dot product typdef defined hlsl_intrinsics.h
with a single intrinsic check in CGBuiltin.cpp & SemaChecking.cpp.

In IntrinsicsDirectX.td we define the llvmIR for the dot product.
A few goals were in mind for this IR. First it should operate on only
vectors. Second the return type should be the vector element type. Third
the second parameter vector should be of the same size as the first
parameter. Finally a dot b should be the same as b dot a.

In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot
product though is language specific intrinsic and so is guarded behind getLangOpts().HLSL.
The call chain looks like this: EmitBuiltinExpr -> EmitHLSLBuiltinExp

EmitHLSLBuiltinExp dot product intrinsics makes a destinction
between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply.

Sema.h & SemaChecking.cpp saw the addition of CheckHLSLBuiltinFunctionCall, a language specific semantic validation that can be expanded for other hlsl specific intrinsics.


Patch is 23.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81190.diff

10 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+2)
  • (modified) clang/include/clang/Sema/Sema.h (+1)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+49)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+1)
  • (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+92)
  • (modified) clang/lib/Sema/SemaChecking.cpp (+79-5)
  • (added) clang/test/CodeGenHLSL/builtins/dot.hlsl (+216)
  • (added) clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl (+46)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+5)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index df74026c5d2d50..771c4f5d4121f4 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4524,6 +4524,12 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void*(unsigned char)";
 }
 
+def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_dot"];
+  let Attributes = [NoThrow, Const, CustomTypeChecking];
+  let Prototype = "void(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index fb44495199ee56..9957f4e743e9cb 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10267,6 +10267,8 @@ def err_vec_builtin_non_vector : Error<
  "first two arguments to %0 must be vectors">;
 def err_vec_builtin_incompatible_vector : Error<
   "first two arguments to %0 must have the same type">;
+def err_vec_builtin_incompatible_size : Error<
+  "first two arguments to %0 must have the same size">;
 def err_vsx_builtin_nonconstant_argument : Error<
   "argument %0 to %1 must be a 2-bit unsigned literal (i.e. 0, 1, 2 or 3)">;
 
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index e9cd42ae777df5..3557db56905ff8 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -14055,6 +14055,7 @@ class Sema final {
   bool CheckPPCBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
                                    CallExpr *TheCall);
   bool CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
+  bool CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
   bool CheckRISCVLMUL(CallExpr *TheCall, unsigned ArgNum);
   bool CheckRISCVBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
                                      CallExpr *TheCall);
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index d454ccc1dd8613..4bad41a9be214e 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -44,6 +44,7 @@
 #include "llvm/IR/IntrinsicsAMDGPU.h"
 #include "llvm/IR/IntrinsicsARM.h"
 #include "llvm/IR/IntrinsicsBPF.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
 #include "llvm/IR/IntrinsicsHexagon.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/IntrinsicsPowerPC.h"
@@ -5982,6 +5983,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
     llvm_unreachable("Bad evaluation kind in EmitBuiltinExpr");
   }
 
+  // EmitHLSLBuiltinExpr will check getLangOpts().HLSL
+  if (Value *V = EmitHLSLBuiltinExpr(BuiltinID, E))
+    return RValue::get(V);
+
   if (getLangOpts().HIPStdPar && getLangOpts().CUDAIsDevice)
     return EmitHipStdParUnsupportedBuiltin(this, FD);
 
@@ -17896,6 +17901,50 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments,
   return Arg;
 }
 
+Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
+                                            const CallExpr *E) {
+  if (!getLangOpts().HLSL)
+    return nullptr;
+
+  switch (BuiltinID) {
+  case Builtin::BI__builtin_hlsl_dot: {
+    Value *Op0 = EmitScalarExpr(E->getArg(0));
+    Value *Op1 = EmitScalarExpr(E->getArg(1));
+    llvm::Type *T0 = Op0->getType();
+    llvm::Type *T1 = Op1->getType();
+    if (!T0->isVectorTy() && !T1->isVectorTy()) {
+      if (T0->isFloatingPointTy()) {
+        return Builder.CreateFMul(Op0, Op1, "dx.dot");
+      }
+
+      if (T0->isIntegerTy()) {
+        return Builder.CreateMul(Op0, Op1, "dx.dot");
+      }
+      assert(
+          false &&
+          "Dot product on a scalar is only supported on integers and floats.");
+    }
+    assert(T0->isVectorTy() && T1->isVectorTy() &&
+           "Dot product of vector and scalar is not supported.");
+
+    // NOTE: this assert will need to be revisited after overload resoltion
+    //  PR merges.
+    assert(T0->getScalarType() == T1->getScalarType() &&
+           "Dot product of vectors need the same element types.");
+
+    auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
+    auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
+    assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
+           "Dot product requires vectors to be of the same size.");
+
+    return Builder.CreateIntrinsic(
+        /*ReturnType*/ T0->getScalarType(), Intrinsic::dx_dot,
+        ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot");
+  } break;
+  }
+  return nullptr;
+}
+
 Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
                                               const CallExpr *E) {
   llvm::AtomicOrdering AO = llvm::AtomicOrdering::SequentiallyConsistent;
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index caa6a327550baa..3169802ffc2c7d 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4405,6 +4405,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitX86BuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
+  llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
                                            const CallExpr *E);
   llvm::Value *EmitSystemZBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index f87ac977997962..a92f0d0849ba77 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -179,6 +179,98 @@ double3 cos(double3);
 _HLSL_BUILTIN_ALIAS(__builtin_elementwise_cos)
 double4 cos(double4);
 
+//===----------------------------------------------------------------------===//
+// dot product builtins
+//===----------------------------------------------------------------------===//
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half, half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half2, half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half3, half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+half dot(half4, half4);
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t, int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t2, int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t3, int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int16_t dot(int16_t4, int16_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t, uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t2, uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t3, uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.2)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint16_t dot(uint16_t4, uint16_t4);
+#endif
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float, float);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float2, float2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float3, float3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+float dot(float4, float4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+double dot(double, double);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int, int);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int2, int2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int3, int3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int dot(int4, int4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint, uint);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint2, uint2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint3, uint3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint dot(uint4, uint4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t, int64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t2, int64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t3, int64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+int64_t dot(int64_t4, int64_t4);
+
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t, uint64_t);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t2, uint64_t2);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t3, uint64_t3);
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_dot)
+uint64_t dot(uint64_t4, uint64_t4);
+
 //===----------------------------------------------------------------------===//
 // floor builtins
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 8e763384774444..b5f3f8b2e96db7 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2120,10 +2120,11 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
 // not a valid type, emit an error message and return true. Otherwise return
 // false.
 static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
-                                        QualType Ty) {
-  if (!Ty->getAs<VectorType>() && !ConstantMatrixType::isValidElementType(Ty)) {
+                                        QualType ArgTy, int ArgIndex) {
+  if (!ArgTy->getAs<VectorType>() &&
+      !ConstantMatrixType::isValidElementType(ArgTy)) {
     return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
-           << 1 << /* vector, integer or float ty*/ 0 << Ty;
+           << ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
   }
 
   return false;
@@ -2958,6 +2959,10 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
   }
   }
 
+  if (getLangOpts().HLSL && CheckHLSLBuiltinFunctionCall(BuiltinID, TheCall)) {
+    return ExprError();
+  }
+
   // Since the target specific builtins for each arch overlap, only check those
   // of the arch we are compiling for.
   if (Context.BuiltinInfo.isTSBuiltin(BuiltinID)) {
@@ -5158,6 +5163,75 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
   return false;
 }
 
+// Note: returning true in this case results in CheckBuiltinFunctionCall
+// returning an ExprError
+bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
+  switch (BuiltinID) {
+  case Builtin::BI__builtin_hlsl_dot: {
+    if (checkArgCount(*this, TheCall, 2)) {
+      return true;
+    }
+    Expr *Arg0 = TheCall->getArg(0);
+    QualType ArgTy0 = Arg0->getType();
+
+    Expr *Arg1 = TheCall->getArg(1);
+    QualType ArgTy1 = Arg1->getType();
+
+    auto *VecTy0 = ArgTy0->getAs<VectorType>();
+    auto *VecTy1 = ArgTy1->getAs<VectorType>();
+    SourceLocation BuiltinLoc = TheCall->getBeginLoc();
+
+    // if arg0 is bool then call Diag with err_builtin_invalid_arg_type
+    if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
+      return true;
+    }
+
+    // if arg1 is bool then call Diag with err_builtin_invalid_arg_type
+    if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
+      return true;
+    }
+
+    if (VecTy0 == nullptr && VecTy1 == nullptr) {
+      if (ArgTy0 != ArgTy1) {
+        return true;
+      } else {
+        return false;
+      }
+    }
+
+    if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
+        (VecTy0 != nullptr && VecTy1 == nullptr)) {
+
+      Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
+          << TheCall->getDirectCallee()
+          << SourceRange(TheCall->getArg(0)->getBeginLoc(),
+                         TheCall->getArg(1)->getEndLoc());
+      return true;
+    }
+
+    if (VecTy0->getElementType() != VecTy1->getElementType()) {
+      // Note: This case should never happen. If type promotion occurs
+      // then element types won't be different. This diag error is here
+      // b\c EmitHLSLBuiltinExpr asserts on this case.
+      Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
+          << TheCall->getDirectCallee()
+          << SourceRange(TheCall->getArg(0)->getBeginLoc(),
+                         TheCall->getArg(1)->getEndLoc());
+      return true;
+    }
+    if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
+      Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size)
+          << TheCall->getDirectCallee()
+          << SourceRange(TheCall->getArg(0)->getBeginLoc(),
+                         TheCall->getArg(1)->getEndLoc());
+      return true;
+    }
+    break;
+  }
+  }
+  return false;
+}
+
 bool Sema::CheckAMDGCNBuiltinFunctionCall(unsigned BuiltinID,
                                           CallExpr *TheCall) {
   // position of memory order and scope arguments in the builtin
@@ -19583,7 +19657,7 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
   TheCall->setArg(0, A.get());
   QualType TyA = A.get()->getType();
 
-  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
+  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
     return true;
 
   TheCall->setType(TyA);
@@ -19611,7 +19685,7 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
                 diag::err_typecheck_call_different_arg_types)
            << TyA << TyB;
 
-  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA))
+  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
     return true;
 
   TheCall->setArg(0, A.get());
diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl
new file mode 100644
index 00000000000000..9a895cd190ba9f
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl
@@ -0,0 +1,216 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -fnative-half-type \
+// RUN:   -emit-llvm -disable-llvm-passes -O3 -o - | FileCheck %s \ 
+// RUN:   --check-prefixes=CHECK,NATIVE_HALF
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
+// RUN:   -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF
+
+// -fnative-half-type sets __HLSL_ENABLE_16_BIT
+#ifdef __HLSL_ENABLE_16_BIT
+// NATIVE_HALF: %dx.dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dx.dot
+int16_t test_dot_short ( int16_t p0, int16_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+int16_t test_dot_short2 ( int16_t2 p0, int16_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+int16_t test_dot_short3 ( int16_t3 p0, int16_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+int16_t test_dot_short4 ( int16_t4 p0, int16_t4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = mul i16 %0, %1
+// NATIVE_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort ( uint16_t p0, uint16_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v2i16(<2 x i16> %0, <2 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort2 ( uint16_t2 p0, uint16_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v3i16(<3 x i16> %0, <3 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort3 ( uint16_t3 p0, uint16_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call i16 @llvm.dx.dot.v4i16(<4 x i16> %0, <4 x i16> %1)
+// NATIVE_HALF: ret i16 %dx.dot
+uint16_t test_dot_ushort4 ( uint16_t4 p0, uint16_t4 p1 ) {
+  return dot ( p0, p1 );
+}
+#endif
+
+// CHECK: %dx.dot = mul i32 %0, %1
+// CHECK: ret i32 %dx.dot
+int test_dot_int ( int p0, int p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int2 ( int2 p0, int2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int3 ( int3 p0, int3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dx.dot
+int test_dot_int4 ( int4 p0, int4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = mul i32 %0, %1
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint ( uint p0, uint p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v2i32(<2 x i32> %0, <2 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint2 ( uint2 p0, uint2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v3i32(<3 x i32> %0, <3 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint3 ( uint3 p0, uint3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i32 @llvm.dx.dot.v4i32(<4 x i32> %0, <4 x i32> %1)
+// CHECK: ret i32 %dx.dot
+uint test_dot_uint4 ( uint4 p0, uint4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = mul i64 %0, %1
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_long ( int64_t p0, int64_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_uint2 ( int64_t2 p0, int64_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_uint3 ( int64_t3 p0, int64_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dx.dot
+int64_t test_dot_uint4 ( int64_t4 p0, int64_t4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK:  %dx.dot = mul i64 %0, %1
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_ulong ( uint64_t p0, uint64_t p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v2i64(<2 x i64> %0, <2 x i64> %1)
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_uint2 ( uint64_t2 p0, uint64_t2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v3i64(<3 x i64> %0, <3 x i64> %1)
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_uint3 ( uint64_t3 p0, uint64_t3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// CHECK: %dx.dot = call i64 @llvm.dx.dot.v4i64(<4 x i64> %0, <4 x i64> %1)
+// CHECK: ret i64 %dx.dot
+uint64_t test_dot_uint4 ( uint64_t4 p0, uint64_t4 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = fmul half %0, %1
+// NATIVE_HALF: ret half %dx.dot
+// NO_HALF: %dx.dot = fmul float %0, %1
+// NO_HALF: ret float %dx.dot
+half test_dot_half ( half p0, half p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1)
+// NATIVE_HALF: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half2 ( half2 p0, half2 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1)
+// NATIVE_HALF: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half3 ( half3 p0, half3 p1 ) {
+  return dot ( p0, p1 );
+}
+
+// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1)
+// NATIVE_HALF: ret half %dx.dot
+// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1)
+// NO_HALF: ret float %dx.dot
+half test_dot_half4 ( half4 p0, half4 p1 ) {
+  return dot ( p0, p1 );
+}...
[truncated]

@farzonl farzonl force-pushed the llvm-farzon/hlsl-dot-intrinsic branch 2 times, most recently from 586ee65 to 2255c53 Compare February 20, 2024 15:57
@farzonl farzonl marked this pull request as draft February 21, 2024 19:53
@farzonl farzonl force-pushed the llvm-farzon/hlsl-dot-intrinsic branch from 2255c53 to 3524651 Compare February 21, 2024 20:22
@farzonl farzonl force-pushed the llvm-farzon/hlsl-dot-intrinsic branch 5 times, most recently from f60cfd0 to f40562c Compare February 23, 2024 03:59
@farzonl farzonl marked this pull request as ready for review February 23, 2024 05:55
clang/lib/CodeGen/CGBuiltin.cpp Outdated Show resolved Hide resolved
clang/lib/Headers/hlsl/hlsl_intrinsics.h Show resolved Hide resolved
clang/lib/Sema/SemaChecking.cpp Outdated Show resolved Hide resolved
clang/lib/Sema/SemaChecking.cpp Show resolved Hide resolved
clang/lib/Sema/SemaChecking.cpp Outdated Show resolved Hide resolved
clang/lib/Sema/SemaChecking.cpp Outdated Show resolved Hide resolved
clang/lib/Sema/SemaChecking.cpp Outdated Show resolved Hide resolved
This change implements llvm#70073

HLSL has a dot intrinsic defined here:
https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-dot

The intrinsic itself is defined as a HLSL_LANG LangBuiltin in Builtins.td.
This is used to associate all the dot product typdef defined hlsl_intrinsics.h
with a single intrinsic check in CGBuiltin.cpp & SemaChecking.cpp.

In IntrinsicsDirectX.td we define the llvmIR for the dot product.
A few goals were in mind for this IR. First it should operate on only
vectors. Second the return type should be the vector element type. Third
the second parameter vector should be of the same size as the first
parameter. Finally `a dot b` should be the same as `b dot a`.

In CGBuiltin.cpp hlsl has built on top of existing clang intrinsics via EmitBuiltinExpr. Dot
product though is language specific intrinsic and so is guarded behind getLangOpts().HLSL.
The call chain looks like this: EmitBuiltinExpr -> EmitHLSLBuiltinExp

EmitHLSLBuiltinExp dot product intrinsics makes a destinction
between vectors and scalars. This is because HLSL supports dot product on scalars which simplifies down to multiply.

Sema.h & SemaChecking.cpp saw the addition of CheckHLSLBuiltinFunctionCall, a language specific semantic validation that can be expanded for other hlsl specific intrinsics.
Add more robustness to SemaChecking
@farzonl farzonl force-pushed the llvm-farzon/hlsl-dot-intrinsic branch from 1545cdf to 080b65f Compare February 23, 2024 23:53
Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few small notes, but otherwise looks good.

clang/lib/CodeGen/CGBuiltin.cpp Outdated Show resolved Hide resolved
clang/lib/CodeGen/CGBuiltin.cpp Outdated Show resolved Hide resolved
clang/lib/CodeGen/CGBuiltin.cpp Outdated Show resolved Hide resolved
clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl Outdated Show resolved Hide resolved
clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl Show resolved Hide resolved
@farzonl farzonl force-pushed the llvm-farzon/hlsl-dot-intrinsic branch from 3152962 to 37091d1 Compare February 24, 2024 01:19
@llvm-beanz llvm-beanz merged commit 82acec1 into llvm:main Feb 26, 2024
3 of 4 checks passed
@farzonl farzonl deleted the llvm-farzon/hlsl-dot-intrinsic branch February 29, 2024 18:45
@farzonl farzonl linked an issue Mar 12, 2024 that may be closed by this pull request
@farzonl farzonl self-assigned this Mar 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX clang:codegen clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

[HLSL] implement dot intrinsic
5 participants