Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DirectX] Use scalar arguments for @llvm.dx.dot intrinsics #134570

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

bogner
Copy link
Contributor

@bogner bogner commented Apr 7, 2025

The dx.dot2, dot3, and dot4 intrinsics exist purely to lower dx.fdot, and they map exactly to the DXIL ops of the same name. Using vectors for their arguments adds unnecessary complexity and causes us to have vector operations that are not trivial to lower post-scalarizer.

Similarly, the dx.dot2add intrinsic is overly generic for something that only needs to lower to a single dot2AddHalf DXIL op. Update its signature to match the operation it lowers to.

Fixes #134569.

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:codegen IR generation bugs: mangling, exceptions, etc. backend:DirectX HLSL HLSL Language Support llvm:ir labels Apr 7, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 7, 2025

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-clang

Author: Justin Bogner (bogner)

Changes

The dx.dot2, dot3, and dot4 intrinsics exist purely to lower dx.fdot, and they map exactly to the DXIL ops of the same name. Using vectors for their arguments adds unnecessary complexity and causes us to have vector operations that are not trivial to lower post-scalarizer.

Similarly, the dx.dot2add intrinsic is overly generic for something that only needs to lower to a single dot2AddHalf DXIL op. Update its signature to match the operation it lowers to.

Fixes #134569.


Patch is 34.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/134570.diff

12 Files Affected:

  • (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+8-3)
  • (modified) clang/test/CodeGenHLSL/builtins/dot2add.hlsl (+50-10)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+24-15)
  • (modified) llvm/lib/Target/DirectX/DXIL.td (+1-2)
  • (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+10-3)
  • (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (-58)
  • (modified) llvm/test/CodeGen/DirectX/dot2_error.ll (+3-2)
  • (modified) llvm/test/CodeGen/DirectX/dot2add.ll (+8-3)
  • (modified) llvm/test/CodeGen/DirectX/dot3_error.ll (+4-2)
  • (modified) llvm/test/CodeGen/DirectX/dot4_error.ll (+5-2)
  • (modified) llvm/test/CodeGen/DirectX/fdot.ll (+43-43)
  • (modified) llvm/test/CodeGen/DirectX/normalize.ll (+6-6)
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 27d1c69439944..4e92be0664f71 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -385,12 +385,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
            "Intrinsic dot2add is only allowed for dxil architecture");
     Value *A = EmitScalarExpr(E->getArg(0));
     Value *B = EmitScalarExpr(E->getArg(1));
-    Value *C = EmitScalarExpr(E->getArg(2));
+    Value *Acc = EmitScalarExpr(E->getArg(2));
+
+    Value *AX = Builder.CreateExtractElement(A, Builder.getSize(0));
+    Value *AY = Builder.CreateExtractElement(A, Builder.getSize(1));
+    Value *BX = Builder.CreateExtractElement(B, Builder.getSize(0));
+    Value *BY = Builder.CreateExtractElement(B, Builder.getSize(1));
 
     Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
     return Builder.CreateIntrinsic(
-        /*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
-        "dx.dot2add");
+        /*ReturnType=*/Acc->getType(), ID,
+        ArrayRef<Value *>{Acc, AX, AY, BX, BY}, nullptr, "dx.dot2add");
   }
   case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
     Value *A = EmitScalarExpr(E->getArg(0));
