Skip to content

Commit 48db4e8

Browse files
authoredFeb 25, 2025··
[mlir][tosa] Change Transpose perms operand to attribute (#128115)
This patch changes the perms operand for Tosa Transpose operator to an i32 array attribute Signed-off-by: Tai Ly <tai.ly@arm.com>
1 parent d2d469e commit 48db4e8

20 files changed

+250
-509
lines changed
 

‎mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

+1-5
Original file line numberDiff line numberDiff line change
@@ -2023,7 +2023,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
20232023

20242024
let arguments = (ins
20252025
Tosa_Tensor:$input1,
2026-
Tosa_Int32Tensor:$perms
2026+
DenseI32ArrayAttr:$perms
20272027
);
20282028

20292029
let results = (
@@ -2035,10 +2035,6 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
20352035
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
20362036
];
20372037

2038-
let extraClassDeclaration = [{
2039-
LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
2040-
}];
2041-
20422038
let hasCanonicalizer = 1;
20432039
let hasFolder = 1;
20442040
let hasVerifier = 1;

‎mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

+5-11
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
329329
SmallVector<int64_t> newWeightShape;
330330
for (auto dim : weightPerm)
331331
newWeightShape.push_back(weightShape[dim]);
332-
auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
333-
Value weightPermValue =
334-
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
332+
auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
335333
Type newWeightTy =
336334
RankedTensorType::get(newWeightShape, weightTy.getElementType());
337335
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
338-
weightPermValue);
336+
weightPermAttr);
339337
}
340338
}
341339

@@ -353,13 +351,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
353351
SmallVector<int64_t> newWeightShape;
354352
for (auto dim : weightPerm)
355353
newWeightShape.push_back(weightShape[dim]);
356-
auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
357-
Value weightPermValue =
358-
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
354+
auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
359355
Type newWeightTy =
360356
RankedTensorType::get(newWeightShape, weightTy.getElementType());
361357
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
362-
weightPermValue);
358+
weightPermAttr);
363359
}
364360

365361
// Extract the attributes for convolution.
@@ -1003,9 +999,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1003999

10041000
LogicalResult matchAndRewrite(tosa::TransposeOp op,
10051001
PatternRewriter &rewriter) const final {
1006-
SmallVector<int32_t> constantPerms;
1007-
if (failed(op.getConstantPerms(constantPerms)))
1008-
return failure();
1002+
const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
10091003

10101004
Location loc = op.getLoc();
10111005
// The verifier should have made sure we have a valid TOSA permutation

‎mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

+7-24
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,10 @@ struct ConsolidateTransposeOptimization
8888
return rewriter.notifyMatchFailure(transposeOp,
8989
"input must be transpose operation");
9090

91-
SmallVector<int32_t> transposePerms, innerTransposePerms;
92-
if (transposeOp.getConstantPerms(transposePerms).failed())
93-
return rewriter.notifyMatchFailure(transposeOp,
94-
"transpose perms must be constant");
95-
if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
96-
return rewriter.notifyMatchFailure(
97-
transposeOp, "inner transpose perms must be constant");
91+
const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
92+
const llvm::ArrayRef<int32_t> innerTransposePerms =
93+
innerTranspose.getPerms();
94+
9895
if (transposePerms.size() != innerTransposePerms.size())
9996
return rewriter.notifyMatchFailure(
10097
transposeOp,
@@ -108,15 +105,9 @@ struct ConsolidateTransposeOptimization
108105
for (int i = 0, s = transposePerms.size(); i < s; ++i)
109106
perms[i] = innerTransposePerms[transposePerms[i]];
110107

111-
auto permsTy =
112-
RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
113-
auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
114-
Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
115-
permsTy, permsAttr);
116-
117108
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
118109
transposeOp, transposeOp.getResult().getType(),
119-
innerTranspose.getInput1(), permsValue);
110+
innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
120111

121112
return success();
122113
}
@@ -128,10 +119,6 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
128119

