@@ -221,6 +221,76 @@ def XeVM_BlockPrefetch2dOp : XeVM_Op<"blockprefetch2d">,
221221 let hasVerifier = 1;
222222}
223223
224+ def XeVM_MatrixElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, F16, BF16]>;
225+
226+ /// Enum attribute of the different precision types.
227+ def XeVM_PrecisionTypeAttr : I32EnumAttr<"PrecisionType",
228+ "XeVM precision type",
229+ [
230+ I32EnumAttrCase<"UNUSED", 0, "unused">,
231+ I32EnumAttrCase<"U8", 1, "u8">,
232+ I32EnumAttrCase<"U4", 2, "u4">,
233+ I32EnumAttrCase<"U2", 3, "u2">,
234+ I32EnumAttrCase<"S8", 4, "i8">,
235+ I32EnumAttrCase<"S4", 5, "i4">,
236+ I32EnumAttrCase<"S2", 6, "i2">,
237+ I32EnumAttrCase<"BF8", 7, "bf8">,
238+ I32EnumAttrCase<"TF32", 8, "tf32">,
239+ I32EnumAttrCase<"BF16", 9, "bf16">,
240+ I32EnumAttrCase<"FP16", 10, "f16">
241+ ]> {
242+ let cppNamespace = "::mlir::xevm";
243+ }
244+
245+ def XeVM_DPASOp : XeVM_Op<"dpas">,
246+ Results<(outs FixedVectorOf<[XeVM_MatrixElemType]>:$d)>,
247+ Arguments<(ins
248+ FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$c,
249+ FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a,
250+ FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b,
251+ XeVM_PrecisionTypeAttr:$pa,
252+ XeVM_PrecisionTypeAttr:$pb,
253+ I32Attr:$rc
254+ )> {
255+
256+ let summary = "Matrix multiply-add";
257+
258+ let description = [{
259+ The `xevm.dpas` operation is a matrix multiplication plus accumulation:
260+
261+ D = C + A x B
262+
263+ where the A, B, C input matrices and the result D have shapes:
264+ D : MxN
265+ C : MxN
266+ A : MxK
267+ B : KxN
268+
269+ Shape restrictions:
270+ M : must be 1, 2, 4, or 8
271+ N : fixed execution size, must be 16
272+ K : systolic_depth * OPS_PER_CHAN
273+ OPS_PER_CHAN
274+ 1 : for TF32
275+ 2 : for 16-bit precision(BF, HF)
276+ 4 : for 8-bit precision (FP8, UB, B)
277+ 8 : for less-then 8 bit precision (U4/S4, U2/S2).
278+
279+ If systolic_depth is 8, K would be 8, 16, 32, or 64 (based on OPS_PER_CHAN).
280+ $a, $b, $c, $d - matrix A, B, C, D, respectively
281+ $pa, $pb - precision of matrix A and B resepectively
282+ $rc - repeat count
283+
284+ Further restrictions as well as more details can be found here:
285+ https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
286+ }];
287+
288+ let assemblyFormat = [{
289+ operands ` ` `{` `pa` `=` $pa `,` `pb` `=` $pb `,` `rc` `=` $rc `}` attr-dict `:` functional-type(operands, results)
290+ }];
291+
292+ // let hasVerifier = 1;
293+ }
224294
225295def XeVM_TargetAttr : XeVM_Attr<"XeVMTarget", "target"> {
226296 let description = [{
0 commit comments