Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tosa] Change Transpose perms operand to attribute #128115

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2023,7 +2023,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",

let arguments = (ins
Tosa_Tensor:$input1,
Tosa_Int32Tensor:$perms
DenseI32ArrayAttr:$perms
);

let results = (
Expand All @@ -2035,10 +2035,6 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];

let extraClassDeclaration = [{
LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
}];

let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
Expand Down
16 changes: 5 additions & 11 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
Value weightPermValue =
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
weightPermValue);
weightPermAttr);
}
}

Expand All @@ -353,13 +351,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
Value weightPermValue =
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
weightPermValue);
weightPermAttr);
}

// Extract the attributes for convolution.
Expand Down Expand Up @@ -970,9 +966,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {

LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const final {
SmallVector<int32_t> constantPerms;
if (failed(op.getConstantPerms(constantPerms)))
return failure();
const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();

Location loc = op.getLoc();
// The verifier should have made sure we have a valid TOSA permutation
Expand Down
31 changes: 7 additions & 24 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,10 @@ struct ConsolidateTransposeOptimization
return rewriter.notifyMatchFailure(transposeOp,
"input must be transpose operation");

SmallVector<int32_t> transposePerms, innerTransposePerms;
if (transposeOp.getConstantPerms(transposePerms).failed())
return rewriter.notifyMatchFailure(transposeOp,
"transpose perms must be constant");
if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
return rewriter.notifyMatchFailure(
transposeOp, "inner transpose perms must be constant");
const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
const llvm::ArrayRef<int32_t> innerTransposePerms =
innerTranspose.getPerms();

if (transposePerms.size() != innerTransposePerms.size())
return rewriter.notifyMatchFailure(
transposeOp,
Expand All @@ -108,15 +105,9 @@ struct ConsolidateTransposeOptimization
for (int i = 0, s = transposePerms.size(); i < s; ++i)
perms[i] = innerTransposePerms[transposePerms[i]];

auto permsTy =
RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
permsTy, permsAttr);

rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
transposeOp, transposeOp.getResult().getType(),
innerTranspose.getInput1(), permsValue);
innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));

return success();
}
Expand All @@ -128,10 +119,6 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {

LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
DenseIntElementsAttr permAttr;
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
return rewriter.notifyMatchFailure(op, "Non-constant permutation");

if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
return rewriter.notifyMatchFailure(
op, "Src is from transpose, can compose transposes");
Expand All @@ -156,9 +143,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
if (numDynDims > 1)
return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");

SmallVector<int64_t> permValues = llvm::to_vector<6>(
llvm::map_range(permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getSExtValue(); }));
const llvm::ArrayRef<int32_t> permValues = op.getPerms();

SmallVector<int64_t> nonZeroPerms;
nonZeroPerms.reserve(permValues.size());
Expand Down Expand Up @@ -1175,9 +1160,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
}

// Transpose is not the identity transpose.
SmallVector<int32_t> perms;
if (getConstantPerms(perms).failed())
return {};
const llvm::ArrayRef<int32_t> perms = getPerms();

if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
return {};
Expand Down
157 changes: 56 additions & 101 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1374,54 +1374,37 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
return mlir::success();
}

LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
// Perms must be constants.
DenseIntElementsAttr permsAttr;
if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
return failure();

perms.clear();
for (auto v : permsAttr.getValues<APInt>())
perms.push_back(v.getSExtValue());

return success();
}

LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TransposeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
ShapeAdaptor permsShape(adaptor.getPerms().getType());

// We cannot infer anything from a rank-0 "permutation" tensor.
if (permsShape.hasRank() && permsShape.getRank() == 0)
return failure();