129120
LogicalResult matchAndRewrite(tosa::TransposeOp op,
130121
PatternRewriter &rewriter) const override {
131-
DenseIntElementsAttr permAttr;
132-
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
133-
return rewriter.notifyMatchFailure(op, "Non-constant permutation");
134-
135122
if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
136123
return rewriter.notifyMatchFailure(
137124
op, "Src is from transpose, can compose transposes");
@@ -156,9 +143,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
156143
if (numDynDims > 1)
157144
return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
158145

159-
SmallVector<int64_t> permValues = llvm::to_vector<6>(
160-
llvm::map_range(permAttr.getValues<APInt>(),
161-
[](const APInt &val) { return val.getSExtValue(); }));
146+
const llvm::ArrayRef<int32_t> permValues = op.getPerms();
162147

163148
SmallVector<int64_t> nonZeroPerms;
164149
nonZeroPerms.reserve(permValues.size());
@@ -1176,9 +1161,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
11761161
}
11771162

11781163
// Transpose is not the identity transpose.
1179-
SmallVector<int32_t> perms;
1180-
if (getConstantPerms(perms).failed())
1181-
return {};
1164+
const llvm::ArrayRef<int32_t> perms = getPerms();
11821165

11831166
if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
11841167
return {};

‎mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

+56-101
Original file line numberDiff line numberDiff line change
@@ -1372,54 +1372,37 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
13721372
return mlir::success();
13731373
}
13741374

1375-
LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
1376-
// Perms must be constants.
1377-
DenseIntElementsAttr permsAttr;
1378-
if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
1379-
return failure();
1380-
1381-
perms.clear();
1382-
for (auto v : permsAttr.getValues<APInt>())
1383-
perms.push_back(v.getSExtValue());
1384-
1385-
return success();
1386-
}
1387-
13881375
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
13891376
MLIRContext *context, ::std::optional<Location> location,
13901377
TransposeOp::Adaptor adaptor,
13911378
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
13921379
ShapeAdaptor inputShape(adaptor.getInput1().getType());
1393-
ShapeAdaptor permsShape(adaptor.getPerms().getType());
1394-
1395-
// We cannot infer anything from a rank-0 "permutation" tensor.
1396-
if (permsShape.hasRank() && permsShape.getRank() == 0)
1397-
return failure();
13981380

