Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H

#include "mlir/IR/Operation.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/IR/Operation.h"

#include <functional>
#include <optional>
Expand Down Expand Up @@ -47,9 +47,11 @@ struct UnrollOptions {

/// Function that converts a ShapedType (TensorDescType or VectorType)
/// into the unrolled type based on the tileShape. It returns a vector of
/// types representing the unrolled types for simplicity.
/// types representing the unrolled types for simplicity. When
/// `returnSingleType` is true, it returns a vector containing only one single
/// unrolled type.
using UnrolledTypeFnType = std::function<SmallVector<Type>(
ShapedType type, ArrayRef<int64_t> tileShape)>;
ShapedType type, ArrayRef<int64_t> tileShape, bool returnSingleType)>;
UnrolledTypeFnType getUnrolledTypes = nullptr;
UnrollOptions &setUnrolledTypesFn(UnrolledTypeFnType fn) {
getUnrolledTypes = std::move(fn);
Expand Down
5 changes: 4 additions & 1 deletion mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ void XeGPUBlockingPass::runOnOperation() {

options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });

options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape,
bool returnSingleType = false) {
Type elemTy = type.getElementType();
Type newTy;

Expand Down Expand Up @@ -352,6 +353,8 @@ void XeGPUBlockingPass::runOnOperation() {
newTy = type.clone(tileShape, elemTy);
}

if (returnSingleType)
return SmallVector<Type>{newTy};
std::optional<SmallVector<int64_t>> ratio =
computeShapeRatio(type.getShape(), tileShape);
assert(ratio && "The shape of the type must be a multiple of tileShape.");
Expand Down
183 changes: 124 additions & 59 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
}

SmallVector<Type> getUnrolledTypes(ShapedType type,
ArrayRef<int64_t> tileShape) const {
return options.getUnrolledTypes(type, tileShape);
ArrayRef<int64_t> tileShape,
bool returnSingleType = false) const {
return options.getUnrolledTypes(type, tileShape, returnSingleType);
}

/// Emulate the the unpack behavior using insert_strided_slice for VectorType
Expand Down Expand Up @@ -121,53 +122,79 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
xegpu::UnrollOptions options;
};

// Generic helper function for unrolling operations with offsets.
//
// Iterates over tile offsets within the tensor descriptor shape and calls
// the provided createOp function for each computed offset. This is used by
// operations like LoadNd, StoreNd, CreateNdDesc, and PrefetchNd when they
// have explicit offsets that need to be adjusted for each unrolled tile.
SmallVector<Value> computeUnrolledOffsets(
SmallVector<OpFoldResult> mixedOffsets, xegpu::TensorDescType tdescTy,
ArrayRef<int64_t> targetShape,
const std::function<Value(SmallVector<OpFoldResult>)> &createOp,
Location loc, PatternRewriter &rewriter) {
int64_t rank = tdescTy.getRank();
ArrayRef<int64_t> shape = tdescTy.getShape();
Comment on lines +136 to +137
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the logic of this function was copied from UnrollCreateNdOp, it is now used in UnrollCreateNdOp/LoadNd/StoreNd/PrefetchNd.


auto addi = [&](OpFoldResult a, int64_t b) -> Value {
std::optional<int64_t> maybeInt = getConstantIntValue(a);
if (maybeInt) {
return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
} else {
auto aV = llvm::cast<Value>(a);
auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
}
};

SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
auto validIdxes =
llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());

SmallVector<Value> newOps;
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(shape, targetShape)) {

for (auto [idx, oldOff, offset] :
llvm::zip(validIdxes, oldOffsets, offsets))
mixedOffsets[idx] = addi(oldOff, offset);

auto newOp = createOp(mixedOffsets);
newOps.push_back(newOp);
}
return newOps;
}

struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
using UnrollPattern<xegpu::CreateNdDescOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getType();
int64_t rank = tdescTy.getRank();
ArrayRef<int64_t> shape = tdescTy.getShape();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];

auto addi = [&](OpFoldResult a, int64_t b) -> Value {
std::optional<int64_t> maybeInt = getConstantIntValue(a);
if (maybeInt) {
return arith::ConstantIndexOp::create(rewriter, loc, *maybeInt + b);
} else {
auto aV = llvm::cast<Value>(a);
auto bV = arith::ConstantIndexOp::create(rewriter, loc, b);
return rewriter.createOrFold<arith::AddIOp>(loc, aV, bV);
}
};

SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();

// For n-D memrefs where n > rank, we need to handle the last `rank`
// dimensions only, and keep the first `n-rank` dimensions as is.
SmallVector<OpFoldResult> oldOffsets = llvm::to_vector(
llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank));
auto validIdxes =
llvm::seq<int64_t>(mixedOffsets.size() - rank, mixedOffsets.size());

SmallVector<Value> newOps;
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(shape, *targetShape)) {

for (auto [idx, oldOff, offset] :
llvm::zip(validIdxes, oldOffsets, offsets))
mixedOffsets[idx] = addi(oldOff, offset);

auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
bool hasOffsets = op.getMixedOffsets().size() != 0;
if (!hasOffsets) {
auto newOp = xegpu::CreateNdDescOp::create(
rewriter, loc, newTdescTy, op.getSource(), mixedOffsets,
op.getMixedSizes(), op.getMixedStrides());
rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
op.getMixedStrides());
newOps.push_back(newOp);
} else {
auto createOp = [&](SmallVector<OpFoldResult> offsets) -> Value {
return xegpu::CreateNdDescOp::create(
rewriter, loc, newTdescTy, op.getSource(), offsets,
op.getMixedSizes(), op.getMixedStrides());
};

newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
*targetShape, createOp, loc, rewriter);
}
Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
rewriter.replaceOp(op, castOp);
Expand Down Expand Up @@ -216,17 +243,30 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
return failure();

