diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 5910aa3f7f2da..f3ffbd0f5a027 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -327,8 +327,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor let hasVerifier = 1; } -def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc"]>, - AllElementTypesMatch<["value", "TensorDesc"]>]> { +def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllElementTypesMatch<["value", "TensorDesc"]>]> { let summary = "stores a n-D block register region back to memory, currently only supports 2D"; let description = [{ diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 9d3c4366a7bd5..15c435f1fa257 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -73,6 +73,39 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) { kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH; } +// Validations for nd instruction arguments is successful if any of these are +// true: +// - tensor descriptor and the output vector shapes exactly match. +// - tensor descriptor has a sg_map attribute and the distributed vector shape +// matches the tensor descriptor shape when scaled using sg_map factors on +// each dimension. +static bool isArgShapesValid(ArrayRef descShape, + ArrayRef valShape, SGMapAttr sgMap) { + if (descShape == valShape) { + if (!sgMap) + return true; + + // this can be relaxed if necessary by supporting non-2d shapes distribution + // until the constraints are defined this lives here instead of the tensor + // descriptor type. + return valShape.size() == sgMap.getWiLayout().size(); + } + + if (!sgMap) + return false; + + if (valShape.size() != descShape.size()) + return false; + + for (const auto &[factor, dim, expected] : + llvm::zip_equal(sgMap.getWiLayout(), valShape, descShape)) { + if (factor * dim != expected) + return false; + } + + return true; +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -210,13 +243,13 @@ LogicalResult PrefetchNdOp::verify() { return emitOpError("Expects a non-scattered TensorDesc.\n"); if (!isReadHintOrNone(getL1HintAttr())) - return emitOpError("invlid l1_hint: ") << getL1HintAttr(); + return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isReadHintOrNone(getL2HintAttr())) - return emitOpError("invlid l2_hint: ") << getL2HintAttr(); + return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isReadHintOrNone(getL3HintAttr())) - return emitOpError("invlid l3_hint: ") << getL3HintAttr(); + return emitOpError("invalid l3_hint: ") << getL3HintAttr(); return success(); } @@ -238,13 +271,13 @@ LogicalResult LoadNdOp::verify() { return emitOpError("Invalid result, it should be a VectorType.\n"); if (!isReadHintOrNone(getL1HintAttr())) - return emitOpError("invlid l1_hint: ") << getL1HintAttr(); + return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isReadHintOrNone(getL2HintAttr())) - return emitOpError("invlid l2_hint: ") << getL2HintAttr(); + return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isReadHintOrNone(getL3HintAttr())) - return emitOpError("invlid l3_hint: ") << getL3HintAttr(); + return emitOpError("invalid l3_hint: ") << getL3HintAttr(); auto array_len = tdescTy.getArrayLength(); auto tdescShape = getShapeOf(tdescTy); @@ -280,8 +313,9 @@ LogicalResult LoadNdOp::verify() { auto it = tdescShape.begin(); tdescShape.insert(it, array_len); } + auto sgMap = tdescTy.getSGMapAttr(); - if (tdescShape != valueShape) + if (!isArgShapesValid(tdescShape, valueShape, sgMap)) return emitOpError() << "Result shape doesn't match TensorDesc shape." << "The expected shape is " << makeString(tdescShape) << ". But the given shape is " @@ -303,17 +337,26 @@ LogicalResult StoreNdOp::verify() { return emitOpError("Expects a non-scattered TensorDesc.\n"); if (!valTy) - return emitOpError("Exepcting a VectorType result.\n"); + return emitOpError("Expecting a VectorType result.\n"); if (!isWriteHintOrNone(getL1HintAttr())) - return emitOpError("invlid l1_hint: ") << getL1HintAttr(); + return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isWriteHintOrNone(getL2HintAttr())) - return emitOpError("invlid l2_hint: ") << getL2HintAttr(); + return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isWriteHintOrNone(getL3HintAttr())) - return emitOpError("invlid l3_hint: ") << getL3HintAttr(); + return emitOpError("invalid l3_hint: ") << getL3HintAttr(); + + auto tdescShape = getShapeOf(dstTy); + auto valueShape = getShapeOf(valTy); + auto sgMap = dstTy.getSGMapAttr(); + if (!isArgShapesValid(tdescShape, valueShape, sgMap)) + return emitOpError() << "Result shape doesn't match TensorDesc shape." + << "The expected shape is " << makeString(tdescShape) + << ". But the given shape is " + << makeString(valueShape) << ".\n"; return success(); } @@ -423,13 +466,13 @@ LogicalResult PrefetchOp::verify() { return emitOpError("Expects a scattered TensorDesc.\n"); if (!isReadHintOrNone(getL1HintAttr())) - return emitOpError("invlid l1_hint: ") << getL1HintAttr(); + return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isReadHintOrNone(getL2HintAttr())) - return emitOpError("invlid l2_hint: ") << getL2HintAttr(); + return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isReadHintOrNone(getL3HintAttr())) - return emitOpError("invlid l3_hint: ") << getL3HintAttr(); + return emitOpError("invalid l3_hint: ") << getL3HintAttr(); return success(); } @@ -446,13 +489,13 @@ LogicalResult LoadGatherOp::verify() { return emitOpError("Expects a scattered TensorDesc.\n"); if (!isReadHintOrNone(getL1HintAttr())) - return emitOpError("invlid l1_hint: ") << getL1HintAttr(); + return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isReadHintOrNone(getL2HintAttr())) - return emitOpError("invlid l2_hint: ") << getL2HintAttr(); + return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isReadHintOrNone(getL3HintAttr())) - return emitOpError("invlid l3_hint: ") << getL3HintAttr(); + return emitOpError("invalid l3_hint: ") << getL3HintAttr(); auto tdescElemTy = tdescTy.getElementType(); auto valueElemTy = getElementType(); @@ -490,13 +533,13 @@ LogicalResult StoreScatterOp::verify() { return emitOpError("Expects a scattered TensorDesc.\n"); if (!isWriteHintOrNone(getL1HintAttr())) - return emitOpError("invlid l1_hint: ") << getL1HintAttr(); + return emitOpError("invalid l1_hint: ") << getL1HintAttr(); if (!isWriteHintOrNone(getL2HintAttr())) - return emitOpError("invlid l2_hint: ") << getL2HintAttr(); + return emitOpError("invalid l2_hint: ") << getL2HintAttr(); if (!isWriteHintOrNone(getL3HintAttr())) - return emitOpError("invlid l3_hint: ") << getL3HintAttr(); + return emitOpError("invalid l3_hint: ") << getL3HintAttr(); auto maskTy = getMaskType(); auto valueTy = getValueType(); diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir index a4587faa3345c..d7174a489888a 100644 --- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir +++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir @@ -86,6 +86,17 @@ gpu.func @test_load_nd_vc_2(%src: memref<8x16xf16>) { gpu.return } +// load_nd args may have different shapes, validated against sg_map +// CHECK: func @test_load_nd_vc_3(%[[arg0:.*]]: memref<24x32xf32>) { +gpu.func @test_load_nd_vc_3(%src: memref<24x32xf32>) { + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> + // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> -> vector<8x1xf32> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> -> vector<8x1xf32> + gpu.return +} + // CHECK: func @test_store_nd_vc(%[[arg0:.*]]: memref<24x32xf16>) { gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) { // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16> @@ -108,6 +119,19 @@ gpu.func @test_store_nd_vc_2(%dst: memref<24x32xf16>) { gpu.return } +// store_nd args may have different shapes, validated against sg_map +// CHECK: func @test_store_nd_vc_3(%[[arg0:.*]]: memref<24x32xf16>) { +gpu.func @test_store_nd_vc_3(%src: memref<24x32xf16>) { + // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x2xf16> + %1 = arith.constant dense<1.0>: vector<24x2xf16> + // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map> + %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> + !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map> + // CHECK: xegpu.store_nd %[[C]], %[[R0]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map> + xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<24x2xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.sg_map> + gpu.return +} + // CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) { gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) { // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir index f8a0d95bd70a2..7816bff0582f8 100644 --- a/mlir/test/Dialect/XeGPU/invalid.mlir +++ b/mlir/test/Dialect/XeGPU/invalid.mlir @@ -32,7 +32,7 @@ func.func @test_create_nd_tdesc_vc_4(%src: memref<2x24x32xf32, 3>) { // ----- func.func @test_prefetch_nd_vc_1(%src: memref<24x32xf16>) { %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16> - // expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint}} + // expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint}} xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<8x16xf16> return } @@ -51,7 +51,7 @@ func.func @test_prefetch_nd_vc_2(%src: memref<24xf16>) { // ----- func.func @test_load_nd_vc_1(%src: memref<8x16xf16>) { %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> - // expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint}} + // expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint}} %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> return @@ -77,11 +77,29 @@ func.func @test_load_nd_vc_3(%src: memref<8x16xf16>) { return } +// ----- +func.func @test_load_nd_vc_4(%src: memref<24x32xf32>) { + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> + // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> -> vector<8x2xf32> + return +} + +// ----- +func.func @test_load_nd_vc_5(%src: memref<24x32xf32>) { + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> + !xegpu.tensor_desc<16xf32, #xegpu.sg_map> + // expected-error@+1 {{Result shape doesn't match TensorDesc shape.}} + %2 = xegpu.load_nd %1: !xegpu.tensor_desc<16xf32, #xegpu.sg_map> -> vector<16xf32> + return +} + // ----- func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) { %1 = arith.constant dense<1.0>: vector<24x32xf16> %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16> - // expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint}} + // expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint}} xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16> return } @@ -147,7 +165,7 @@ func.func @test_prefetch_vc_2(%src: ui64) { %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex> %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> - // expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint}} + // expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint}} xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> return } @@ -168,7 +186,7 @@ func.func @test_load_gather_vc_2(%src: ui64) { %0 = arith.constant dense<1>: vector<4xi1> %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> - // expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint}} + // expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint}} %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> -> vector<4x2xf32> @@ -193,7 +211,7 @@ func.func @test_store_scatter_vc_2(%src: ui64) { %1 = arith.constant dense<2.9>: vector<4x2xf32> %2 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr> - // expected-error@+1 {{invlid l1_hint: #xegpu.cache_hint}} + // expected-error@+1 {{invalid l1_hint: #xegpu.cache_hint}} xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint}> : vector<4x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr>, vector<4xi1> return