Skip to content

Commit c93760b

Browse files
committed
[DirectX] Use scalar arguments for @llvm.dx.dot intrinsics
The `dx.dot2`, `dot3`, and `dot4` intrinsics exist purely to lower `dx.fdot`, and they map exactly to the DXIL ops of the same name. Using vectors for their arguments adds unnecessary complexity and causes us to have vector operations that are not trivial to lower post-scalarizer. Similarly, the `dx.dot2add` intrinsic is overly generic for something that only needs to lower to a single `dot2AddHalf` DXIL op. Update its signature to match the operation it lowers to. Fixes llvm#134569.
1 parent ac42b08 commit c93760b

File tree

12 files changed

+162
-149
lines changed

12 files changed

+162
-149
lines changed

Diff for: clang/lib/CodeGen/TargetBuiltins/DirectX.cpp

+8-3
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,17 @@ Value *CodeGenFunction::EmitDirectXBuiltinExpr(unsigned BuiltinID,
2525
case DirectX::BI__builtin_dx_dot2add: {
2626
Value *A = EmitScalarExpr(E->getArg(0));
2727
Value *B = EmitScalarExpr(E->getArg(1));
28-
Value *C = EmitScalarExpr(E->getArg(2));
28+
Value *Acc = EmitScalarExpr(E->getArg(2));
29+
30+
Value *AX = Builder.CreateExtractElement(A, Builder.getSize(0));
31+
Value *AY = Builder.CreateExtractElement(A, Builder.getSize(1));
32+
Value *BX = Builder.CreateExtractElement(B, Builder.getSize(0));
33+
Value *BY = Builder.CreateExtractElement(B, Builder.getSize(1));
2934

3035
Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
3136
return Builder.CreateIntrinsic(
32-
/*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
33-
"dx.dot2add");
37+
/*ReturnType=*/Acc->getType(), ID,
38+
ArrayRef<Value *>{Acc, AX, AY, BX, BY}, nullptr, "dx.dot2add");
3439
}
3540
}
3641
return nullptr;

Diff for: clang/test/CodeGenHLSL/builtins/dot2add.hlsl

+50-10
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ float test_default_parameter_type(half2 p1, half2 p2, float p3) {
1313
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
1414
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
1515
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
16-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
16+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
17+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
18+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
19+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
20+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
1721
// CHECK: ret float %[[RES]]
1822
return dot2add(p1, p2, p3);
1923
}
@@ -25,7 +29,11 @@ float test_float_arg2_type(half2 p1, float2 p2, float p3) {
2529
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
2630
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
2731
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
28-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
32+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
33+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
34+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
35+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
36+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
2937
// CHECK: ret float %[[RES]]
3038
return dot2add(p1, p2, p3);
3139
}
@@ -37,7 +45,11 @@ float test_float_arg1_type(float2 p1, half2 p2, float p3) {
3745
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
3846
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
3947
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
40-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
48+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
49+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
50+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
51+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
52+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
4153
// CHECK: ret float %[[RES]]
4254
return dot2add(p1, p2, p3);
4355
}
@@ -49,7 +61,11 @@ float test_double_arg3_type(half2 p1, half2 p2, double p3) {
4961
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
5062
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
5163
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
52-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
64+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
65+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
66+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
67+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
68+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
5369
// CHECK: ret float %[[RES]]
5470
return dot2add(p1, p2, p3);
5571
}
@@ -62,7 +78,11 @@ float test_float_arg1_arg2_type(float2 p1, float2 p2, float p3) {
6278
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
6379
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
6480
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
65-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
81+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
82+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
83+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
84+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
85+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
6686
// CHECK: ret float %[[RES]]
6787
return dot2add(p1, p2, p3);
6888
}
@@ -75,7 +95,11 @@ float test_double_arg1_arg2_type(double2 p1, double2 p2, float p3) {
7595
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
7696
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
7797
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
78-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
98+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
99+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
100+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
101+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
102+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
79103
// CHECK: ret float %[[RES]]
80104
return dot2add(p1, p2, p3);
81105
}
@@ -88,7 +112,11 @@ float test_int16_arg1_arg2_type(int16_t2 p1, int16_t2 p2, float p3) {
88112
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
89113
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
90114
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
91-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
115+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
116+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
117+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
118+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
119+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
92120
// CHECK: ret float %[[RES]]
93121
return dot2add(p1, p2, p3);
94122
}
@@ -101,7 +129,11 @@ float test_int32_arg1_arg2_type(int32_t2 p1, int32_t2 p2, float p3) {
101129
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
102130
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
103131
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
104-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
132+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
133+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
134+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
135+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
136+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
105137
// CHECK: ret float %[[RES]]
106138
return dot2add(p1, p2, p3);
107139
}
@@ -114,7 +146,11 @@ float test_int64_arg1_arg2_type(int64_t2 p1, int64_t2 p2, float p3) {
114146
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
115147
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
116148
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
117-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
149+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
150+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
151+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
152+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
153+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
118154
// CHECK: ret float %[[RES]]
119155
return dot2add(p1, p2, p3);
120156
}
@@ -129,7 +165,11 @@ float test_bool_arg1_arg2_type(bool2 p1, bool2 p2, float p3) {
129165
// CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float
130166
// CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4
131167
// CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]]
132-
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
168+
// CHECK-DXIL: %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
169+
// CHECK-DXIL: %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
170+
// CHECK-DXIL: %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
171+
// CHECK-DXIL: %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
172+
// CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float %{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
133173
// CHECK: ret float %[[RES]]
134174
return dot2add(p1, p2, p3);
135175
}

