@@ -64,18 +64,20 @@ func.func @dot_product_i32() {
6464 %vector_b = arith.constant dense <314 > : vector <[4 ]xi32 >
6565 %vector_c = arith.constant dense <0 > : vector <[4 ]xi32 >
6666
67- // The result of this dot-product will depend
68- // on the vector length, so we are unable to verify it.
67+ // DOT PRODUCT 1
6968 %dp1 = vector.contract #dotp_trait %vector_a , %vector_b , %acc
7069 : vector <[4 ]xi32 >, vector <[4 ]xi32 > into i32
71- // Dot product should be (123 * 314) * 4 * vscale, so ...
70+ // Dot product should be:
71+ // * val = (123 * 314) * 4 * vscale,
72+ // so ...
7273 %vscale = vector.vscale
7374 %vscale_i32 = arith.index_cast %vscale : index to i32
74- %dp1_divvl = arith.divui %dp1 , %vscale_i32 : i32
75- // ... %dp/% vscale = 123 * 314 * 4 = 154488
75+ %dp1_div = arith.divui %dp1 , %vscale_i32 : i32
76+ // ... val / vscale = 123 * 314 * 4 = 154488
7677 // DP: 154488
77- vector.print %dp1_divvl : i32
78+ vector.print %dp1_div : i32
7879
80+ // DOT PRODUCT 2
7981 // The result of this dot-product should be 0.
8082 %dp2 = vector.contract #dotp_trait %vector_a , %vector_c , %acc
8183 : vector <[4 ]xi32 >, vector <[4 ]xi32 > into i32
@@ -96,18 +98,27 @@ func.func @matvec_i32() {
9698 %vector_b = arith.constant dense <314 > : vector <[4 ]xi32 >
9799 %vector_c = arith.constant dense <0 > : vector <[4 ]xi32 >
98100
99- // The result of this matvec will depend on the vector length, so we are
100- // unable to verify it.
101- %dp1 = vector.contract #matvec_trait %vector_a , %vector_b , %acc
101+ // MATVEC 1
102+ %mv1 = vector.contract #matvec_trait %vector_a , %vector_b , %acc
102103 : vector <3 x[4 ]xi32 >, vector <[4 ]xi32 > into vector <3 xi32 >
103- // MV: {{[0-9]*}}, {{[0-9]*}}, {{[0-9]*}}
104- vector.print %dp1 : vector <3 xi32 >
105-
106- // The result of this matvc should be a vector of 0s.
107- %dp2 = vector.contract #matvec_trait %vector_a , %vector_c , %acc
104+ // Every element in the output vector is a result of a dot product, for
105+ // which:
106+ // val = (123 * 314) * 4 * vscale
107+ // so ...
108+ %vscale = vector.vscale
109+ %vscale_v = vector.splat %vscale : vector <3 xindex >
110+ %vscale_i32 = arith.index_cast %vscale_v : vector <3 xindex > to vector <3 xi32 >
111+ %mv1_div = arith.divui %mv1 , %vscale_i32 : vector <3 xi32 >
112+ // ... val / vscale = 123 * 314 * 4 = 154488
113+ // MV: 154488, 154488, 154488
114+ vector.print %mv1_div : vector <3 xi32 >
115+
116+ // MATVEC 2
117+ // The result of this matvec should be a vector of 0s.
118+ %mv2 = vector.contract #matvec_trait %vector_a , %vector_c , %acc
108119 : vector <3 x[4 ]xi32 >, vector <[4 ]xi32 > into vector <3 xi32 >
109120 // MV: 0, 0, 0
110- vector.print %dp2 : vector <3 xi32 >
121+ vector.print %mv2 : vector <3 xi32 >
111122
112123 // MV: SVE: END OF TEST OUTPUT
113124 vector.print str " SVE: END OF TEST OUTPUT"
0 commit comments