diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 9d3c4366a7bd5..0c5a1ce0e96a3 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -198,6 +198,22 @@ LogicalResult CreateNdDescOp::verify() { tdescMemorySpace == static_cast(MemorySpace::SLM)) return emitOpError("SLM is not supported for 2D Block TensorDesc.\n"); + if (auto attr = getType().getSGMapAttr()) { + auto wiLayout = attr.getWiLayout(); + auto wiData = attr.getWiData(); + if (wiData[0] < 1 || wiData[1] < 1 || (wiData[0] > 1 && wiData[1] > 1)) + return emitOpError() << "`wi_data` values must be >=1 and can only be >1 " + "along one dimension." + << "\n"; + auto tdescShape = getType().getShape(); + for (size_t i = 0; i < tdescShape.size(); i++) { + if (tdescShape[i] % wiLayout[i]) + return emitOpError() << "Work-items must uniformly divide a tile " + "(tdescShape[i] % wiLayout[i] == 0)" + << "\n"; + } + } + return success(); } @@ -250,6 +266,13 @@ LogicalResult LoadNdOp::verify() { auto tdescShape = getShapeOf(tdescTy); auto valueShape = getShapeOf(valueTy); + if (auto attr = getTensorDescType().getSGMapAttr()) { + auto wiLayout = attr.getWiLayout(); + for (size_t i = 0; i < tdescShape.size(); i++) { + tdescShape[i] /= wiLayout[i]; + } + } + if (getTranspose()) { auto trans = getTranspose().value(); diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir index a4587faa3345c..0f92e9cb68db6 100644 --- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir +++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir @@ -21,6 +21,33 @@ gpu.func @test_create_nd_tdesc_with_sg_map(%src: memref<24x32xf32>) { gpu.return } +// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map(%[[arg0:.*]]: memref<32x32xi8>) { +gpu.func @test_load_nd_tdesc_with_sg_map(%src: memref<32x32xi8>) { + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<32x32xi8> -> !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<32x32xi8> -> !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map> + // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, packed}> : !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map> -> vector<8x1x4xi8> + %2 = xegpu.load_nd %1 <{packed, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: !xegpu.tensor_desc<32x16xi8, #xegpu.sg_map> -> vector<8x1x4xi8> + gpu.return +} + +// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map_2(%[[arg0:.*]]: memref<24x32xf32>) { +gpu.func @test_load_nd_tdesc_with_sg_map_2(%src: memref<24x32xf32>) { + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> + // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> -> vector<8x1xf32> + %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map> -> vector<8x1xf32> + gpu.return +} + +// CHECK: gpu.func @test_load_nd_tdesc_with_sg_map_3(%[[arg0:.*]]: memref<32x32xf32>) { +gpu.func @test_load_nd_tdesc_with_sg_map_3(%src: memref<32x32xf32>) { + // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map> + %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<32x32xf32> -> !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map> + // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[REG]] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, transpose = array}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map> -> vector<8x1xf32> + %2 = xegpu.load_nd %1 <{transpose = array, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x8xf32, #xegpu.sg_map> -> vector<8x1xf32> + gpu.return +} + // CHECK: gpu.func @test_create_nd_tdesc_vc_2(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) { gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) { //CHECK: %[[C:.*]] = arith.constant 1 : index