Diff for: llvm/include/llvm/IR/IntrinsicsDirectX.td

+24-15
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,27 @@ def int_dx_nclamp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>,
7676
def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
7777
def int_dx_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
7878

79-
def int_dx_dot2 :
80-
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
81-
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
82-
[IntrNoMem, Commutative] >;
83-
def int_dx_dot3 :
84-
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
85-
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
86-
[IntrNoMem, Commutative] >;
87-
def int_dx_dot4 :
88-
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
89-
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
90-
[IntrNoMem, Commutative] >;
79+
def int_dx_dot2 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
80+
[
81+
llvm_anyfloat_ty, LLVMMatchType<0>,
82+
LLVMMatchType<0>, LLVMMatchType<0>
83+
],
84+
[IntrNoMem, Commutative]>;
85+
def int_dx_dot3 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
86+
[
87+
llvm_anyfloat_ty, LLVMMatchType<0>,
88+
LLVMMatchType<0>, LLVMMatchType<0>,
89+
LLVMMatchType<0>, LLVMMatchType<0>
90+
],
91+
[IntrNoMem, Commutative]>;
92+
def int_dx_dot4 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
93+
[
94+
llvm_anyfloat_ty, LLVMMatchType<0>,
95+
LLVMMatchType<0>, LLVMMatchType<0>,
96+
LLVMMatchType<0>, LLVMMatchType<0>,
97+
LLVMMatchType<0>, LLVMMatchType<0>
98+
],
99+
[IntrNoMem, Commutative]>;
91100
def int_dx_fdot :
92101
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
93102
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
@@ -100,9 +109,9 @@ def int_dx_udot :
100109
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
101110
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
102111
[IntrNoMem, Commutative] >;
103-
def int_dx_dot2add :
104-
DefaultAttrsIntrinsic<[llvm_float_ty],
105-
[llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty],
112+
def int_dx_dot2add :
113+
DefaultAttrsIntrinsic<[llvm_float_ty],
114+
[llvm_float_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty],
106115
[IntrNoMem, Commutative]>;
107116
def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
108117
def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;

Diff for: llvm/lib/Target/DirectX/DXIL.td

+1-2
Original file line numberDiff line numberDiff line change
@@ -1078,8 +1078,7 @@ def RawBufferStore : DXILOp<140, rawBufferStore> {
10781078
}
10791079

10801080
def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
1081-
let Doc = "dot product of 2 vectors of half having size = 2, returns "
1082-
"float";
1081+
let Doc = "2D half dot product with accumulate to float";
10831082
let intrinsics = [IntrinSelect<int_dx_dot2add>];
10841083
let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy];
10851084
let result = FloatTy;

Diff for: llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

+10-3
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
169169
assert(ATy->getScalarType()->isFloatingPointTy());
170170

171171
Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
172-
switch (AVec->getNumElements()) {
172+
int NumElts = AVec->getNumElements();
173+
switch (NumElts) {
173174
case 2:
174175
DotIntrinsic = Intrinsic::dx_dot2;
175176
break;
@@ -185,8 +186,14 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
185186
/* gen_crash_diag=*/false);
186187
return nullptr;
187188
}
188-
return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
189-
ArrayRef<Value *>{A, B}, nullptr, "dot");
189+
190+
SmallVector<Value *> Args;
191+
for (int I = 0; I < NumElts; ++I)
192+
Args.push_back(Builder.CreateExtractElement(A, Builder.getInt32(I)));
193+
for (int I = 0; I < NumElts; ++I)
194+
Args.push_back(Builder.CreateExtractElement(B, Builder.getInt32(I)));
195+
return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, Args,
196+
nullptr, "dot");
190197
}
191198

192199
// Create the appropriate DXIL float dot intrinsic for the operands of Orig

Diff for: llvm/lib/Target/DirectX/DXILOpLowering.cpp

-58
Original file line numberDiff line numberDiff line change
@@ -33,52 +33,6 @@
3333
using namespace llvm;
3434
using namespace llvm::dxil;
3535