13991381
// If input rank and permutation length is unknown, the output rank is
14001382
// unknown.
1401-
if (!inputShape.hasRank() || !permsShape.hasRank() ||
1402-
permsShape.isDynamicDim(0)) {
1383+
if (!inputShape.hasRank()) {
14031384
inferredReturnShapes.push_back(ShapedTypeComponents());
14041385
return success();
14051386
}
14061387

1388+
const auto inputRank = inputShape.getRank();
1389+
14071390
// This would imply the number of permutations does not match the rank of
14081391
// the input which is illegal.
1409-
if (permsShape.getDimSize(0) != inputShape.getRank()) {
1392+
if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
14101393
return failure();
14111394
}
14121395

14131396
SmallVector<int64_t> outputShape;
14141397
// Rank-0 means no permutations matter.
1415-
if (inputShape.getRank() == 0) {
1398+
if (inputRank == 0) {
14161399
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
14171400
return success();
14181401
}
14191402

14201403
// Check whether the input dimensions are all the same.
14211404
bool allTheSame = true;
1422-
for (int i = 1, s = inputShape.getRank(); i < s; i++) {
1405+
for (int i = 1, s = inputRank; i < s; i++) {
14231406
if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
14241407
allTheSame = false;
14251408
break;
@@ -1429,34 +1412,21 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
14291412
// If all of the input dimensions are the same we don't care about the
14301413
// permutation.
14311414
if (allTheSame) {
1432-
outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1415+
outputShape.resize(inputRank, inputShape.getDimSize(0));
14331416
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
14341417
return success();
14351418
}
14361419

1437-
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1438-
// If the permuations are a constant we can directly determine the output
1439-
// shape.
1440-
DenseIntElementsAttr attr;
1441-
if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
1442-
attr.getType().getRank() == 1) {
1443-
ShapeAdaptor permShape = attr;
1444-
// Constant permutation must be the same length as the input rank.
1445-
if (inputShape.getRank() != permShape.getRank())
1446-
return emitOptionalError(location,
1447-
"constant permutation must be the same length"
1448-
" as the input rank");
1449-
1450-
// Constant permutation values must be within the input rank.
1451-
for (int i = 0, e = inputShape.getRank(); i < e; i++) {
1452-
if (inputShape.getRank() <= permShape.getDimSize(i))
1453-
return failure();
1454-
}
1420+
outputShape.resize(inputRank, ShapedType::kDynamic);
14551421

1456-
outputShape.reserve(inputShape.getRank());
1457-
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1458-
outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
1459-
}
1422+
// Constant permutation values must be within the input rank.
1423+
if (llvm::any_of(adaptor.getPerms(),
1424+
[inputRank](const auto i) { return i >= inputRank; }))
1425+
return failure();
1426+
1427+
outputShape.reserve(inputRank);
1428+
for (int i = 0, s = inputRank; i < s; i++) {
1429+
outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
14601430
}
14611431

14621432
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -1465,75 +1435,60 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
14651435

14661436
LogicalResult tosa::TransposeOp::verify() {
14671437
TensorType inputType = getInput1().getType();
1468-
TensorType permType = getPerms().getType();
14691438
TensorType outputType = getOutput().getType();
1439+
const llvm::ArrayRef<int32_t> constantPerms = getPerms();
14701440

1471-
if (permType.hasRank() && permType.getRank() != 1)
1472-
return emitOpError()
1473-
<< "expected permutation tensor to be rank 1 but got rank "
1474-
<< permType.getRank();
1475-
if (inputType.hasRank() && permType.hasRank())
1476-
if (!permType.isDynamicDim(0) &&
1477-
permType.getDimSize(0) != inputType.getRank())
1478-
return emitOpError() << "expected permutation tensor dim 0 to have size "
1479-
<< inputType.getRank()
1480-
<< " (input rank) but got size "
1481-
<< permType.getDimSize(0);
1441+
if (inputType.hasRank() &&
1442+
constantPerms.size() != static_cast<size_t>(inputType.getRank()))
1443+
return emitOpError() << "expected perms attribute to have size "
1444+
<< inputType.getRank() << " (input rank) but got size "
1445+
<< constantPerms.size();
14821446
if (inputType.hasRank() && outputType.hasRank() &&
14831447
inputType.getRank() != outputType.getRank())
14841448
return emitOpError()
14851449
<< "expected input tensor rank to equal result tensor rank";
1486-
if (outputType.hasRank() && permType.hasRank())
1487-
if (!permType.isDynamicDim(0) &&
1488-
permType.getDimSize(0) != outputType.getRank())
1489-
return emitOpError() << "expected permutation tensor dim 0 to have size "
1490-
<< outputType.getRank()
1491-
<< " (output rank) but got size "
1492-
<< permType.getDimSize(0);
1493-
1494-
SmallVector<int32_t> constantPerms;
1495-
if (succeeded(getConstantPerms(constantPerms))) {
1496-
// Assert that the permutation tensor has a rank, which means that the
1497-
// rank has been verified above.
1498-
assert(permType.hasRank() &&
1499-
"Unexpectedly found permutation tensor without rank");
1500-
if (!llvm::all_of(constantPerms,
1501-
[&constantPerms](int32_t s) {
1502-
return s >= 0 &&
1503-
static_cast<size_t>(s) < constantPerms.size();
1504-
}) ||
1505-
!isPermutationVector(llvm::to_vector(llvm::map_range(
1506-
constantPerms, [](int32_t v) -> int64_t { return v; }))))
1507-
return emitOpError() << "expected valid permutation tensor";
1508-
1509-
// Verify that the types of the input and output tensors are properly
1510-
// permuted.
1511-
if (inputType.hasRank() && outputType.hasRank()) {
1512-
assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
1513-
inputType.getRank() == outputType.getRank());
1514-
1515-
for (auto i = 0; i < outputType.getRank(); i++) {
1516-
if (inputType.isDynamicDim(constantPerms[i]) ||
1517-
outputType.isDynamicDim(i))
1518-
continue;
1519-
1520-
if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
1521-
return emitOpError()
1522-
<< "expected output tensor dim " << i << " to match "
1523-
<< "input dim " << constantPerms[i] << " with value of "
1524-
<< inputType.getDimSize(constantPerms[i]);
1525-
}
1450+
if (outputType.hasRank() &&
1451+
constantPerms.size() != static_cast<size_t>(outputType.getRank()))
1452+
return emitOpError() << "expected perms attribute to have size "
1453+
<< outputType.getRank()
1454+
<< " (output rank) but got size "
1455+
<< constantPerms.size();
1456+
1457+
if (!llvm::all_of(constantPerms,
1458+
[&constantPerms](int32_t s) {
1459+
return s >= 0 &&
1460+
static_cast<size_t>(s) < constantPerms.size();
1461+
}) ||
1462+
!isPermutationVector(llvm::to_vector(llvm::map_range(
1463+
constantPerms, [](int32_t v) -> int64_t { return v; }))))
1464+
return emitOpError() << "expected valid permutation indices";
1465+
1466+
// Verify that the types of the input and output tensors are properly
1467+
// permuted.
1468+
if (inputType.hasRank() && outputType.hasRank()) {
1469+
assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
1470+
inputType.getRank() == outputType.getRank());
1471+
1472+
for (auto i = 0; i < outputType.getRank(); i++) {
1473+
if (inputType.isDynamicDim(constantPerms[i]) ||
1474+
outputType.isDynamicDim(i))
1475+
continue;
1476+
1477+
if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
1478+
return emitOpError()
1479+
<< "expected output tensor dim " << i << " to match "
1480+
<< "input dim " << constantPerms[i] << " with value of "
1481+
<< inputType.getDimSize(constantPerms[i]);
15261482
}
15271483
}
1484+
15281485
return success();
15291486
}
15301487

15311488
LogicalResult TransposeOp::reifyResultShapes(
15321489
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
15331490

1534-
SmallVector<int32_t> transposePerms;
1535-
if (getConstantPerms(transposePerms).failed())
1536-
return failure();
1491+
const llvm::ArrayRef<int32_t> transposePerms = getPerms();
15371492

15381493
Value input = getInput1();
15391494
auto inputType = cast<TensorType>(input.getType());

‎mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

+2-10
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,9 @@ class TransposeConvStridedConverter
166166
getTosaConstShape(rewriter, loc, weightReshapeDims0));
167167

168168
// Transpose the factored-out stride to the output channels.
169-
Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
170-
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
171-
rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
172-
173169
weight = CreateOpAndInferShape<tosa::TransposeOp>(
174170
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
175-
transposeWeightVal);
171+
rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
176172

177173
// Collapse the strides and output channels into a single dimension.
178174
llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
@@ -269,13 +265,9 @@ class TransposeConvStridedConverter
269265
convReshapeDims0Value);
270266

271267
// Transpose the factored-out stride to the output channels.
272-
Value transposeConvVal = rewriter.create<tosa::ConstOp>(
273-
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
274-
rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
275-
276268
conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
277269
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
278-
transposeConvVal);
270+
rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
279271

280272
// Fuse striding behavior back into width / height.
281273
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {

‎mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

+1-6
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
224224
if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
225225
return failure();
226226

227-
DenseIntElementsAttr permAttr;
228-
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
229-
return failure();
230227
auto permValues = llvm::map_to_vector(
231-
// TOSA allows both 32- and 64-bit integer tensors here.
232-
permAttr.getValues<APInt>(),
233-
[](const APInt &val) { return val.getSExtValue(); });
228+
op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
234229

235230
auto inputType = cast<ShapedType>(op.getInput1().getType());
236231

0 commit comments

Comments
 (0)
Please sign in to comment.