diff --git a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
index 2464607dd636c..c345e17476e08 100644
--- a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
@@ -13,7 +13,11 @@ float test_default_parameter_type(half2 p1, half2 p2, float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -25,7 +29,11 @@ float test_float_arg2_type(half2 p1, float2 p2, float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -37,7 +45,11 @@ float test_float_arg1_type(float2 p1, half2 p2, float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -49,7 +61,11 @@ float test_double_arg3_type(half2 p1, half2 p2, double p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -62,7 +78,11 @@ float test_float_arg1_arg2_type(float2 p1, float2 p2, float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -75,7 +95,11 @@ float test_double_arg1_arg2_type(double2 p1, double2 p2, float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -88,7 +112,11 @@ float test_int16_arg1_arg2_type(int16_t2 p1, int16_t2 p2, float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -101,7 +129,11 @@ float test_int32_arg1_arg2_type(int32_t2 p1, int32_t2 p2, float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -114,7 +146,11 @@ float test_int64_arg1_arg2_type(int64_t2 p1, int64_t2 p2, float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -129,7 +165,11 @@ float test_bool_arg1_arg2_type(bool2 p1, bool2 p2, float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 775d325feeb14..b1a27311e2a9c 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -76,18 +76,27 @@ def int_dx_nclamp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>,
 def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 
-def int_dx_dot2 :
-    DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
-    [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [IntrNoMem, Commutative] >;
-def int_dx_dot3 :
-    DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
-    [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [IntrNoMem, Commutative] >;
-def int_dx_dot4 :
-    DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
-    [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
-    [IntrNoMem, Commutative] >;
+def int_dx_dot2 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+                                        [
+                                          llvm_anyfloat_ty, LLVMMatchType<0>,
+                                          LLVMMatchType<0>, LLVMMatchType<0>
+                                        ],
+                                        [IntrNoMem, Commutative]>;
+def int_dx_dot3 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+                                        [
+                                          llvm_anyfloat_ty, LLVMMatchType<0>,
+                                          LLVMMatchType<0>, LLVMMatchType<0>,
+                                          LLVMMatchType<0>, LLVMMatchType<0>
+                                        ],
+                                        [IntrNoMem, Commutative]>;
+def int_dx_dot4 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+                                        [
+                                          llvm_anyfloat_ty, LLVMMatchType<0>,
+                                          LLVMMatchType<0>, LLVMMatchType<0>,
+                                          LLVMMatchType<0>, LLVMMatchType<0>,
+                                          LLVMMatchType<0>, LLVMMatchType<0>
+                                        ],
+                                        [IntrNoMem, Commutative]>;
 def int_dx_fdot :
     DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
     [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
@@ -100,9 +109,9 @@ def int_dx_udot :
     DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
     [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, Commutative] >;
-def int_dx_dot2add : 
-    DefaultAttrsIntrinsic<[llvm_float_ty], 
-    [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty], 
+def int_dx_dot2add :
+    DefaultAttrsIntrinsic<[llvm_float_ty],
+    [llvm_float_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty],
     [IntrNoMem, Commutative]>;
 def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
 def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index b1e7406ead675..645105ade72b6 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -1078,8 +1078,7 @@ def RawBufferStore : DXILOp<140, rawBufferStore> {
 }
 
 def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
-  let Doc = "dot product of 2 vectors of half having size = 2, returns "
-            "float";
+  let Doc = "2D half dot product with accumulate to float";
   let intrinsics = [IntrinSelect<int_dx_dot2add>];
   let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy];
   let result = FloatTy;
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e44d3b70eb657..53ffcc3ebbdbe 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -169,7 +169,8 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
   assert(ATy->getScalarType()->isFloatingPointTy());
 
   Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
-  switch (AVec->getNumElements()) {
+  int NumElts = AVec->getNumElements();
+  switch (NumElts) {
   case 2:
     DotIntrinsic = Intrinsic::dx_dot2;
     break;
@@ -185,8 +186,14 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
         /* gen_crash_diag=*/false);
     return nullptr;
   }
-  return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
-                                 ArrayRef<Value *>{A, B}, nullptr, "dot");
+
+  SmallVector<Value *> Args;
+  for (int I = 0; I < NumElts; ++I)
+    Args.push_back(Builder.CreateExtractElement(A, Builder.getInt32(I)));
+  for (int I = 0; I < NumElts; ++I)
+    Args.push_back(Builder.CreateExtractElement(B, Builder.getInt32(I)));
+  return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, Args,
+                                 nullptr, "dot");
 }
 
 // Create the appropriate DXIL float dot intrinsic for the operands of Orig
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 41a9426998826..4574e5f7bbd96 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -33,52 +33,6 @@
 using namespace llvm;
 using namespace llvm::dxil;
 
-static bool isVectorArgExpansion(Function &F) {
-  switch (F.getIntrinsicID()) {
-  case Intrinsic::dx_dot2:
-  case Intrinsic::dx_dot3:
-  case Intrinsic::dx_dot4:
-    return true;
-  }
-  return false;
-}
-
-static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
-  SmallVector<Value *> ExtractedElements;
-  auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
-  for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
-    Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
-    Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
-    ExtractedElements.push_back(ExtractedElement);
-  }
-  return ExtractedElements;
-}
-
-static SmallVector<Value *>
-argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder, unsigned NumOperands) {
-  assert(NumOperands > 0);
-  Value *Arg0 = Orig->getOperand(0);
-  [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
-  assert(VecArg0);
-  SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
-  for (unsigned I = 1; I < NumOperands; ++I) {
-    Value *Arg = Orig->getOperand(I);
-    [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
-    assert(VecArg);
-    assert(VecArg0->getElementType() == VecArg->getElementType());
-    assert(VecArg0->getNumElements() == VecArg->getNumElements());
-    auto NextOperandList = populateOperands(Arg, Builder);
-    NewOperands.append(NextOperandList.begin(), NextOperandList.end());
-  }
-  return NewOperands;
-}
-
-static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
-                                             IRBuilder<> &Builder) {
-  // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
-  return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1);
-}
-
 namespace {
 class OpLowerer {
   Module &M;
@@ -150,9 +104,6 @@ class OpLowerer {
   [[nodiscard]] bool
   replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
                         ArrayRef<IntrinArgSelect> ArgSelects) {
-    bool IsVectorArgExpansion = isVectorArgExpansion(F);
-    assert(!(IsVectorArgExpansion && ArgSelects.size()) &&
-           "Cann't do vector arg expansion when using arg selects.");
     return replaceFunction(F, [&](CallInst *CI) -> Error {
       OpBuilder.getIRB().SetInsertPoint(CI);
       SmallVector<Value *> Args;
@@ -170,15 +121,6 @@ class OpLowerer {
             break;
           }
         }
-      } else if (IsVectorArgExpansion) {
-        Args = argVectorFlatten(CI, OpBuilder.getIRB());
-      } else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) {
-        // arg[NumOperands-1] is a pointer and is not needed by our flattening.
-        // arg[NumOperands-2] also does not need to be flattened because it is a
-        // scalar.
-        unsigned NumOperands = CI->getNumOperands() - 2;
-        Args.push_back(CI->getArgOperand(NumOperands));
-        Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands));
       } else {
         Args.append(CI->arg_begin(), CI->arg_end());
       }
diff --git a/llvm/test/CodeGen/DirectX/dot2_error.ll b/llvm/test/CodeGen/DirectX/dot2_error.ll
index 97b025d36f018..f2167aa516057 100644
--- a/llvm/test/CodeGen/DirectX/dot2_error.ll
+++ b/llvm/test/CodeGen/DirectX/dot2_error.ll
@@ -4,8 +4,9 @@
 ; CHECK: in function dot_double2
 ; CHECK-SAME: Cannot create Dot2 operation: Invalid overload type
 
...
[truncated]

Value *C = EmitScalarExpr(E->getArg(2));
Value *Acc = EmitScalarExpr(E->getArg(2));

Value *AX = Builder.CreateExtractElement(A, Builder.getSize(0));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not sure how i feel about scalarization in the frontend. This seems like a one off so have a minor concern that changes like this will make the vectorized DXIL coming in 6.9 have special cases in the frontend. This is a concern for the future so take it with a grain of salt.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue that this really is a special case. dot2add can only possibly take two-element vectors (it's right in the name), so we don't need to worry about scalarizing vectors of different lengths. It's also unlikely that we'll be adding DXIL operations for anything other than this version that works with half. If this operation were something more generic I'd be a little more hesitant to do this logic early, but I don't think we need to worry here.

ArrayRef<Value *>{A, B}, nullptr, "dot");

SmallVector<Value *> Args;
for (int I = 0; I < NumElts; ++I)
Copy link
Member

@farzonl farzonl Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have the new DirectX target builtins I had another thought on how to do this that would eliminate our convert one intrinsic to another intrinsic trick we are doing here. Just expose dot2, dot3, and dot4 in the frontend. If we are comfortable with doing scalarization in the frontend for dot2add then does it make sense to do that for these fdot cases?

Downside is we would need sema check for these 3 new cases which seems like overkill and maybe disqualifies this idea.

You may also have to call a builtin from a builtin via FunctionDecl *lookupBuiltinFunction(..) to avoid having to change the HLSL headers which is maybe another disqualifier.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think having a generic fdot for the frontend to emit makes sense. The implementation being per vector size is a bit odd and wouldn't scale, so if we did ever have a >4 element fdot I feel like it would be nicer for the backend to handle it rather than need the frontend to contend with all of these details.

Copy link
Member

@farzonl farzonl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable

The `dx.dot2`, `dot3`, and `dot4` intrinsics exist purely to lower
`dx.fdot`, and they map exactly to the DXIL ops of the same name. Using
vectors for their arguments adds unnecessary complexity and causes us to
have vector operations that are not trivial to lower post-scalarizer.

Similarly, the `dx.dot2add` intrinsic is overly generic for something
that only needs to lower to a single `dot2AddHalf` DXIL op. Update its
signature to match the operation it lowers to.

Fixes llvm#134569.
@bogner bogner force-pushed the 2025-04-07-scalar-dot branch from 7f8833f to 1fa97b4 Compare April 8, 2025 00:27
@farzonl
Copy link
Member

farzonl commented Apr 8, 2025

There is a dot2add.c test case you will need to update. Target builtins are suppose to work across language modes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX clang:codegen IR generation bugs: mangling, exceptions, etc. clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

[DirectX] The dx.dot2/3/4 and dx.dot2add intrinsics leave vectors around post-scalarization
3 participants