36-
static bool isVectorArgExpansion(Function &F) {
37-
switch (F.getIntrinsicID()) {
38-
case Intrinsic::dx_dot2:
39-
case Intrinsic::dx_dot3:
40-
case Intrinsic::dx_dot4:
41-
return true;
42-
}
43-
return false;
44-
}
45-
46-
static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
47-
SmallVector<Value *> ExtractedElements;
48-
auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
49-
for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
50-
Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
51-
Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
52-
ExtractedElements.push_back(ExtractedElement);
53-
}
54-
return ExtractedElements;
55-
}
56-
57-
static SmallVector<Value *>
58-
argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder, unsigned NumOperands) {
59-
assert(NumOperands > 0);
60-
Value *Arg0 = Orig->getOperand(0);
61-
[[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
62-
assert(VecArg0);
63-
SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
64-
for (unsigned I = 1; I < NumOperands; ++I) {
65-
Value *Arg = Orig->getOperand(I);
66-
[[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
67-
assert(VecArg);
68-
assert(VecArg0->getElementType() == VecArg->getElementType());
69-
assert(VecArg0->getNumElements() == VecArg->getNumElements());
70-
auto NextOperandList = populateOperands(Arg, Builder);
71-
NewOperands.append(NextOperandList.begin(), NextOperandList.end());
72-
}
73-
return NewOperands;
74-
}
75-
76-
static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
77-
IRBuilder<> &Builder) {
78-
// Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
79-
return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1);
80-
}
81-
8236
namespace {
8337
class OpLowerer {
8438
Module &M;
@@ -150,9 +104,6 @@ class OpLowerer {
150104
[[nodiscard]] bool
151105
replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
152106
ArrayRef<IntrinArgSelect> ArgSelects) {
153-
bool IsVectorArgExpansion = isVectorArgExpansion(F);
154-
assert(!(IsVectorArgExpansion && ArgSelects.size()) &&
155-
"Cann't do vector arg expansion when using arg selects.");
156107
return replaceFunction(F, [&](CallInst *CI) -> Error {
157108
OpBuilder.getIRB().SetInsertPoint(CI);
158109
SmallVector<Value *> Args;
@@ -170,15 +121,6 @@ class OpLowerer {
170121
break;
171122
}
172123
}
173-
} else if (IsVectorArgExpansion) {
174-
Args = argVectorFlatten(CI, OpBuilder.getIRB());
175-
} else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) {
176-
// arg[NumOperands-1] is a pointer and is not needed by our flattening.
177-
// arg[NumOperands-2] also does not need to be flattened because it is a
178-
// scalar.
179-
unsigned NumOperands = CI->getNumOperands() - 2;
180-
Args.push_back(CI->getArgOperand(NumOperands));
181-
Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands));
182124
} else {
183125
Args.append(CI->arg_begin(), CI->arg_end());
184126
}

Diff for: llvm/test/CodeGen/DirectX/dot2_error.ll

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
; CHECK: in function dot_double2
55
; CHECK-SAME: Cannot create Dot2 operation: Invalid overload type
66

7-
define noundef double @dot_double2(<2 x double> noundef %a, <2 x double> noundef %b) {
7+
define noundef double @dot_double2(double noundef %a1, double noundef %a2,
8+
double noundef %b1, double noundef %b2) {
89
entry:
9-
%dx.dot = call double @llvm.dx.dot2.v2f64(<2 x double> %a, <2 x double> %b)
10+
%dx.dot = call double @llvm.dx.dot2(double %a1, double %a2, double %b1, double %b2)
1011
ret double %dx.dot
1112
}

Diff for: llvm/test/CodeGen/DirectX/dot2add.ll

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
22

3-
define noundef float @dot2add_simple(<2 x half> noundef %a, <2 x half> noundef %b, float %c) {
3+
define noundef float @dot2add_simple(<2 x half> noundef %a, <2 x half> noundef %b, float %acc) {
44
entry:
5-
; CHECK: call float @dx.op.dot2AddHalf(i32 162, float %c, half %0, half %1, half %2, half %3)
6-
%ret = call float @llvm.dx.dot2add(<2 x half> %a, <2 x half> %b, float %c)
5+
%ax = extractelement <2 x half> %a, i32 0
6+
%ay = extractelement <2 x half> %a, i32 1
7+
%bx = extractelement <2 x half> %b, i32 0
8+
%by = extractelement <2 x half> %b, i32 1
9+
10+
; CHECK: call float @dx.op.dot2AddHalf(i32 162, float %acc, half %ax, half %ay, half %bx, half %by)
11+
%ret = call float @llvm.dx.dot2add(float %acc, half %ax, half %ay, half %bx, half %by)
712
ret float %ret
813
}

0 commit comments

Comments
 (0)