// If input rank and permutation length is unknown, the output rank is
// unknown.
if (!inputShape.hasRank() || !permsShape.hasRank() ||
permsShape.isDynamicDim(0)) {
if (!inputShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
return success();
}

const auto inputRank = inputShape.getRank();

// This would imply the number of permutations does not match the rank of
// the input which is illegal.
if (permsShape.getDimSize(0) != inputShape.getRank()) {
if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
return failure();
}

SmallVector<int64_t> outputShape;
// Rank-0 means no permutations matter.
if (inputShape.getRank() == 0) {
if (inputRank == 0) {
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}

// Check whether the input dimensions are all the same.
bool allTheSame = true;
for (int i = 1, s = inputShape.getRank(); i < s; i++) {
for (int i = 1, s = inputRank; i < s; i++) {
if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
allTheSame = false;
break;
Expand All @@ -1431,34 +1414,21 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
// If all of the input dimensions are the same we don't care about the
// permutation.
if (allTheSame) {
outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
outputShape.resize(inputRank, inputShape.getDimSize(0));
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}

outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
// If the permuations are a constant we can directly determine the output
// shape.
DenseIntElementsAttr attr;
if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
attr.getType().getRank() == 1) {
ShapeAdaptor permShape = attr;
// Constant permutation must be the same length as the input rank.
if (inputShape.getRank() != permShape.getRank())
return emitOptionalError(location,
"constant permutation must be the same length"
" as the input rank");

// Constant permutation values must be within the input rank.
for (int i = 0, e = inputShape.getRank(); i < e; i++) {
if (inputShape.getRank() <= permShape.getDimSize(i))
return failure();
}
outputShape.resize(inputRank, ShapedType::kDynamic);

outputShape.reserve(inputShape.getRank());
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
}
// Constant permutation values must be within the input rank.
if (llvm::any_of(adaptor.getPerms(),
[inputRank](const auto i) { return i >= inputRank; }))
return failure();

outputShape.reserve(inputRank);
for (int i = 0, s = inputRank; i < s; i++) {
outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
}

inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
Expand All @@ -1467,75 +1437,60 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(

LogicalResult tosa::TransposeOp::verify() {
TensorType inputType = getInput1().getType();
TensorType permType = getPerms().getType();
TensorType outputType = getOutput().getType();
const llvm::ArrayRef<int32_t> constantPerms = getPerms();

if (permType.hasRank() && permType.getRank() != 1)
return emitOpError()
<< "expected permutation tensor to be rank 1 but got rank "
<< permType.getRank();
if (inputType.hasRank() && permType.hasRank())
if (!permType.isDynamicDim(0) &&
permType.getDimSize(0) != inputType.getRank())
return emitOpError() << "expected permutation tensor dim 0 to have size "
<< inputType.getRank()
<< " (input rank) but got size "
<< permType.getDimSize(0);
if (inputType.hasRank() &&
constantPerms.size() != static_cast<size_t>(inputType.getRank()))
return emitOpError() << "expected perms attribute to have size "
<< inputType.getRank() << " (input rank) but got size "
<< constantPerms.size();
if (inputType.hasRank() && outputType.hasRank() &&
inputType.getRank() != outputType.getRank())
return emitOpError()
<< "expected input tensor rank to equal result tensor rank";
if (outputType.hasRank() && permType.hasRank())
if (!permType.isDynamicDim(0) &&
permType.getDimSize(0) != outputType.getRank())
return emitOpError() << "expected permutation tensor dim 0 to have size "
<< outputType.getRank()
<< " (output rank) but got size "
<< permType.getDimSize(0);

SmallVector<int32_t> constantPerms;
if (succeeded(getConstantPerms(constantPerms))) {
// Assert that the permutation tensor has a rank, which means that the
// rank has been verified above.
assert(permType.hasRank() &&
"Unexpectedly found permutation tensor without rank");
if (!llvm::all_of(constantPerms,
[&constantPerms](int32_t s) {
return s >= 0 &&
static_cast<size_t>(s) < constantPerms.size();
}) ||
!isPermutationVector(llvm::to_vector(llvm::map_range(
constantPerms, [](int32_t v) -> int64_t { return v; }))))
return emitOpError() << "expected valid permutation tensor";

// Verify that the types of the input and output tensors are properly
// permuted.
if (inputType.hasRank() && outputType.hasRank()) {
assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
inputType.getRank() == outputType.getRank());

for (auto i = 0; i < outputType.getRank(); i++) {
if (inputType.isDynamicDim(constantPerms[i]) ||
outputType.isDynamicDim(i))
continue;

if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
return emitOpError()
<< "expected output tensor dim " << i << " to match "
<< "input dim " << constantPerms[i] << " with value of "
<< inputType.getDimSize(constantPerms[i]);
}
if (outputType.hasRank() &&
constantPerms.size() != static_cast<size_t>(outputType.getRank()))
return emitOpError() << "expected perms attribute to have size "
<< outputType.getRank()
<< " (output rank) but got size "
<< constantPerms.size();

if (!llvm::all_of(constantPerms,
[&constantPerms](int32_t s) {
return s >= 0 &&
static_cast<size_t>(s) < constantPerms.size();
}) ||
!isPermutationVector(llvm::to_vector(llvm::map_range(
constantPerms, [](int32_t v) -> int64_t { return v; }))))
return emitOpError() << "expected valid permutation indices";

// Verify that the types of the input and output tensors are properly
// permuted.
if (inputType.hasRank() && outputType.hasRank()) {
assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
inputType.getRank() == outputType.getRank());

for (auto i = 0; i < outputType.getRank(); i++) {
if (inputType.isDynamicDim(constantPerms[i]) ||
outputType.isDynamicDim(i))
continue;

if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
return emitOpError()
<< "expected output tensor dim " << i << " to match "
<< "input dim " << constantPerms[i] << " with value of "
<< inputType.getDimSize(constantPerms[i]);
}
}

return success();
}

LogicalResult TransposeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {

SmallVector<int32_t> transposePerms;
if (getConstantPerms(transposePerms).failed())
return failure();
const llvm::ArrayRef<int32_t> transposePerms = getPerms();

Value input = getInput1();
auto inputType = cast<TensorType>(input.getType());
Expand Down
12 changes: 2 additions & 10 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,9 @@ class TransposeConvStridedConverter
getTosaConstShape(rewriter, loc, weightReshapeDims0));

// Transpose the factored-out stride to the output channels.
Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));

weight = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
transposeWeightVal);
rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));

// Collapse the strides and output channels into a single dimension.
llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
Expand Down Expand Up @@ -269,13 +265,9 @@ class TransposeConvStridedConverter
convReshapeDims0Value);

// Transpose the factored-out stride to the output channels.
Value transposeConvVal = rewriter.create<tosa::ConstOp>(
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));

conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
transposeConvVal);
rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));

// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
Expand Down
7 changes: 1 addition & 6 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
return failure();

DenseIntElementsAttr permAttr;
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
return failure();
auto permValues = llvm::map_to_vector(
// TOSA allows both 32- and 64-bit integer tensors here.
permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getSExtValue(); });
op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });

auto inputType = cast<ShapedType>(op.getInput1().getType());

Expand Down
Loading