Skip to content

Commit af7b9d6

Browse files
vmaksimosys-ce-bb
authored andcommitted
1 parent e98c1dd commit af7b9d6

File tree

2 files changed

+187
-4
lines changed

2 files changed

+187
-4
lines changed

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,10 @@ class SPIRVBinary : public SPIRVInstTemplateBase {
624624
assert(getValueType(Op1)->getVectorComponentCount() ==
625625
getValueType(Op2)->getVectorComponentCount() &&
626626
"Inconsistent Vector component width");
627+
} else if (getValueType(Op1)->isTypeCooperativeMatrixKHR()) {
628+
Op1Ty = getValueType(Op1)->getVectorComponentType();
629+
Op2Ty = getValueType(Op2)->getVectorComponentType();
630+
assert(Op1Ty == Op2Ty && "Inconsistent Cooperative matrix types");
627631
} else {
628632
Op1Ty = getValueType(Op1);
629633
Op2Ty = getValueType(Op2);
@@ -1502,10 +1506,13 @@ class SPIRVUnary : public SPIRVInstTemplateBase {
15021506
return;
15031507
if (isGenericNegateOpCode(OpCode)) {
15041508
SPIRVType *ResTy =
1505-
Type->isTypeVector() ? Type->getVectorComponentType() : Type;
1506-
SPIRVType *OpTy = Type->isTypeVector()
1507-
? getValueType(Op)->getVectorComponentType()
1508-
: getValueType(Op);
1509+
Type->isTypeVector() || Type->isTypeCooperativeMatrixKHR()
1510+
? Type->getVectorComponentType()
1511+
: Type;
1512+
SPIRVType *OpTy =
1513+
Type->isTypeVector() || Type->isTypeCooperativeMatrixKHR()
1514+
? getValueType(Op)->getVectorComponentType()
1515+
: getValueType(Op);
15091516

15101517
(void)ResTy;
15111518
(void)OpTy;
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; TODO: Validation is disabled till the moment the tools in CI are updated (passes locally)
7+
; R/UN: spirv-val %t.spv
8+
9+
; TODO: come up with an approach and implement reverse translation
10+
; R/UN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o %t.rev.bc
11+
; R/UN: llvm-dis %t.rev.bc
12+
; R/UN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM
13+
14+
; CHECK-SPIRV: TypeInt [[#TypeInt:]] 32 0
15+
; CHECK-SPIRV: TypeCooperativeMatrixKHR [[#MatrixTypeInt:]] [[#TypeInt]]
16+
; CHECK-SPIRV: TypeFloat [[#TypeFloat:]] 32
17+
; CHECK-SPIRV: TypeCooperativeMatrixKHR [[#MatrixTypeFloat:]] [[#TypeFloat]]
18+
19+
target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024"
20+
target triple = "spir-unknown-unknown"
21+
22+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixIn:]] [[#]] {{$}}
23+
; CHECK-SPIRV: SNegate [[#MatrixTypeInt]] [[#]] [[#MatrixIn]]
24+
define spir_kernel void @testSNegate(i32 %a) #0 !kernel_arg_addr_space !10 !kernel_arg_access_qual !11 !kernel_arg_type !12 !kernel_arg_type_qual !9 !kernel_arg_base_type !12 {
25+
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
26+
%call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z15__spirv_SNegate(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1)
27+
ret void
28+
}
29+
30+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixIn:]] [[#]] {{$}}
31+
; CHECK-SPIRV: FNegate [[#MatrixTypeFloat]] [[#]] [[#MatrixIn]]
32+
; CHECK-LLVM: fneg
33+
define spir_kernel void @testFNeg(float %a) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !9 {
34+
entry:
35+
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
36+
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z15__spirv_FNegate(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0)
37+
ret void
38+
}
39+
40+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
41+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
42+
; CHECK-SPIRV: IAdd [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
43+
define spir_kernel void @testIAdd(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_type_qual !7 !kernel_arg_base_type !6 {
44+
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
45+
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
46+
%call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_IAdd(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
47+
ret void
48+
}
49+
50+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
51+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
52+
; CHECK-SPIRV: ISub [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
53+
define spir_kernel void @testISub(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_type_qual !7 !kernel_arg_base_type !6 {
54+
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
55+
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
56+
%call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_ISub(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
57+
ret void
58+
}
59+
60+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
61+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
62+
; CHECK-SPIRV: IMul [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
63+
define spir_kernel void @testIMul(i32 %a, i32 %b) #0 !kernel_arg_addr_space !4 !kernel_arg_access_qual !5 !kernel_arg_type !6 !kernel_arg_type_qual !7 !kernel_arg_base_type !6 {
64+
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
65+
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
66+
%call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_IMul(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
67+
ret void
68+
}
69+
70+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
71+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
72+
; CHECK-SPIRV: SDiv [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
73+
define void @testSDiv(i32 %a, i32 %b) {
74+
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
75+
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
76+
%call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_SDiv(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
77+
ret void
78+
}
79+
80+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixA:]] [[#]] {{$}}
81+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeInt]] [[#MatrixB:]] [[#]] {{$}}
82+
; CHECK-SPIRV: UDiv [[#MatrixTypeInt]] [[#]] [[#MatrixA]] [[#MatrixB]]
83+
define void @testUDiv(i32 %a, i32 %b) {
84+
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
85+
%2 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 0)
86+
%call = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_UDiv(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %1, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) %2)
87+
ret void
88+
}
89+
90+
91+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}}
92+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}}
93+
; CHECK-SPIRV: FAdd [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]]
94+
define spir_kernel void @testFAdd(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
95+
entry:
96+
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
97+
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
98+
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FAdd(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
99+
ret void
100+
}
101+
102+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}}
103+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}}
104+
; CHECK-SPIRV: FSub [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]]
105+
define spir_kernel void @testFSub(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
106+
entry:
107+
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
108+
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
109+
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FSub(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
110+
ret void
111+
}
112+
113+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}}
114+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}}
115+
; CHECK-SPIRV: FMul [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]]
116+
define spir_kernel void @testFMul(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
117+
entry:
118+
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
119+
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
120+
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FMul(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
121+
ret void
122+
}
123+
124+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixA:]] [[#]] {{$}}
125+
; CHECK-SPIRV: CompositeConstruct [[#MatrixTypeFloat]] [[#MatrixB:]] [[#]] {{$}}
126+
; CHECK-SPIRV: FDiv [[#MatrixTypeFloat]] [[#]] [[#MatrixA]] [[#MatrixB]]
127+
define spir_kernel void @testFDiv(float %a, float %b) local_unnamed_addr #0 !kernel_arg_addr_space !2 !kernel_arg_access_qual !3 !kernel_arg_type !4 !kernel_arg_base_type !4 !kernel_arg_type_qual !5 {
128+
entry:
129+
%0 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
130+
%1 = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float 0.000000e+00)
131+
%call = call spir_func target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FDiv(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %0, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) %1)
132+
ret void
133+
}
134+
135+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructFloat(float noundef)
136+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z26__spirv_CompositeConstructInt32(i32 noundef)
137+
138+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z15__spirv_FNegate(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef)
139+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z15__spirv_SNegate(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) noundef)
140+
141+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_IAdd(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) noundef, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) noundef)
142+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_ISub(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) noundef, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) noundef)
143+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_IMul(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) noundef, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) noundef)
144+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_SDiv(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) noundef, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) noundef)
145+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) @_Z12__spirv_UDiv(target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) noundef, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 3) noundef)
146+
147+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FAdd(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef)
148+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FSub(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef)
149+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FMul(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef)
150+
declare spir_func noundef target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) @_Z12__spirv_FDiv(target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef, target("spirv.CooperativeMatrixKHR", float, 3, 12, 12, 3) noundef)
151+
152+
153+
attributes #0 = { nounwind }
154+
155+
!spirv.MemoryModel = !{!0}
156+
!opencl.enable.FP_CONTRACT = !{}
157+
!spirv.Source = !{!1}
158+
!opencl.spir.version = !{!0}
159+
!opencl.ocl.version = !{!0}
160+
!opencl.used.extensions = !{!2}
161+
!opencl.used.optional.core.features = !{!2}
162+
!spirv.Generator = !{!3}
163+
164+
!0 = !{i32 1, i32 2}
165+
!1 = !{i32 3, i32 102000}
166+
!2 = !{}
167+
!3 = !{i16 7, i16 0}
168+
!4 = !{i32 0, i32 0}
169+
!5 = !{!"none", !"none"}
170+
!6 = !{!"int", !"int"}
171+
!7 = !{!"", !""}
172+
!8 = !{!2, !2}
173+
!9 = !{!""}
174+
!10 = !{i32 0}
175+
!11 = !{!"none"}
176+
!12 = !{!"int"}

0 commit comments

Comments
 (0)