int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
if ((offsetSize != 0) || op.getConstOffsetsAttr())
return failure();
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();

SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);

SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Value> convertedTdesc = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

for (auto t : convertedTdesc)
xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
op->getAttrs());
if (!hasOffsets) {
for (auto t : convertedTdesc)
xegpu::PrefetchNdOp::create(rewriter, loc, TypeRange(), t,
op->getAttrs());
} else {
auto createPrefetch = [&](SmallVector<OpFoldResult> offsets) -> Value {
xegpu::PrefetchNdOp::create(rewriter, loc, convertedTdesc[0], offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
// return dummy Value to satisfy function's signature
return nullptr;
};

computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
createPrefetch, loc, rewriter);
}

rewriter.eraseOp(op);
return success();
Expand All @@ -247,22 +287,33 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
return failure();

int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
if ((offsetSize != 0) || op.getConstOffsetsAttr())
return failure();
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();

Type elemTy = tdescTy.getElementType();
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);

SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);

SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

SmallVector<Value> newOps;
for (auto t : convertedTdescs) {
auto newOp =
xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t, op->getAttrs());
newOps.push_back(newOp);

if (!hasOffsets) {
for (auto t : convertedTdescs) {
auto newOp = xegpu::LoadNdOp::create(rewriter, loc, newValueTy, t,
op->getAttrs());
newOps.push_back(newOp);
}
} else {
auto createLoad = [&](SmallVector<OpFoldResult> offsets) {
return xegpu::LoadNdOp::create(
rewriter, loc, newValueTy, convertedTdescs[0], offsets,
op.getPackedAttr(), op.getTransposeAttr(), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
};
newOps = computeUnrolledOffsets(op.getMixedOffsets(), tdescTy,
*targetShape, createLoad, loc, rewriter);
}

Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
Expand All @@ -285,22 +336,36 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
return failure();

int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
if ((offsetSize != 0) || op.getConstOffsetsAttr())
return failure();
bool hasOffsets = (offsetSize != 0) || op.getConstOffsetsAttr();

SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
SmallVector<Type> convertedTdescTypes =
getUnrolledTypes(tdescTy, *targetShape);
SmallVector<Type> convertedTdescTypes = getUnrolledTypes(
tdescTy, *targetShape, /*returnSingleType*/ hasOffsets);

SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
SmallVector<Value> convertedTdescs = pack(
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);

for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
if (!hasOffsets) {
for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
xegpu::StoreNdOp::create(rewriter, loc, v, t, op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
} else {
size_t valueIndex = 0;
auto createStore = [&](SmallVector<OpFoldResult> offsets) {
xegpu::StoreNdOp::create(rewriter, loc, convertedValues[valueIndex++],
convertedTdescs[0], offsets,
op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
// return dummy Value to satisfy function's signature
return nullptr;
};

computeUnrolledOffsets(op.getMixedOffsets(), tdescTy, *targetShape,
createStore, loc, rewriter);
}

rewriter.eraseOp(op);
return success();
Expand Down
61 changes: 61 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-unroll-patterns-no-desc-offsets.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: mlir-opt --test-xegpu-unrolling-patterns -split-input-file %s | FileCheck %s

gpu.module @xevm_test {

// CHECK-LABEL: create_nd_tdesc
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
// CHECK-SAME: !xegpu.tensor_desc<8x16xf32>
// CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
%tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
}

//-----
// CHECK-LABEL: load_nd
// CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32>
// CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
// CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
gpu.func @load_nd(%src: memref<256x318xf32>) -> vector<24x32xf32> {
%tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
gpu.return %ld : vector<24x32xf32>
}

//-----
// CHECK-LABEL: load_nd_store_nd
// CHECK-SAME: [[arg0:%.+]]: memref<256x318xf32>
// CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<256x318xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
// CHECK-COUNT-6: xegpu.store_nd {{.*}}[{{.*}}] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
gpu.func @load_nd_store_nd(%src: memref<256x318xf32>) {
%tdesc = xegpu.create_nd_tdesc %src : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
xegpu.store_nd %ld, %tdesc[0, 0] : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.return
}

//-----
// CHECK-LABEL: prefetch_nd_tdesc
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
// CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-COUNT-6: xegpu.prefetch_nd {{.*}}[{{.*}}] : !xegpu.tensor_desc<8x16xf32>
gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
%tdesc = xegpu.create_nd_tdesc %src : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
xegpu.prefetch_nd %tdesc[8, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
gpu.return
}

//-----

// CHECK-LABEL: load_nd_offsets_at_both_places
// CHECK-COUNT-2: builtin.unrealized_conversion_cast
gpu.func @load_nd_offsets_at_both_places(%src: memref<256x318xf32>) -> vector<24x32xf32> {
%tdesc = xegpu.create_nd_tdesc %src[16, 8] : memref<256x318xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
%ld = xegpu.load_nd %tdesc[8, 16]: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
gpu.return %ld : vector<24x32xf32>
}
}
5 changes: 4 additions & 1 deletion mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ struct TestXeGPUUnrollingPatterns
});

options.setUnrolledTypesFn(
[&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> {
[&](ShapedType type, ArrayRef<int64_t> tileShape,
bool returnSingleType = false) -> SmallVector<Type> {
Type elemTy = type.getElementType();
Type newTy;

Expand Down Expand Up @@ -137,6 +138,8 @@ struct TestXeGPUUnrollingPatterns
newTy = type.clone(tileShape, elemTy);
}

if (returnSingleType)
return SmallVector<Type>{newTy};
std::optional<SmallVector<int64_t>> ratio =
computeShapeRatio(type.getShape(), tileShape);
assert(ratio && "Expecting the ratio to be valid.");
Expand Down