From f7eee8847ebe967c170361eae93429f9ee339451 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Tue, 23 Sep 2025 14:37:58 +0000 Subject: [PATCH 1/4] [mlir][XeGPU][XeGPUUnroll] Support new syntax with offsets moved to load_nd/store_nd/prefetch_nd Signed-off-by: dchigarev --- .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 177 +++++++++++++----- ...xegpu-unroll-patterns-no-desc-offsets.mlir | 61 ++++++ 2 files changed, 186 insertions(+), 52 deletions(-) create mode 100644 mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 29c9fcdfebcdb..cad7436f23762 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -121,54 +121,81 @@ struct UnrollPattern : public OpRewritePattern { xegpu::UnrollOptions options; }; +// Generic helper function for unrolling operations with offsets. +// +// Iterates over tile offsets within the tensor descriptor shape and calls +// the provided createOp function for each computed offset. This is used by +// operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they +// have explicit offsets that need to be adjusted for each unrolled tile. +SmallVector computeUnrolledOffsets( + SmallVector mixedOffsets, xegpu::TensorDescType tdescTy, + ArrayRef targetShape, + const std::function)> &createOp, + Location loc, PatternRewriter &rewriter) { + int64_t rank = tdescTy.getRank(); + ArrayRef shape = tdescTy.getShape(); + + auto addi = [&](OpFoldResult a, int64_t b) -> Value { + std::optional maybeInt = getConstantIntValue(a); + if (maybeInt) { + return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b); + } else { + auto aV = llvm::cast(a); + auto bV = arith::ConstantIndexOp::create(rewriter, loc, b); + return rewriter.createOrFold(loc, aV, bV); + } + }; + + SmallVector oldOffsets = llvm::to_vector( + llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank)); + auto validIdxes = + llvm::seq(mixedOffsets.size() - rank, mixedOffsets.size()); + + SmallVector newOps; + for (SmallVector offsets : + StaticTileOffsetRange(shape, targetShape)) { + + for (auto [idx, oldOff, offset] : + llvm::zip(validIdxes, oldOffsets, offsets)) + mixedOffsets[idx] = addi(oldOff, offset); + + auto newOp = createOp(mixedOffsets); + newOps.push_back(newOp); + } + return newOps; +} + struct UnrollCreateNdOp : public UnrollPattern { using UnrollPattern::UnrollPattern; LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getType(); - int64_t rank = tdescTy.getRank(); - ArrayRef shape = tdescTy.getShape(); std::optional> targetShape = getTargetShape(op); if (!targetShape) return failure(); - auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0]; - - auto addi = [&](OpFoldResult a, int64_t b) -> Value { - std::optional maybeInt = getConstantIntValue(a); - if (maybeInt) { - return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b); - } else { - auto aV = llvm::cast(a); - auto bV = arith::ConstantIndexOp::create(rewriter, loc, b); - return rewriter.createOrFold(loc, aV, bV); - } - }; - - SmallVector mixedOffsets = op.getMixedOffsets(); - - // For n-D memrefs where n > rank, we need to handle the last `rank` - // dimensions only, and keep the first `n-rank` dimensions as is. - SmallVector oldOffsets = llvm::to_vector( - llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank)); - auto validIdxes = - llvm::seq(mixedOffsets.size() - rank, mixedOffsets.size()); - SmallVector newOps; - for (SmallVector offsets : - StaticTileOffsetRange(shape, *targetShape)) { - - for (auto [idx, oldOff, offset] : - llvm::zip(validIdxes, oldOffsets, offsets)) - mixedOffsets[idx] = addi(oldOff, offset); + auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0]; + bool hasOffsets = op.getMixedOffsets().size() != 0; + if (!hasOffsets) { auto newOp = xegpu::CreateNdDescOp::create( - rewriter, loc, newTdescTy, op.getSource(), mixedOffsets, - op.getMixedSizes(), op.getMixedStrides()); + rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(), + op.getMixedStrides()); newOps.push_back(newOp); + } else { + auto createOp = [&](SmallVector offsets) -> Value { + return xegpu::CreateNdDescOp::create( + rewriter, loc, newTdescTy, op.getSource(), offsets, + op.getMixedSizes(), op.getMixedStrides()); + }; + + newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, + *targetShape, createOp, loc, rewriter); } + Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter); rewriter.replaceOp(op, castOp); @@ -216,17 +243,33 @@ struct UnrollPrefetchNdOp : public UnrollPattern { return failure(); int64_t offsetSize = static_cast(op.getOffsets().size()); - if ((offsetSize != 0) || op.getConstOffsetsAttr()) - return failure(); + bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); + + if (hasOffsets) + convertedTdescTypes.resize(1); + SmallVector convertedTdesc = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); - for (auto t : convertedTdesc) - xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t, - op->getAttrs()); + if (!hasOffsets) { + for (auto t : convertedTdesc) + xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t, + op->getAttrs()); + } else { + auto createPrefetch = [&](SmallVector offsets) -> Value { + xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets, + op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + // return dummy Value to satisfy function's signature + return nullptr; + }; + + computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape, + createPrefetch, loc, rewriter); + } rewriter.eraseOp(op); return success(); @@ -247,26 +290,39 @@ struct UnrollLoadNdOp : public UnrollPattern { return failure(); int64_t offsetSize = static_cast(op.getOffsets().size()); - if ((offsetSize != 0) || op.getConstOffsetsAttr()) - return failure(); + bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); Type elemTy = tdescTy.getElementType(); VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); + + if (hasOffsets) + convertedTdescTypes.resize(1); + SmallVector convertedTdescs = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); - SmallVector newOps; - for (auto t : convertedTdescs) { - auto newOp = - xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, op->getAttrs()); - newOps.push_back(newOp); + + if (!hasOffsets) { + for (auto t : convertedTdescs) { + auto newOp = xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, + op->getAttrs()); + newOps.push_back(newOp); + } + } else { + auto createLoad = [&](SmallVector offsets) { + return xegpu::LoadNdOp::create( + rewriter, loc, newValueTy, convertedTdescs[0], offsets, + op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + }; + newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, + *targetShape, createLoad, loc, rewriter); } Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); - rewriter.replaceOp(op, castOp); return success(); } @@ -285,22 +341,39 @@ struct UnrollStoreNdOp : public UnrollPattern { return failure(); int64_t offsetSize = static_cast(op.getOffsets().size()); - if ((offsetSize != 0) || op.getConstOffsetsAttr()) - return failure(); + bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); SmallVector convertedValTypes = getUnrolledTypes(valueTy, *targetShape); SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); - SmallVector convertedValues = - pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); + if (hasOffsets) + convertedTdescTypes.resize(1); + SmallVector convertedTdescs = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); - for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) - xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + SmallVector convertedValues = + pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); + if (!hasOffsets) { + for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) + xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + } else { + size_t valueIndex = 0; + auto createStore = [&](SmallVector offsets) { + xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++], + convertedTdescs[0], offsets, + op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + // return dummy Value to satisfy function's signature + return nullptr; + }; + + computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape, + createStore, loc, rewriter); + } rewriter.eraseOp(op); return success(); diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir new file mode 100644 index 0000000000000..f28e82a2a4c76 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt --test-xegpu-unrolling-patterns -split-input-file %s | FileCheck %s + +gpu.module @xevm_test { + + // CHECK-LABEL: create_nd_tdesc + // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> + // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast + // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> + // CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout> {__xegpu_blocking_tile_shape__ = array, __xegpu_blocking_unpack__} + gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> { + %tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + } + +//----- + // CHECK-LABEL: load_nd + // CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32> + // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + // CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32> + gpu.func @load_nd(%src: memref<256x318xf32>) -> vector<24x32xf32> { + %tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout> -> vector<24x32xf32> + gpu.return %ld : vector<24x32xf32> + } + +//----- + // CHECK-LABEL: load_nd_store_nd + // CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32> + //CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-COUNT-6: xegpu.store_nd {{.*}}[{{.*}}] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.func @load_nd_store_nd(%src: memref<256x318xf32>) { + %tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout> -> vector<24x32xf32> + xegpu.store_nd %ld, %tdesc[0, 0] : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return + } + +//----- + // CHECK-LABEL: prefetch_nd_tdesc + // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> + // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-COUNT-6: xegpu.prefetch_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> + gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) { + %tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + xegpu.prefetch_nd %tdesc[8, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return + } + +//----- + + // CHECK-LABEL: load_nd_offsets_at_both_places + // CHECK-COUNT-2: builtin.unrealized_conversion_cast + gpu.func @load_nd_offsets_at_both_places(%src: memref<256x318xf32>) -> vector<24x32xf32> { + %tdesc = xegpu.create_nd_tdesc %src[16, 8] : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout> -> vector<24x32xf32> + gpu.return %ld : vector<24x32xf32> + } +} \ No newline at end of file From f45f04735e0db6a662154a15b28e82682b2c6d86 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Tue, 23 Sep 2025 14:50:24 +0000 Subject: [PATCH 2/4] fix formatting Signed-off-by: dchigarev --- .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 17 +++++++++++++---- .../xegpu-unroll-patterns-no-desc-offsets.mlir | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index cad7436f23762..80d1cb12dff80 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -195,7 +195,6 @@ struct UnrollCreateNdOp : public UnrollPattern { newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape, createOp, loc, rewriter); } - Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter); rewriter.replaceOp(op, castOp); @@ -248,8 +247,11 @@ struct UnrollPrefetchNdOp : public UnrollPattern { SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); - if (hasOffsets) + if (hasOffsets) { + // only need one tdesc, tile offsets will be computed + // at the operation level convertedTdescTypes.resize(1); + } SmallVector convertedTdesc = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); @@ -298,8 +300,11 @@ struct UnrollLoadNdOp : public UnrollPattern { SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); - if (hasOffsets) + if (hasOffsets) { + // only need one tdesc, tile offsets will be computed + // at the operation level convertedTdescTypes.resize(1); + } SmallVector convertedTdescs = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); @@ -323,6 +328,7 @@ struct UnrollLoadNdOp : public UnrollPattern { } Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); + rewriter.replaceOp(op, castOp); return success(); } @@ -348,8 +354,11 @@ struct UnrollStoreNdOp : public UnrollPattern { SmallVector convertedTdescTypes = getUnrolledTypes(tdescTy, *targetShape); - if (hasOffsets) + if (hasOffsets) { + // only need one tdesc, tile offsets will be computed + // at the operation level convertedTdescTypes.resize(1); + } SmallVector convertedTdescs = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir index f28e82a2a4c76..cbfd991b5557e 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir @@ -58,4 +58,4 @@ gpu.module @xevm_test { %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout> -> vector<24x32xf32> gpu.return %ld : vector<24x32xf32> } -} \ No newline at end of file +} From 932346e77594318552ae60d64473dda041192888 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Wed, 24 Sep 2025 10:09:36 +0000 Subject: [PATCH 3/4] Modify 'unrolledTypefFn' to return one single type Signed-off-by: dchigarev --- .../Dialect/XeGPU/Transforms/Transforms.h | 8 +++-- .../XeGPU/Transforms/XeGPUBlocking.cpp | 5 ++- .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 35 +++++-------------- .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 5 ++- 4 files changed, 22 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h index 44b81796b1313..b74c15e5b7ac1 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h @@ -9,9 +9,9 @@ #ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H #define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H +#include "mlir/IR/Operation.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/LogicalResult.h" -#include "mlir/IR/Operation.h" #include #include @@ -47,9 +47,11 @@ struct UnrollOptions { /// Function that converts a ShapedType (TensorDescType or VectorType) /// into the unrolled type based on the tileShape. It returns a vector of - /// types representing the unrolled types for simplicity. + /// types representing the unrolled types for simplicity. When + /// `returnSingleType` is true, it returns a vector containing only one single + /// unrolled type. using UnrolledTypeFnType = std::function( - ShapedType type, ArrayRef tileShape)>; + ShapedType type, ArrayRef tileShape, bool returnSingleType)>; UnrolledTypeFnType getUnrolledTypes = nullptr; UnrollOptions &setUnrolledTypesFn(UnrolledTypeFnType fn) { getUnrolledTypes = std::move(fn); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index 7efa4b9fbd934..36c498e8b849d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -319,7 +319,8 @@ void XeGPUBlockingPass::runOnOperation() { options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); }); - options.setUnrolledTypesFn([&](ShapedType type, ArrayRef tileShape) { + options.setUnrolledTypesFn([&](ShapedType type, ArrayRef tileShape, + bool returnSingleType = false) { Type elemTy = type.getElementType(); Type newTy; @@ -352,6 +353,8 @@ void XeGPUBlockingPass::runOnOperation() { newTy = type.clone(tileShape, elemTy); } + if (returnSingleType) + return SmallVector{newTy}; std::optional> ratio = computeShapeRatio(type.getShape(), tileShape); assert(ratio && "The shape of the type must be a multiple of tileShape."); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index 80d1cb12dff80..f738effe46a72 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -56,8 +56,9 @@ struct UnrollPattern : public OpRewritePattern { } SmallVector getUnrolledTypes(ShapedType type, - ArrayRef tileShape) const { - return options.getUnrolledTypes(type, tileShape); + ArrayRef tileShape, + bool returnSingleType = false) const { + return options.getUnrolledTypes(type, tileShape, returnSingleType); } /// Emulate the the unpack behavior using insert_strided_slice for VectorType @@ -244,14 +245,8 @@ struct UnrollPrefetchNdOp : public UnrollPattern { int64_t offsetSize = static_cast(op.getOffsets().size()); bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr(); - SmallVector convertedTdescTypes = - getUnrolledTypes(tdescTy, *targetShape); - - if (hasOffsets) { - // only need one tdesc, tile offsets will be computed - // at the operation level - convertedTdescTypes.resize(1); - } + SmallVector convertedTdescTypes = getUnrolledTypes( + tdescTy, *targetShape, /*returnSingleType*/ hasOffsets); SmallVector convertedTdesc = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); @@ -297,14 +292,8 @@ struct UnrollLoadNdOp : public UnrollPattern { Type elemTy = tdescTy.getElementType(); VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); - SmallVector convertedTdescTypes = - getUnrolledTypes(tdescTy, *targetShape); - - if (hasOffsets) { - // only need one tdesc, tile offsets will be computed - // at the operation level - convertedTdescTypes.resize(1); - } + SmallVector convertedTdescTypes = getUnrolledTypes( + tdescTy, *targetShape, /*returnSingleType*/ hasOffsets); SmallVector convertedTdescs = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); @@ -351,14 +340,8 @@ struct UnrollStoreNdOp : public UnrollPattern { SmallVector convertedValTypes = getUnrolledTypes(valueTy, *targetShape); - SmallVector convertedTdescTypes = - getUnrolledTypes(tdescTy, *targetShape); - - if (hasOffsets) { - // only need one tdesc, tile offsets will be computed - // at the operation level - convertedTdescTypes.resize(1); - } + SmallVector convertedTdescTypes = getUnrolledTypes( + tdescTy, *targetShape, /*returnSingleType*/ hasOffsets); SmallVector convertedTdescs = pack( op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index e1ba45c60ac36..b2bdf3efc65f7 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -95,7 +95,8 @@ struct TestXeGPUUnrollingPatterns }); options.setUnrolledTypesFn( - [&](ShapedType type, ArrayRef tileShape) -> SmallVector { + [&](ShapedType type, ArrayRef tileShape, + bool returnSingleType = false) -> SmallVector { Type elemTy = type.getElementType(); Type newTy; @@ -137,6 +138,8 @@ struct TestXeGPUUnrollingPatterns newTy = type.clone(tileShape, elemTy); } + if (returnSingleType) + return SmallVector{newTy}; std::optional> ratio = computeShapeRatio(type.getShape(), tileShape); assert(ratio && "Expecting the ratio to be valid."); From 04959fec89a11c7e4115664cb26334a3917adf17 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Thu, 25 Sep 2025 16:15:52 +0000 Subject: [PATCH 4/4] fix offsets in lit tests Signed-off-by: dchigarev --- .../XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir index cbfd991b5557e..6eee5a544e3f8 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir @@ -4,7 +4,7 @@ gpu.module @xevm_test { // CHECK-LABEL: create_nd_tdesc // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> - // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast // CHECK-SAME: !xegpu.tensor_desc<8x16xf32> // CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout> {__xegpu_blocking_tile_shape__ = array, __xegpu_blocking_unpack__} @@ -16,7 +16,7 @@ gpu.module @xevm_test { //----- // CHECK-LABEL: load_nd // CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32> - // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> // CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32> gpu.func @load_nd(%src: memref<256x318xf32>) -> vector<24x32xf32> { @@ -28,9 +28,9 @@ gpu.module @xevm_test { //----- // CHECK-LABEL: load_nd_store_nd // CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32> - //CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32> - //CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> - //CHECK-COUNT-6: xegpu.store_nd {{.*}}[{{.*}}] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + // CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + // CHECK-COUNT-6: xegpu.store_nd {{.*}}[{{.*}}] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.func @load_nd_store_nd(%src: memref<256x318xf32>) { %tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> %ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout> -> vector<24x32xf32> @@ -41,7 +41,7 @@ gpu.module @xevm_test { //----- // CHECK-LABEL: prefetch_nd_tdesc // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> - // CHECK-COUNT-1: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK-COUNT-6: xegpu.prefetch_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) { %tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout>