Skip to content

Commit 2d54470

Browse files
committed
[mlir][tosa] Change Transpose perms operand to attribute
This patch changes the perms operand for Tosa Transpose operator to an i32 array attribute Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I6c54d203ee7eb8d77442ff29fdbff2d2e3e6950b
1 parent 23aca2f commit 2d54470

20 files changed

+250
-509
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

+1-5
Original file line numberDiff line numberDiff line change
@@ -2023,7 +2023,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
20232023

20242024
let arguments = (ins
20252025
Tosa_Tensor:$input1,
2026-
Tosa_Int32Tensor:$perms
2026+
DenseI32ArrayAttr:$perms
20272027
);
20282028

20292029
let results = (
@@ -2035,10 +2035,6 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
20352035
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
20362036
];
20372037

2038-
let extraClassDeclaration = [{
2039-
LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
2040-
}];
2041-
20422038
let hasCanonicalizer = 1;
20432039
let hasFolder = 1;
20442040
let hasVerifier = 1;

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

+5-11
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
329329
SmallVector<int64_t> newWeightShape;
330330
for (auto dim : weightPerm)
331331
newWeightShape.push_back(weightShape[dim]);
332-
auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
333-
Value weightPermValue =
334-
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
332+
auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
335333
Type newWeightTy =
336334
RankedTensorType::get(newWeightShape, weightTy.getElementType());
337335
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
338-
weightPermValue);
336+
weightPermAttr);
339337
}
340338
}
341339

@@ -353,13 +351,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
353351
SmallVector<int64_t> newWeightShape;
354352
for (auto dim : weightPerm)
355353
newWeightShape.push_back(weightShape[dim]);
356-
auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
357-
Value weightPermValue =
358-
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
354+
auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
359355
Type newWeightTy =
360356
RankedTensorType::get(newWeightShape, weightTy.getElementType());
361357
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
362-
weightPermValue);
358+
weightPermAttr);
363359
}
364360

365361
// Extract the attributes for convolution.
@@ -970,9 +966,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
970966

971967
LogicalResult matchAndRewrite(tosa::TransposeOp op,
972968
PatternRewriter &rewriter) const final {
973-
SmallVector<int32_t> constantPerms;
974-
if (failed(op.getConstantPerms(constantPerms)))
975-
return failure();
969+
const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
976970

977971
Location loc = op.getLoc();
978972
// The verifier should have made sure we have a valid TOSA permutation

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

+7-24
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,10 @@ struct ConsolidateTransposeOptimization
8888
return rewriter.notifyMatchFailure(transposeOp,
8989
"input must be transpose operation");
9090

91-
SmallVector<int32_t> transposePerms, innerTransposePerms;
92-
if (transposeOp.getConstantPerms(transposePerms).failed())
93-
return rewriter.notifyMatchFailure(transposeOp,
94-
"transpose perms must be constant");
95-
if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
96-
return rewriter.notifyMatchFailure(
97-
transposeOp, "inner transpose perms must be constant");
91+
const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
92+
const llvm::ArrayRef<int32_t> innerTransposePerms =
93+
innerTranspose.getPerms();
94+
9895
if (transposePerms.size() != innerTransposePerms.size())
9996
return rewriter.notifyMatchFailure(
10097
transposeOp,
@@ -108,15 +105,9 @@ struct ConsolidateTransposeOptimization
108105
for (int i = 0, s = transposePerms.size(); i < s; ++i)
109106
perms[i] = innerTransposePerms[transposePerms[i]];
110107

111-
auto permsTy =
112-
RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
113-
auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
114-
Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
115-
permsTy, permsAttr);
116-
117108
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
118109
transposeOp, transposeOp.getResult().getType(),
119-
innerTranspose.getInput1(), permsValue);
110+
innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
120111

121112
return success();
122113
}
@@ -128,10 +119,6 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
128119

129120
LogicalResult matchAndRewrite(tosa::TransposeOp op,
130121
PatternRewriter &rewriter) const override {
131-
DenseIntElementsAttr permAttr;
132-
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
133-
return rewriter.notifyMatchFailure(op, "Non-constant permutation");
134-
135122
if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
136123
return rewriter.notifyMatchFailure(
137124
op, "Src is from transpose, can compose transposes");
@@ -156,9 +143,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
156143
if (numDynDims > 1)
157144
return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
158145

159-
SmallVector<int64_t> permValues = llvm::to_vector<6>(
160-
llvm::map_range(permAttr.getValues<APInt>(),
161-
[](const APInt &val) { return val.getSExtValue(); }));
146+
const llvm::ArrayRef<int32_t> permValues = op.getPerms();
162147

163148
SmallVector<int64_t> nonZeroPerms;
164149
nonZeroPerms.reserve(permValues.size());
@@ -1175,9 +1160,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
11751160
}
11761161

11771162
// Transpose is not the identity transpose.
1178-
SmallVector<int32_t> perms;
1179-
if (getConstantPerms(perms).failed())
1180-
return {};
1163+
const llvm::ArrayRef<int32_t> perms = getPerms();
11811164

11821165
if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
11831166
return {};

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

+56-101
Original file line numberDiff line numberDiff line change
@@ -1374,54 +1374,37 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
13741374
return mlir::success();
13751375
}
13761376

1377-
LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
1378-
// Perms must be constants.
1379-
DenseIntElementsAttr permsAttr;
1380-
if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
1381-
return failure();
1382-
1383-
perms.clear();
1384-
for (auto v : permsAttr.getValues<APInt>())
1385-
perms.push_back(v.getSExtValue());
1386-
1387-
return success();
1388-
}
1389-
13901377
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
13911378
MLIRContext *context, ::std::optional<Location> location,
13921379
TransposeOp::Adaptor adaptor,
13931380
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
13941381
ShapeAdaptor inputShape(adaptor.getInput1().getType());
1395-
ShapeAdaptor permsShape(adaptor.getPerms().getType());
1396-
1397-
// We cannot infer anything from a rank-0 "permutation" tensor.
1398-
if (permsShape.hasRank() && permsShape.getRank() == 0)
1399-
return failure();
14001382

14011383
// If input rank and permutation length is unknown, the output rank is
14021384
// unknown.
1403-
if (!inputShape.hasRank() || !permsShape.hasRank() ||
1404-
permsShape.isDynamicDim(0)) {
1385+
if (!inputShape.hasRank()) {
14051386
inferredReturnShapes.push_back(ShapedTypeComponents());
14061387
return success();
14071388
}
14081389

1390+
const auto inputRank = inputShape.getRank();
1391+
14091392
// This would imply the number of permutations does not match the rank of
14101393
// the input which is illegal.
1411-
if (permsShape.getDimSize(0) != inputShape.getRank()) {
1394+
if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
14121395
return failure();
14131396
}
14141397

14151398
SmallVector<int64_t> outputShape;
14161399
// Rank-0 means no permutations matter.
1417-
if (inputShape.getRank() == 0) {
1400+
if (inputRank == 0) {
14181401
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
14191402
return success();
14201403
}
14211404

14221405
// Check whether the input dimensions are all the same.
14231406
bool allTheSame = true;
1424-
for (int i = 1, s = inputShape.getRank(); i < s; i++) {
1407+
for (int i = 1, s = inputRank; i < s; i++) {
14251408
if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
14261409
allTheSame = false;
14271410
break;
@@ -1431,34 +1414,21 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
14311414
// If all of the input dimensions are the same we don't care about the
14321415
// permutation.
14331416
if (allTheSame) {
1434-
outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1417+
outputShape.resize(inputRank, inputShape.getDimSize(0));
14351418
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
14361419
return success();
14371420
}
14381421

1439-
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1440-
// If the permuations are a constant we can directly determine the output
1441-
// shape.
1442-
DenseIntElementsAttr attr;
1443-
if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
1444-
attr.getType().getRank() == 1) {
1445-
ShapeAdaptor permShape = attr;
1446-
// Constant permutation must be the same length as the input rank.
1447-
if (inputShape.getRank() != permShape.getRank())
1448-
return emitOptionalError(location,
1449-
"constant permutation must be the same length"
1450-
" as the input rank");
1451-
1452-
// Constant permutation values must be within the input rank.
1453-
for (int i = 0, e = inputShape.getRank(); i < e; i++) {
1454-
if (inputShape.getRank() <= permShape.getDimSize(i))
1455-
return failure();
1456-
}
1422+
outputShape.resize(inputRank, ShapedType::kDynamic);
14571423

1458-
outputShape.reserve(inputShape.getRank());
1459-
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1460-
outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
1461-
}
1424+
// Constant permutation values must be within the input rank.
1425+
if (llvm::any_of(adaptor.getPerms(),
1426+
[inputRank](const auto i) { return i >= inputRank; }))
1427+
return failure();
1428+
1429+
outputShape.reserve(inputRank);
1430+
for (int i = 0, s = inputRank; i < s; i++) {
1431+
outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
14621432
}
14631433

14641434
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -1467,75 +1437,60 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
14671437

14681438
LogicalResult tosa::TransposeOp::verify() {
14691439
TensorType inputType = getInput1().getType();
1470-
TensorType permType = getPerms().getType();
14711440
TensorType outputType = getOutput().getType();
1441+
const llvm::ArrayRef<int32_t> constantPerms = getPerms();
14721442

1473-
if (permType.hasRank() && permType.getRank() != 1)
1474-
return emitOpError()
1475-
<< "expected permutation tensor to be rank 1 but got rank "
1476-
<< permType.getRank();
1477-
if (inputType.hasRank() && permType.hasRank())
1478-
if (!permType.isDynamicDim(0) &&
1479-
permType.getDimSize(0) != inputType.getRank())
1480-
return emitOpError() << "expected permutation tensor dim 0 to have size "
1481-
<< inputType.getRank()
1482-
<< " (input rank) but got size "
1483-
<< permType.getDimSize(0);
1443+
if (inputType.hasRank() &&
1444+
constantPerms.size() != static_cast<size_t>(inputType.getRank()))
1445+
return emitOpError() << "expected perms attribute to have size "
1446+
<< inputType.getRank() << " (input rank) but got size "
1447+
<< constantPerms.size();
14841448
if (inputType.hasRank() && outputType.hasRank() &&
14851449
inputType.getRank() != outputType.getRank())
14861450
return emitOpError()
14871451
<< "expected input tensor rank to equal result tensor rank";
1488-
if (outputType.hasRank() && permType.hasRank())
1489-
if (!permType.isDynamicDim(0) &&
1490-
permType.getDimSize(0) != outputType.getRank())
1491-
return emitOpError() << "expected permutation tensor dim 0 to have size "
1492-
<< outputType.getRank()
1493-
<< " (output rank) but got size "
1494-
<< permType.getDimSize(0);
1495-
1496-
SmallVector<int32_t> constantPerms;
1497-
if (succeeded(getConstantPerms(constantPerms))) {
1498-
// Assert that the permutation tensor has a rank, which means that the
1499-
// rank has been verified above.
1500-
assert(permType.hasRank() &&
1501-
"Unexpectedly found permutation tensor without rank");
1502-
if (!llvm::all_of(constantPerms,
1503-
[&constantPerms](int32_t s) {
1504-
return s >= 0 &&
1505-
static_cast<size_t>(s) < constantPerms.size();
1506-
}) ||
1507-
!isPermutationVector(llvm::to_vector(llvm::map_range(
1508-
constantPerms, [](int32_t v) -> int64_t { return v; }))))
1509-
return emitOpError() << "expected valid permutation tensor";
1510-
1511-
// Verify that the types of the input and output tensors are properly
1512-
// permuted.
1513-
if (inputType.hasRank() && outputType.hasRank()) {
1514-
assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
1515-
inputType.getRank() == outputType.getRank());
1516-
1517-
for (auto i = 0; i < outputType.getRank(); i++) {
1518-
if (inputType.isDynamicDim(constantPerms[i]) ||
1519-
outputType.isDynamicDim(i))
1520-
continue;
1521-
1522-
if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
1523-
return emitOpError()
1524-
<< "expected output tensor dim " << i << " to match "
1525-
<< "input dim " << constantPerms[i] << " with value of "
1526-
<< inputType.getDimSize(constantPerms[i]);
1527-
}
1452+
if (outputType.hasRank() &&
1453+
constantPerms.size() != static_cast<size_t>(outputType.getRank()))
1454+
return emitOpError() << "expected perms attribute to have size "
1455+
<< outputType.getRank()
1456+
<< " (output rank) but got size "
1457+
<< constantPerms.size();
1458+
1459+
if (!llvm::all_of(constantPerms,
1460+
[&constantPerms](int32_t s) {
1461+
return s >= 0 &&
1462+
static_cast<size_t>(s) < constantPerms.size();
1463+
}) ||
1464+
!isPermutationVector(llvm::to_vector(llvm::map_range(
1465+
constantPerms, [](int32_t v) -> int64_t { return v; }))))
1466+
return emitOpError() << "expected valid permutation indices";
1467+
1468+
// Verify that the types of the input and output tensors are properly
1469+
// permuted.
1470+
if (inputType.hasRank() && outputType.hasRank()) {
1471+
assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
1472+
inputType.getRank() == outputType.getRank());
1473+
1474+
for (auto i = 0; i < outputType.getRank(); i++) {
1475+
if (inputType.isDynamicDim(constantPerms[i]) ||
1476+
outputType.isDynamicDim(i))
1477+
continue;
1478+
1479+
if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
1480+
return emitOpError()
1481+
<< "expected output tensor dim " << i << " to match "
1482+
<< "input dim " << constantPerms[i] << " with value of "
1483+
<< inputType.getDimSize(constantPerms[i]);
15281484
}
15291485
}
1486+
15301487
return success();
15311488
}
15321489

15331490
LogicalResult TransposeOp::reifyResultShapes(
15341491
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
15351492

1536-
SmallVector<int32_t> transposePerms;
1537-
if (getConstantPerms(transposePerms).failed())
1538-
return failure();
1493+
const llvm::ArrayRef<int32_t> transposePerms = getPerms();
15391494

15401495
Value input = getInput1();
15411496
auto inputType = cast<TensorType>(input.getType());

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

+2-10
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,9 @@ class TransposeConvStridedConverter
166166
getTosaConstShape(rewriter, loc, weightReshapeDims0));
167167

168168
// Transpose the factored-out stride to the output channels.
169-
Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
170-
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
171-
rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
172-
173169
weight = CreateOpAndInferShape<tosa::TransposeOp>(
174170
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
175-
transposeWeightVal);
171+
rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
176172

177173
// Collapse the strides and output channels into a single dimension.
178174
llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
@@ -269,13 +265,9 @@ class TransposeConvStridedConverter
269265
convReshapeDims0Value);
270266

271267
// Transpose the factored-out stride to the output channels.
272-
Value transposeConvVal = rewriter.create<tosa::ConstOp>(
273-
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
274-
rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
275-
276268
conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
277269
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
278-
transposeConvVal);
270+
rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
279271

280272
// Fuse striding behavior back into width / height.
281273
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

+1-6
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
224224
if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
225225
return failure();
226226

227-
DenseIntElementsAttr permAttr;
228-
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
229-
return failure();
230227
auto permValues = llvm::map_to_vector(
231-
// TOSA allows both 32- and 64-bit integer tensors here.
232-
permAttr.getValues<APInt>(),
233-
[](const APInt &val) { return val.getSExtValue(); });
228+
op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
234229

235230
auto inputType = cast<ShapedType>(op.getInput1().getType());
236231

0 commit comments

Comments
 (0)