-
Notifications
You must be signed in to change notification settings - Fork 13.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DirectX] Use scalar arguments for @llvm.dx.dot intrinsics #134570
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-clang-codegen @llvm/pr-subscribers-clang Author: Justin Bogner (bogner) ChangesThe Similarly, the Fixes #134569. Patch is 34.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/134570.diff 12 Files Affected:
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 27d1c69439944..4e92be0664f71 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -385,12 +385,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
"Intrinsic dot2add is only allowed for dxil architecture");
Value *A = EmitScalarExpr(E->getArg(0));
Value *B = EmitScalarExpr(E->getArg(1));
- Value *C = EmitScalarExpr(E->getArg(2));
+ Value *Acc = EmitScalarExpr(E->getArg(2));
+
+ Value *AX = Builder.CreateExtractElement(A, Builder.getSize(0));
+ Value *AY = Builder.CreateExtractElement(A, Builder.getSize(1));
+ Value *BX = Builder.CreateExtractElement(B, Builder.getSize(0));
+ Value *BY = Builder.CreateExtractElement(B, Builder.getSize(1));
Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
return Builder.CreateIntrinsic(
- /*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
- "dx.dot2add");
+ /*ReturnType=*/Acc->getType(), ID,
+ ArrayRef<Value *>{Acc, AX, AY, BX, BY}, nullptr, "dx.dot2add");
}
case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
Value *A = EmitScalarExpr(E->getArg(0));
diff --git a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
index 2464607dd636c..c345e17476e08 100644
--- a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
@@ -13,7 +13,11 @@ float test_default_parameter_type(half2 p1, half2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
- // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+ // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
@@ -25,7 +29,11 @@ float test_float_arg2_type(half2 p1, float2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
- // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+ // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
@@ -37,7 +45,11 @@ float test_float_arg1_type(float2 p1, half2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
- // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+ // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
@@ -49,7 +61,11 @@ float test_double_arg3_type(half2 p1, half2 p2, double p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
- // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+ // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
@@ -62,7 +78,11 @@ float test_float_arg1_arg2_type(float2 p1, float2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
- // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+ // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
@@ -75,7 +95,11 @@ float test_double_arg1_arg2_type(double2 p1, double2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
- // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+ // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
@@ -88,7 +112,11 @@ float test_int16_arg1_arg2_type(int16_t2 p1, int16_t2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
- // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+ // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
@@ -101,7 +129,11 @@ float test_int32_arg1_arg2_type(int32_t2 p1, int32_t2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
- // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+ // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
@@ -114,7 +146,11 @@ float test_int64_arg1_arg2_type(int64_t2 p1, int64_t2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
- // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+ // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
@@ -129,7 +165,11 @@ float test_bool_arg1_arg2_type(bool2 p1, bool2 p2, float p3) {
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
- // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+ // CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+ // CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+ // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
// CHECK: ret float %[[RES]]
return dot2add(p1, p2, p3);
}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 775d325feeb14..b1a27311e2a9c 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -76,18 +76,27 @@ def int_dx_nclamp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>,
def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
-def int_dx_dot2 :
- DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
- [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, Commutative] >;
-def int_dx_dot3 :
- DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
- [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, Commutative] >;
-def int_dx_dot4 :
- DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
- [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
- [IntrNoMem, Commutative] >;
+def int_dx_dot2 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+ [
+ llvm_anyfloat_ty, LLVMMatchType<0>,
+ LLVMMatchType<0>, LLVMMatchType<0>
+ ],
+ [IntrNoMem, Commutative]>;
+def int_dx_dot3 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+ [
+ llvm_anyfloat_ty, LLVMMatchType<0>,
+ LLVMMatchType<0>, LLVMMatchType<0>,
+ LLVMMatchType<0>, LLVMMatchType<0>
+ ],
+ [IntrNoMem, Commutative]>;
+def int_dx_dot4 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+ [
+ llvm_anyfloat_ty, LLVMMatchType<0>,
+ LLVMMatchType<0>, LLVMMatchType<0>,
+ LLVMMatchType<0>, LLVMMatchType<0>,
+ LLVMMatchType<0>, LLVMMatchType<0>
+ ],
+ [IntrNoMem, Commutative]>;
def int_dx_fdot :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
@@ -100,9 +109,9 @@ def int_dx_udot :
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
-def int_dx_dot2add :
- DefaultAttrsIntrinsic<[llvm_float_ty],
- [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty],
+def int_dx_dot2add :
+ DefaultAttrsIntrinsic<[llvm_float_ty],
+ [llvm_float_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty],
[IntrNoMem, Commutative]>;
def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index b1e7406ead675..645105ade72b6 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -1078,8 +1078,7 @@ def RawBufferStore : DXILOp<140, rawBufferStore> {
}
def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
- let Doc = "dot product of 2 vectors of half having size = 2, returns "
- "float";
+ let Doc = "2D half dot product with accumulate to float";
let intrinsics = [IntrinSelect<int_dx_dot2add>];
let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy];
let result = FloatTy;
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e44d3b70eb657..53ffcc3ebbdbe 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -169,7 +169,8 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
assert(ATy->getScalarType()->isFloatingPointTy());
Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
- switch (AVec->getNumElements()) {
+ int NumElts = AVec->getNumElements();
+ switch (NumElts) {
case 2:
DotIntrinsic = Intrinsic::dx_dot2;
break;
@@ -185,8 +186,14 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
/* gen_crash_diag=*/false);
return nullptr;
}
- return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
- ArrayRef<Value *>{A, B}, nullptr, "dot");
+
+ SmallVector<Value *> Args;
+ for (int I = 0; I < NumElts; ++I)
+ Args.push_back(Builder.CreateExtractElement(A, Builder.getInt32(I)));
+ for (int I = 0; I < NumElts; ++I)
+ Args.push_back(Builder.CreateExtractElement(B, Builder.getInt32(I)));
+ return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, Args,
+ nullptr, "dot");
}
// Create the appropriate DXIL float dot intrinsic for the operands of Orig
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 41a9426998826..4574e5f7bbd96 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -33,52 +33,6 @@
using namespace llvm;
using namespace llvm::dxil;
-static bool isVectorArgExpansion(Function &F) {
- switch (F.getIntrinsicID()) {
- case Intrinsic::dx_dot2:
- case Intrinsic::dx_dot3:
- case Intrinsic::dx_dot4:
- return true;
- }
- return false;
-}
-
-static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
- SmallVector<Value *> ExtractedElements;
- auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
- for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
- Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
- Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
- ExtractedElements.push_back(ExtractedElement);
- }
- return ExtractedElements;
-}
-
-static SmallVector<Value *>
-argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder, unsigned NumOperands) {
- assert(NumOperands > 0);
- Value *Arg0 = Orig->getOperand(0);
- [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
- assert(VecArg0);
- SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
- for (unsigned I = 1; I < NumOperands; ++I) {
- Value *Arg = Orig->getOperand(I);
- [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
- assert(VecArg);
- assert(VecArg0->getElementType() == VecArg->getElementType());
- assert(VecArg0->getNumElements() == VecArg->getNumElements());
- auto NextOperandList = populateOperands(Arg, Builder);
- NewOperands.append(NextOperandList.begin(), NextOperandList.end());
- }
- return NewOperands;
-}
-
-static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
- IRBuilder<> &Builder) {
- // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
- return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1);
-}
-
namespace {
class OpLowerer {
Module &M;
@@ -150,9 +104,6 @@ class OpLowerer {
[[nodiscard]] bool
replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
ArrayRef<IntrinArgSelect> ArgSelects) {
- bool IsVectorArgExpansion = isVectorArgExpansion(F);
- assert(!(IsVectorArgExpansion && ArgSelects.size()) &&
- "Cann't do vector arg expansion when using arg selects.");
return replaceFunction(F, [&](CallInst *CI) -> Error {
OpBuilder.getIRB().SetInsertPoint(CI);
SmallVector<Value *> Args;
@@ -170,15 +121,6 @@ class OpLowerer {
break;
}
}
- } else if (IsVectorArgExpansion) {
- Args = argVectorFlatten(CI, OpBuilder.getIRB());
- } else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) {
- // arg[NumOperands-1] is a pointer and is not needed by our flattening.
- // arg[NumOperands-2] also does not need to be flattened because it is a
- // scalar.
- unsigned NumOperands = CI->getNumOperands() - 2;
- Args.push_back(CI->getArgOperand(NumOperands));
- Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands));
} else {
Args.append(CI->arg_begin(), CI->arg_end());
}
diff --git a/llvm/test/CodeGen/DirectX/dot2_error.ll b/llvm/test/CodeGen/DirectX/dot2_error.ll
index 97b025d36f018..f2167aa516057 100644
--- a/llvm/test/CodeGen/DirectX/dot2_error.ll
+++ b/llvm/test/CodeGen/DirectX/dot2_error.ll
@@ -4,8 +4,9 @@
; CHECK: in function dot_double2
; CHECK-SAME: Cannot create Dot2 operation: Invalid overload type
...
[truncated]
|
clang/lib/CodeGen/CGHLSLBuiltins.cpp
Outdated
Value *C = EmitScalarExpr(E->getArg(2)); | ||
Value *Acc = EmitScalarExpr(E->getArg(2)); | ||
|
||
Value *AX = Builder.CreateExtractElement(A, Builder.getSize(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
im not sure how i feel about scalarization in the frontend. This seems like a one off so have a minor concern that changes like this will make the vectorized DXIL coming in 6.9 have special cases in the frontend. This is a concern for the future so take it with a grain of salt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would argue that this really is a special case. dot2add
can only possibly take two-element vectors (it's right in the name), so we don't need to worry about scalarizing vectors of different lengths. It's also unlikely that we'll be adding DXIL operations for anything other than this version that works with half
. If this operation were something more generic I'd be a little more hesitant to do this logic early, but I don't think we need to worry here.
ArrayRef<Value *>{A, B}, nullptr, "dot"); | ||
|
||
SmallVector<Value *> Args; | ||
for (int I = 0; I < NumElts; ++I) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we have the new DirectX target builtins I had another thought on how to do this that would eliminate our convert one intrinsic to another intrinsic trick we are doing here. Just expose dot2, dot3, and dot4 in the frontend. If we are comfortable with doing scalarization in the frontend for dot2add then does it make sense to do that for these fdot cases?
Downside is we would need sema check for these 3 new cases which seems like overkill and maybe disqualifies this idea.
You may also have to call a builtin from a builtin via FunctionDecl *lookupBuiltinFunction(..)
to avoid having to change the HLSL headers which is maybe another disqualifier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think having a generic fdot
for the frontend to emit makes sense. The implementation being per vector size is a bit odd and wouldn't scale, so if we did ever have a >4 element fdot
I feel like it would be nicer for the backend to handle it rather than need the frontend to contend with all of these details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks reasonable
The `dx.dot2`, `dot3`, and `dot4` intrinsics exist purely to lower `dx.fdot`, and they map exactly to the DXIL ops of the same name. Using vectors for their arguments adds unnecessary complexity and causes us to have vector operations that are not trivial to lower post-scalarizer. Similarly, the `dx.dot2add` intrinsic is overly generic for something that only needs to lower to a single `dot2AddHalf` DXIL op. Update its signature to match the operation it lowers to. Fixes llvm#134569.
7f8833f
to
1fa97b4
Compare
There is a |
The
dx.dot2
,dot3
, anddot4
intrinsics exist purely to lowerdx.fdot
, and they map exactly to the DXIL ops of the same name. Using vectors for their arguments adds unnecessary complexity and causes us to have vector operations that are not trivial to lower post-scalarizer.Similarly, the
dx.dot2add
intrinsic is overly generic for something that only needs to lower to a singledot2AddHalf
DXIL op. Update its signature to match the operation it lowers to.Fixes #134569.