-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tosa] Change Transpose perms operand to attribute #128115
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir-linalg Author: Tai Ly (Tai78641) ChangesThis patch changes the perms operand for Tosa Transpose operator to an i32 array attribute Patch is 109.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128115.diff 20 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3de1c21f40b43..a06e03f831985 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2011,7 +2011,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
let arguments = (ins
Tosa_Tensor:$input1,
- Tosa_Int32Tensor:$perms
+ DenseI32ArrayAttr:$perms
);
let results = (
@@ -2023,10 +2023,6 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];
- let extraClassDeclaration = [{
- LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
- }];
-
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a8fd536dd2548..42e88ee9026ac 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -329,13 +329,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
- Value weightPermValue =
- rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermValue);
+ weightPermAttr);
}
}
@@ -353,13 +351,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
- Value weightPermValue =
- rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermValue);
+ weightPermAttr);
}
// Extract the attributes for convolution.
@@ -970,9 +966,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const final {
- SmallVector<int32_t> constantPerms;
- if (failed(op.getConstantPerms(constantPerms)))
- return failure();
+ const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
Location loc = op.getLoc();
// The verifier should have made sure we have a valid TOSA permutation
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..8e2d8662ece8d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,13 +88,10 @@ struct ConsolidateTransposeOptimization
return rewriter.notifyMatchFailure(transposeOp,
"input must be transpose operation");
- SmallVector<int32_t> transposePerms, innerTransposePerms;
- if (transposeOp.getConstantPerms(transposePerms).failed())
- return rewriter.notifyMatchFailure(transposeOp,
- "transpose perms must be constant");
- if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
- return rewriter.notifyMatchFailure(
- transposeOp, "inner transpose perms must be constant");
+ const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
+ const llvm::ArrayRef<int32_t> innerTransposePerms =
+ innerTranspose.getPerms();
+
if (transposePerms.size() != innerTransposePerms.size())
return rewriter.notifyMatchFailure(
transposeOp,
@@ -108,15 +105,9 @@ struct ConsolidateTransposeOptimization
for (int i = 0, s = transposePerms.size(); i < s; ++i)
perms[i] = innerTransposePerms[transposePerms[i]];
- auto permsTy =
- RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
- auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
- Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
- permsTy, permsAttr);
-
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
transposeOp, transposeOp.getResult().getType(),
- innerTranspose.getInput1(), permsValue);
+ innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
return success();
}
@@ -128,10 +119,6 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
- DenseIntElementsAttr permAttr;
- if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
- return rewriter.notifyMatchFailure(op, "Non-constant permutation");
-
if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
return rewriter.notifyMatchFailure(
op, "Src is from transpose, can compose transposes");
@@ -156,9 +143,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
if (numDynDims > 1)
return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
- SmallVector<int64_t> permValues = llvm::to_vector<6>(
- llvm::map_range(permAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ const llvm::ArrayRef<int32_t> permValues = op.getPerms();
SmallVector<int64_t> nonZeroPerms;
nonZeroPerms.reserve(permValues.size());
@@ -1175,9 +1160,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
}
// Transpose is not the identity transpose.
- SmallVector<int32_t> perms;
- if (getConstantPerms(perms).failed())
- return {};
+ const llvm::ArrayRef<int32_t> perms = getPerms();
if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e9c33e1b1bf10..7030dccd693a4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1374,41 +1374,22 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
return mlir::success();
}
-LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
- // Perms must be constants.
- DenseIntElementsAttr permsAttr;
- if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
- return failure();
-
- perms.clear();
- for (auto v : permsAttr.getValues<APInt>())
- perms.push_back(v.getSExtValue());
-
- return success();
-}
-
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TransposeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
- ShapeAdaptor permsShape(adaptor.getPerms().getType());
-
- // We cannot infer anything from a rank-0 "permutation" tensor.
- if (permsShape.hasRank() && permsShape.getRank() == 0)
- return failure();
// If input rank and permutation length is unknown, the output rank is
// unknown.
- if (!inputShape.hasRank() || !permsShape.hasRank() ||
- permsShape.isDynamicDim(0)) {
+ if (!inputShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
return success();
}
// This would imply the number of permutations does not match the rank of
// the input which is illegal.
- if (permsShape.getDimSize(0) != inputShape.getRank()) {
+ if (adaptor.getPerms().size() != static_cast<size_t>(inputShape.getRank())) {
return failure();
}
@@ -1437,28 +1418,16 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
- // If the permuations are a constant we can directly determine the output
- // shape.
- DenseIntElementsAttr attr;
- if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
- attr.getType().getRank() == 1) {
- ShapeAdaptor permShape = attr;
- // Constant permutation must be the same length as the input rank.
- if (inputShape.getRank() != permShape.getRank())
- return emitOptionalError(location,
- "constant permutation must be the same length"
- " as the input rank");
-
- // Constant permutation values must be within the input rank.
- for (int i = 0, e = inputShape.getRank(); i < e; i++) {
- if (inputShape.getRank() <= permShape.getDimSize(i))
- return failure();
- }
- outputShape.reserve(inputShape.getRank());
- for (int i = 0, s = inputShape.getRank(); i < s; i++) {
- outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
- }
+ // Constant permutation values must be within the input rank.
+ for (auto i : adaptor.getPerms()) {
+ if (inputShape.getRank() <= i)
+ return failure();
+ }
+
+ outputShape.reserve(inputShape.getRank());
+ for (int i = 0, s = inputShape.getRank(); i < s; i++) {
+ outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
}
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -1467,75 +1436,61 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
LogicalResult tosa::TransposeOp::verify() {
TensorType inputType = getInput1().getType();
- TensorType permType = getPerms().getType();
TensorType outputType = getOutput().getType();
+ const llvm::ArrayRef<int32_t> constantPerms = getPerms();
- if (permType.hasRank() && permType.getRank() != 1)
- return emitOpError()
- << "expected permutation tensor to be rank 1 but got rank "
- << permType.getRank();
- if (inputType.hasRank() && permType.hasRank())
- if (!permType.isDynamicDim(0) &&
- permType.getDimSize(0) != inputType.getRank())
- return emitOpError() << "expected permutation tensor dim 0 to have size "
+ if (inputType.hasRank())
+ if (constantPerms.size() != static_cast<size_t>(inputType.getRank()))
+ return emitOpError() << "expected perms attribute to have size "
<< inputType.getRank()
<< " (input rank) but got size "
- << permType.getDimSize(0);
+ << constantPerms.size();
if (inputType.hasRank() && outputType.hasRank() &&
inputType.getRank() != outputType.getRank())
return emitOpError()
<< "expected input tensor rank to equal result tensor rank";
- if (outputType.hasRank() && permType.hasRank())
- if (!permType.isDynamicDim(0) &&
- permType.getDimSize(0) != outputType.getRank())
- return emitOpError() << "expected permutation tensor dim 0 to have size "
+ if (outputType.hasRank())
+ if (constantPerms.size() != static_cast<size_t>(outputType.getRank()))
+ return emitOpError() << "expected perms attribute to have size "
<< outputType.getRank()
<< " (output rank) but got size "
- << permType.getDimSize(0);
-
- SmallVector<int32_t> constantPerms;
- if (succeeded(getConstantPerms(constantPerms))) {
- // Assert that the permutation tensor has a rank, which means that the
- // rank has been verified above.
- assert(permType.hasRank() &&
- "Unexpectedly found permutation tensor without rank");
- if (!llvm::all_of(constantPerms,
- [&constantPerms](int32_t s) {
- return s >= 0 &&
- static_cast<size_t>(s) < constantPerms.size();
- }) ||
- !isPermutationVector(llvm::to_vector(llvm::map_range(
- constantPerms, [](int32_t v) -> int64_t { return v; }))))
- return emitOpError() << "expected valid permutation tensor";
-
- // Verify that the types of the input and output tensors are properly
- // permuted.
- if (inputType.hasRank() && outputType.hasRank()) {
- assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
- inputType.getRank() == outputType.getRank());
-
- for (auto i = 0; i < outputType.getRank(); i++) {
- if (inputType.isDynamicDim(constantPerms[i]) ||
- outputType.isDynamicDim(i))
- continue;
-
- if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
- return emitOpError()
- << "expected output tensor dim " << i << " to match "
- << "input dim " << constantPerms[i] << " with value of "
- << inputType.getDimSize(constantPerms[i]);
- }
+ << constantPerms.size();
+
+ if (!llvm::all_of(constantPerms,
+ [&constantPerms](int32_t s) {
+ return s >= 0 &&
+ static_cast<size_t>(s) < constantPerms.size();
+ }) ||
+ !isPermutationVector(llvm::to_vector(llvm::map_range(
+ constantPerms, [](int32_t v) -> int64_t { return v; }))))
+ return emitOpError() << "expected valid permutation indices";
+
+ // Verify that the types of the input and output tensors are properly
+ // permuted.
+ if (inputType.hasRank() && outputType.hasRank()) {
+ assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
+ inputType.getRank() == outputType.getRank());
+
+ for (auto i = 0; i < outputType.getRank(); i++) {
+ if (inputType.isDynamicDim(constantPerms[i]) ||
+ outputType.isDynamicDim(i))
+ continue;
+
+ if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
+ return emitOpError()
+ << "expected output tensor dim " << i << " to match "
+ << "input dim " << constantPerms[i] << " with value of "
+ << inputType.getDimSize(constantPerms[i]);
}
}
+
return success();
}
LogicalResult TransposeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
- SmallVector<int32_t> transposePerms;
- if (getConstantPerms(transposePerms).failed())
- return failure();
+ const llvm::ArrayRef<int32_t> transposePerms = getPerms();
Value input = getInput1();
auto inputType = cast<TensorType>(input.getType());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 26baddcf1dd15..61011b6df4617 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -166,13 +166,9 @@ class TransposeConvStridedConverter
getTosaConstShape(rewriter, loc, weightReshapeDims0));
// Transpose the factored-out stride to the output channels.
- Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
- loc, RankedTensorType::get({6}, rewriter.getI32Type()),
- rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
-
weight = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- transposeWeightVal);
+ rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
// Collapse the strides and output channels into a single dimension.
llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
@@ -269,13 +265,9 @@ class TransposeConvStridedConverter
convReshapeDims0Value);
// Transpose the factored-out stride to the output channels.
- Value transposeConvVal = rewriter.create<tosa::ConstOp>(
- loc, RankedTensorType::get({6}, rewriter.getI32Type()),
- rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
-
conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
- transposeConvVal);
+ rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 403ac48b91559..43e9507b4d95a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -224,13 +224,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
return failure();
- DenseIntElementsAttr permAttr;
- if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
- return failure();
auto permValues = llvm::map_to_vector(
- // TOSA allows both 32- and 64-bit integer tensors here.
- permAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); });
+ op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
auto inputType = cast<ShapedType>(op.getInput1().getType());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 64e5c31793f84..d4d8aae8b0316 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -367,9 +367,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
std::optional<Value> TosaReduceTransposes::buildMappedToValue(
TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
- SmallVector<int32_t> perms;
- if (failed(transposeOp.getConstantPerms(perms)) ||
- !areInvolutionTransposes(hoistedPerms, perms))
+ if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
return std::nullopt;
return transposeOp.getInput1();
}
@@ -506,14 +504,11 @@ bool TosaReduceTransposes::dependenciesAreValid(
// replaced.
Operation *user = use.getOwner();
if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
- SmallVector<int32_t> otherPerms;
-
// Can later think about cases where transpose -> transpose
// or reshape -> transpose, where the transposes are not necessarily
// the same perms as the hoisted, if implementing a more general
// transform. These could be permitted.
- if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
- !llvm::equal(perms, otherPerms))
+ if (!llvm::equal(perms, otherTranspose.getPerms()))
return false;
} else if (userNotContainedInValidTransposeDependencies(
user, validTransposes, transposeInfo)) {
@@ -607,9 +602,9 @@ void TosaReduceTransposes::runOnOperation() {
!llvm::isa<RankedTensorType>(output.getType()))
return;
- // No transformation when transpose permutation non-constant.
- if (failed(transposeOp.getConstantPerms(perms)))
- return;
+ for (int32_t v : transposeOp.getPerms()) {
+ perms.push_back(v);
+ }
// We let --canonicalize deal with identity transpose.
if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..f2abb29b4fe66 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -56,15 +56,6 @@ static LogicalResult checkConstantOperandPad(Operation *op) {
return success();
}
-static LogicalResult checkConstantOpe...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) ChangesThis patch changes the perms operand for Tosa Transpose operator to an i32 array attribute Patch is 109.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128115.diff 20 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3de1c21f40b43..a06e03f831985 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2011,7 +2011,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
let arguments = (ins
Tosa_Tensor:$input1,
- Tosa_Int32Tensor:$perms
+ DenseI32ArrayAttr:$perms
);
let results = (
@@ -2023,10 +2023,6 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];
- let extraClassDeclaration = [{
- LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
- }];
-
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a8fd536dd2548..42e88ee9026ac 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -329,13 +329,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
- Value weightPermValue =
- rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermValue);
+ weightPermAttr);
}
}
@@ -353,13 +351,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
- Value weightPermValue =
- rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermValue);
+ weightPermAttr);
}
// Extract the attributes for convolution.
@@ -970,9 +966,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const final {
- SmallVector<int32_t> constantPerms;
- if (failed(op.getConstantPerms(constantPerms)))
- return failure();
+ const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
Location loc = op.getLoc();
// The verifier should have made sure we have a valid TOSA permutation
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..8e2d8662ece8d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,13 +88,10 @@ struct ConsolidateTransposeOptimization
return rewriter.notifyMatchFailure(transposeOp,
"input must be transpose operation");
- SmallVector<int32_t> transposePerms, innerTransposePerms;
- if (transposeOp.getConstantPerms(transposePerms).failed())
- return rewriter.notifyMatchFailure(transposeOp,
- "transpose perms must be constant");
- if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
- return rewriter.notifyMatchFailure(
- transposeOp, "inner transpose perms must be constant");
+ const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
+ const llvm::ArrayRef<int32_t> innerTransposePerms =
+ innerTranspose.getPerms();
+
if (transposePerms.size() != innerTransposePerms.size())
return rewriter.notifyMatchFailure(
transposeOp,
@@ -108,15 +105,9 @@ struct ConsolidateTransposeOptimization
for (int i = 0, s = transposePerms.size(); i < s; ++i)
perms[i] = innerTransposePerms[transposePerms[i]];
- auto permsTy =
- RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
- auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
- Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
- permsTy, permsAttr);
-
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
transposeOp, transposeOp.getResult().getType(),
- innerTranspose.getInput1(), permsValue);
+ innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
return success();
}
@@ -128,10 +119,6 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
- DenseIntElementsAttr permAttr;
- if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
- return rewriter.notifyMatchFailure(op, "Non-constant permutation");
-
if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
return rewriter.notifyMatchFailure(
op, "Src is from transpose, can compose transposes");
@@ -156,9 +143,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
if (numDynDims > 1)
return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
- SmallVector<int64_t> permValues = llvm::to_vector<6>(
- llvm::map_range(permAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ const llvm::ArrayRef<int32_t> permValues = op.getPerms();
SmallVector<int64_t> nonZeroPerms;
nonZeroPerms.reserve(permValues.size());
@@ -1175,9 +1160,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
}
// Transpose is not the identity transpose.
- SmallVector<int32_t> perms;
- if (getConstantPerms(perms).failed())
- return {};
+ const llvm::ArrayRef<int32_t> perms = getPerms();
if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e9c33e1b1bf10..7030dccd693a4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1374,41 +1374,22 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
return mlir::success();
}
-LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
- // Perms must be constants.
- DenseIntElementsAttr permsAttr;
- if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
- return failure();
-
- perms.clear();
- for (auto v : permsAttr.getValues<APInt>())
- perms.push_back(v.getSExtValue());
-
- return success();
-}
-
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TransposeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
- ShapeAdaptor permsShape(adaptor.getPerms().getType());
-
- // We cannot infer anything from a rank-0 "permutation" tensor.
- if (permsShape.hasRank() && permsShape.getRank() == 0)
- return failure();
// If input rank and permutation length is unknown, the output rank is
// unknown.
- if (!inputShape.hasRank() || !permsShape.hasRank() ||
- permsShape.isDynamicDim(0)) {
+ if (!inputShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
return success();
}
// This would imply the number of permutations does not match the rank of
// the input which is illegal.
- if (permsShape.getDimSize(0) != inputShape.getRank()) {
+ if (adaptor.getPerms().size() != static_cast<size_t>(inputShape.getRank())) {
return failure();
}
@@ -1437,28 +1418,16 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
- // If the permuations are a constant we can directly determine the output
- // shape.
- DenseIntElementsAttr attr;
- if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
- attr.getType().getRank() == 1) {
- ShapeAdaptor permShape = attr;
- // Constant permutation must be the same length as the input rank.
- if (inputShape.getRank() != permShape.getRank())
- return emitOptionalError(location,
- "constant permutation must be the same length"
- " as the input rank");
-
- // Constant permutation values must be within the input rank.
- for (int i = 0, e = inputShape.getRank(); i < e; i++) {
- if (inputShape.getRank() <= permShape.getDimSize(i))
- return failure();
- }
- outputShape.reserve(inputShape.getRank());
- for (int i = 0, s = inputShape.getRank(); i < s; i++) {
- outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
- }
+ // Constant permutation values must be within the input rank.
+ for (auto i : adaptor.getPerms()) {
+ if (inputShape.getRank() <= i)
+ return failure();
+ }
+
+ outputShape.reserve(inputShape.getRank());
+ for (int i = 0, s = inputShape.getRank(); i < s; i++) {
+ outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
}
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -1467,75 +1436,61 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
LogicalResult tosa::TransposeOp::verify() {
TensorType inputType = getInput1().getType();
- TensorType permType = getPerms().getType();
TensorType outputType = getOutput().getType();
+ const llvm::ArrayRef<int32_t> constantPerms = getPerms();
- if (permType.hasRank() && permType.getRank() != 1)
- return emitOpError()
- << "expected permutation tensor to be rank 1 but got rank "
- << permType.getRank();
- if (inputType.hasRank() && permType.hasRank())
- if (!permType.isDynamicDim(0) &&
- permType.getDimSize(0) != inputType.getRank())
- return emitOpError() << "expected permutation tensor dim 0 to have size "
+ if (inputType.hasRank())
+ if (constantPerms.size() != static_cast<size_t>(inputType.getRank()))
+ return emitOpError() << "expected perms attribute to have size "
<< inputType.getRank()
<< " (input rank) but got size "
- << permType.getDimSize(0);
+ << constantPerms.size();
if (inputType.hasRank() && outputType.hasRank() &&
inputType.getRank() != outputType.getRank())
return emitOpError()
<< "expected input tensor rank to equal result tensor rank";
- if (outputType.hasRank() && permType.hasRank())
- if (!permType.isDynamicDim(0) &&
- permType.getDimSize(0) != outputType.getRank())
- return emitOpError() << "expected permutation tensor dim 0 to have size "
+ if (outputType.hasRank())
+ if (constantPerms.size() != static_cast<size_t>(outputType.getRank()))
+ return emitOpError() << "expected perms attribute to have size "
<< outputType.getRank()
<< " (output rank) but got size "
- << permType.getDimSize(0);
-
- SmallVector<int32_t> constantPerms;
- if (succeeded(getConstantPerms(constantPerms))) {
- // Assert that the permutation tensor has a rank, which means that the
- // rank has been verified above.
- assert(permType.hasRank() &&
- "Unexpectedly found permutation tensor without rank");
- if (!llvm::all_of(constantPerms,
- [&constantPerms](int32_t s) {
- return s >= 0 &&
- static_cast<size_t>(s) < constantPerms.size();
- }) ||
- !isPermutationVector(llvm::to_vector(llvm::map_range(
- constantPerms, [](int32_t v) -> int64_t { return v; }))))
- return emitOpError() << "expected valid permutation tensor";
-
- // Verify that the types of the input and output tensors are properly
- // permuted.
- if (inputType.hasRank() && outputType.hasRank()) {
- assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
- inputType.getRank() == outputType.getRank());
-
- for (auto i = 0; i < outputType.getRank(); i++) {
- if (inputType.isDynamicDim(constantPerms[i]) ||
- outputType.isDynamicDim(i))
- continue;
-
- if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
- return emitOpError()
- << "expected output tensor dim " << i << " to match "
- << "input dim " << constantPerms[i] << " with value of "
- << inputType.getDimSize(constantPerms[i]);
- }
+ << constantPerms.size();
+
+ if (!llvm::all_of(constantPerms,
+ [&constantPerms](int32_t s) {
+ return s >= 0 &&
+ static_cast<size_t>(s) < constantPerms.size();
+ }) ||
+ !isPermutationVector(llvm::to_vector(llvm::map_range(
+ constantPerms, [](int32_t v) -> int64_t { return v; }))))
+ return emitOpError() << "expected valid permutation indices";
+
+ // Verify that the types of the input and output tensors are properly
+ // permuted.
+ if (inputType.hasRank() && outputType.hasRank()) {
+ assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
+ inputType.getRank() == outputType.getRank());
+
+ for (auto i = 0; i < outputType.getRank(); i++) {
+ if (inputType.isDynamicDim(constantPerms[i]) ||
+ outputType.isDynamicDim(i))
+ continue;
+
+ if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
+ return emitOpError()
+ << "expected output tensor dim " << i << " to match "
+ << "input dim " << constantPerms[i] << " with value of "
+ << inputType.getDimSize(constantPerms[i]);
}
}
+
return success();
}
LogicalResult TransposeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
- SmallVector<int32_t> transposePerms;
- if (getConstantPerms(transposePerms).failed())
- return failure();
+ const llvm::ArrayRef<int32_t> transposePerms = getPerms();
Value input = getInput1();
auto inputType = cast<TensorType>(input.getType());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 26baddcf1dd15..61011b6df4617 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -166,13 +166,9 @@ class TransposeConvStridedConverter
getTosaConstShape(rewriter, loc, weightReshapeDims0));
// Transpose the factored-out stride to the output channels.
- Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
- loc, RankedTensorType::get({6}, rewriter.getI32Type()),
- rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
-
weight = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- transposeWeightVal);
+ rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
// Collapse the strides and output channels into a single dimension.
llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
@@ -269,13 +265,9 @@ class TransposeConvStridedConverter
convReshapeDims0Value);
// Transpose the factored-out stride to the output channels.
- Value transposeConvVal = rewriter.create<tosa::ConstOp>(
- loc, RankedTensorType::get({6}, rewriter.getI32Type()),
- rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
-
conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
- transposeConvVal);
+ rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 403ac48b91559..43e9507b4d95a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -224,13 +224,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
return failure();
- DenseIntElementsAttr permAttr;
- if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
- return failure();
auto permValues = llvm::map_to_vector(
- // TOSA allows both 32- and 64-bit integer tensors here.
- permAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); });
+ op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
auto inputType = cast<ShapedType>(op.getInput1().getType());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 64e5c31793f84..d4d8aae8b0316 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -367,9 +367,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
std::optional<Value> TosaReduceTransposes::buildMappedToValue(
TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
- SmallVector<int32_t> perms;
- if (failed(transposeOp.getConstantPerms(perms)) ||
- !areInvolutionTransposes(hoistedPerms, perms))
+ if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
return std::nullopt;
return transposeOp.getInput1();
}
@@ -506,14 +504,11 @@ bool TosaReduceTransposes::dependenciesAreValid(
// replaced.
Operation *user = use.getOwner();
if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
- SmallVector<int32_t> otherPerms;
-
// Can later think about cases where transpose -> transpose
// or reshape -> transpose, where the transposes are not necessarily
// the same perms as the hoisted, if implementing a more general
// transform. These could be permitted.
- if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
- !llvm::equal(perms, otherPerms))
+ if (!llvm::equal(perms, otherTranspose.getPerms()))
return false;
} else if (userNotContainedInValidTransposeDependencies(
user, validTransposes, transposeInfo)) {
@@ -607,9 +602,9 @@ void TosaReduceTransposes::runOnOperation() {
!llvm::isa<RankedTensorType>(output.getType()))
return;
- // No transformation when transpose permutation non-constant.
- if (failed(transposeOp.getConstantPerms(perms)))
- return;
+ for (int32_t v : transposeOp.getPerms()) {
+ perms.push_back(v);
+ }
// We let --canonicalize deal with identity transpose.
if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..f2abb29b4fe66 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -56,15 +56,6 @@ static LogicalResult checkConstantOperandPad(Operation *op) {
return success();
}
-static LogicalResult checkConstantOpe...
[truncated]
|
@llvm/pr-subscribers-mlir-memref Author: Tai Ly (Tai78641) ChangesThis patch changes the perms operand for Tosa Transpose operator to an i32 array attribute Patch is 109.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128115.diff 20 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3de1c21f40b43..a06e03f831985 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2011,7 +2011,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
let arguments = (ins
Tosa_Tensor:$input1,
- Tosa_Int32Tensor:$perms
+ DenseI32ArrayAttr:$perms
);
let results = (
@@ -2023,10 +2023,6 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];
- let extraClassDeclaration = [{
- LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
- }];
-
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a8fd536dd2548..42e88ee9026ac 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -329,13 +329,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
- Value weightPermValue =
- rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermValue);
+ weightPermAttr);
}
}
@@ -353,13 +351,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
- auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
- Value weightPermValue =
- rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+ auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
- weightPermValue);
+ weightPermAttr);
}
// Extract the attributes for convolution.
@@ -970,9 +966,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const final {
- SmallVector<int32_t> constantPerms;
- if (failed(op.getConstantPerms(constantPerms)))
- return failure();
+ const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
Location loc = op.getLoc();
// The verifier should have made sure we have a valid TOSA permutation
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..8e2d8662ece8d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,13 +88,10 @@ struct ConsolidateTransposeOptimization
return rewriter.notifyMatchFailure(transposeOp,
"input must be transpose operation");
- SmallVector<int32_t> transposePerms, innerTransposePerms;
- if (transposeOp.getConstantPerms(transposePerms).failed())
- return rewriter.notifyMatchFailure(transposeOp,
- "transpose perms must be constant");
- if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
- return rewriter.notifyMatchFailure(
- transposeOp, "inner transpose perms must be constant");
+ const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
+ const llvm::ArrayRef<int32_t> innerTransposePerms =
+ innerTranspose.getPerms();
+
if (transposePerms.size() != innerTransposePerms.size())
return rewriter.notifyMatchFailure(
transposeOp,
@@ -108,15 +105,9 @@ struct ConsolidateTransposeOptimization
for (int i = 0, s = transposePerms.size(); i < s; ++i)
perms[i] = innerTransposePerms[transposePerms[i]];
- auto permsTy =
- RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
- auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
- Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
- permsTy, permsAttr);
-
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
transposeOp, transposeOp.getResult().getType(),
- innerTranspose.getInput1(), permsValue);
+ innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
return success();
}
@@ -128,10 +119,6 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
- DenseIntElementsAttr permAttr;
- if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
- return rewriter.notifyMatchFailure(op, "Non-constant permutation");
-
if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
return rewriter.notifyMatchFailure(
op, "Src is from transpose, can compose transposes");
@@ -156,9 +143,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
if (numDynDims > 1)
return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
- SmallVector<int64_t> permValues = llvm::to_vector<6>(
- llvm::map_range(permAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ const llvm::ArrayRef<int32_t> permValues = op.getPerms();
SmallVector<int64_t> nonZeroPerms;
nonZeroPerms.reserve(permValues.size());
@@ -1175,9 +1160,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
}
// Transpose is not the identity transpose.
- SmallVector<int32_t> perms;
- if (getConstantPerms(perms).failed())
- return {};
+ const llvm::ArrayRef<int32_t> perms = getPerms();
if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e9c33e1b1bf10..7030dccd693a4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1374,41 +1374,22 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
return mlir::success();
}
-LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
- // Perms must be constants.
- DenseIntElementsAttr permsAttr;
- if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
- return failure();
-
- perms.clear();
- for (auto v : permsAttr.getValues<APInt>())
- perms.push_back(v.getSExtValue());
-
- return success();
-}
-
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TransposeOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape(adaptor.getInput1().getType());
- ShapeAdaptor permsShape(adaptor.getPerms().getType());
-
- // We cannot infer anything from a rank-0 "permutation" tensor.
- if (permsShape.hasRank() && permsShape.getRank() == 0)
- return failure();
// If input rank and permutation length is unknown, the output rank is
// unknown.
- if (!inputShape.hasRank() || !permsShape.hasRank() ||
- permsShape.isDynamicDim(0)) {
+ if (!inputShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents());
return success();
}
// This would imply the number of permutations does not match the rank of
// the input which is illegal.
- if (permsShape.getDimSize(0) != inputShape.getRank()) {
+ if (adaptor.getPerms().size() != static_cast<size_t>(inputShape.getRank())) {
return failure();
}
@@ -1437,28 +1418,16 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
- // If the permuations are a constant we can directly determine the output
- // shape.
- DenseIntElementsAttr attr;
- if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
- attr.getType().getRank() == 1) {
- ShapeAdaptor permShape = attr;
- // Constant permutation must be the same length as the input rank.
- if (inputShape.getRank() != permShape.getRank())
- return emitOptionalError(location,
- "constant permutation must be the same length"
- " as the input rank");
-
- // Constant permutation values must be within the input rank.
- for (int i = 0, e = inputShape.getRank(); i < e; i++) {
- if (inputShape.getRank() <= permShape.getDimSize(i))
- return failure();
- }
- outputShape.reserve(inputShape.getRank());
- for (int i = 0, s = inputShape.getRank(); i < s; i++) {
- outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
- }
+ // Constant permutation values must be within the input rank.
+ for (auto i : adaptor.getPerms()) {
+ if (inputShape.getRank() <= i)
+ return failure();
+ }
+
+ outputShape.reserve(inputShape.getRank());
+ for (int i = 0, s = inputShape.getRank(); i < s; i++) {
+ outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
}
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -1467,75 +1436,61 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
LogicalResult tosa::TransposeOp::verify() {
TensorType inputType = getInput1().getType();
- TensorType permType = getPerms().getType();
TensorType outputType = getOutput().getType();
+ const llvm::ArrayRef<int32_t> constantPerms = getPerms();
- if (permType.hasRank() && permType.getRank() != 1)
- return emitOpError()
- << "expected permutation tensor to be rank 1 but got rank "
- << permType.getRank();
- if (inputType.hasRank() && permType.hasRank())
- if (!permType.isDynamicDim(0) &&
- permType.getDimSize(0) != inputType.getRank())
- return emitOpError() << "expected permutation tensor dim 0 to have size "
+ if (inputType.hasRank())
+ if (constantPerms.size() != static_cast<size_t>(inputType.getRank()))
+ return emitOpError() << "expected perms attribute to have size "
<< inputType.getRank()
<< " (input rank) but got size "
- << permType.getDimSize(0);
+ << constantPerms.size();
if (inputType.hasRank() && outputType.hasRank() &&
inputType.getRank() != outputType.getRank())
return emitOpError()
<< "expected input tensor rank to equal result tensor rank";
- if (outputType.hasRank() && permType.hasRank())
- if (!permType.isDynamicDim(0) &&
- permType.getDimSize(0) != outputType.getRank())
- return emitOpError() << "expected permutation tensor dim 0 to have size "
+ if (outputType.hasRank())
+ if (constantPerms.size() != static_cast<size_t>(outputType.getRank()))
+ return emitOpError() << "expected perms attribute to have size "
<< outputType.getRank()
<< " (output rank) but got size "
- << permType.getDimSize(0);
-
- SmallVector<int32_t> constantPerms;
- if (succeeded(getConstantPerms(constantPerms))) {
- // Assert that the permutation tensor has a rank, which means that the
- // rank has been verified above.
- assert(permType.hasRank() &&
- "Unexpectedly found permutation tensor without rank");
- if (!llvm::all_of(constantPerms,
- [&constantPerms](int32_t s) {
- return s >= 0 &&
- static_cast<size_t>(s) < constantPerms.size();
- }) ||
- !isPermutationVector(llvm::to_vector(llvm::map_range(
- constantPerms, [](int32_t v) -> int64_t { return v; }))))
- return emitOpError() << "expected valid permutation tensor";
-
- // Verify that the types of the input and output tensors are properly
- // permuted.
- if (inputType.hasRank() && outputType.hasRank()) {
- assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
- inputType.getRank() == outputType.getRank());
-
- for (auto i = 0; i < outputType.getRank(); i++) {
- if (inputType.isDynamicDim(constantPerms[i]) ||
- outputType.isDynamicDim(i))
- continue;
-
- if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
- return emitOpError()
- << "expected output tensor dim " << i << " to match "
- << "input dim " << constantPerms[i] << " with value of "
- << inputType.getDimSize(constantPerms[i]);
- }
+ << constantPerms.size();
+
+ if (!llvm::all_of(constantPerms,
+ [&constantPerms](int32_t s) {
+ return s >= 0 &&
+ static_cast<size_t>(s) < constantPerms.size();
+ }) ||
+ !isPermutationVector(llvm::to_vector(llvm::map_range(
+ constantPerms, [](int32_t v) -> int64_t { return v; }))))
+ return emitOpError() << "expected valid permutation indices";
+
+ // Verify that the types of the input and output tensors are properly
+ // permuted.
+ if (inputType.hasRank() && outputType.hasRank()) {
+ assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
+ inputType.getRank() == outputType.getRank());
+
+ for (auto i = 0; i < outputType.getRank(); i++) {
+ if (inputType.isDynamicDim(constantPerms[i]) ||
+ outputType.isDynamicDim(i))
+ continue;
+
+ if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
+ return emitOpError()
+ << "expected output tensor dim " << i << " to match "
+ << "input dim " << constantPerms[i] << " with value of "
+ << inputType.getDimSize(constantPerms[i]);
}
}
+
return success();
}
LogicalResult TransposeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
- SmallVector<int32_t> transposePerms;
- if (getConstantPerms(transposePerms).failed())
- return failure();
+ const llvm::ArrayRef<int32_t> transposePerms = getPerms();
Value input = getInput1();
auto inputType = cast<TensorType>(input.getType());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 26baddcf1dd15..61011b6df4617 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -166,13 +166,9 @@ class TransposeConvStridedConverter
getTosaConstShape(rewriter, loc, weightReshapeDims0));
// Transpose the factored-out stride to the output channels.
- Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
- loc, RankedTensorType::get({6}, rewriter.getI32Type()),
- rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
-
weight = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
- transposeWeightVal);
+ rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
// Collapse the strides and output channels into a single dimension.
llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
@@ -269,13 +265,9 @@ class TransposeConvStridedConverter
convReshapeDims0Value);
// Transpose the factored-out stride to the output channels.
- Value transposeConvVal = rewriter.create<tosa::ConstOp>(
- loc, RankedTensorType::get({6}, rewriter.getI32Type()),
- rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
-
conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
- transposeConvVal);
+ rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 403ac48b91559..43e9507b4d95a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -224,13 +224,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
return failure();
- DenseIntElementsAttr permAttr;
- if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
- return failure();
auto permValues = llvm::map_to_vector(
- // TOSA allows both 32- and 64-bit integer tensors here.
- permAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); });
+ op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
auto inputType = cast<ShapedType>(op.getInput1().getType());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 64e5c31793f84..d4d8aae8b0316 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -367,9 +367,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
std::optional<Value> TosaReduceTransposes::buildMappedToValue(
TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
- SmallVector<int32_t> perms;
- if (failed(transposeOp.getConstantPerms(perms)) ||
- !areInvolutionTransposes(hoistedPerms, perms))
+ if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
return std::nullopt;
return transposeOp.getInput1();
}
@@ -506,14 +504,11 @@ bool TosaReduceTransposes::dependenciesAreValid(
// replaced.
Operation *user = use.getOwner();
if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
- SmallVector<int32_t> otherPerms;
-
// Can later think about cases where transpose -> transpose
// or reshape -> transpose, where the transposes are not necessarily
// the same perms as the hoisted, if implementing a more general
// transform. These could be permitted.
- if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
- !llvm::equal(perms, otherPerms))
+ if (!llvm::equal(perms, otherTranspose.getPerms()))
return false;
} else if (userNotContainedInValidTransposeDependencies(
user, validTransposes, transposeInfo)) {
@@ -607,9 +602,9 @@ void TosaReduceTransposes::runOnOperation() {
!llvm::isa<RankedTensorType>(output.getType()))
return;
- // No transformation when transpose permutation non-constant.
- if (failed(transposeOp.getConstantPerms(perms)))
- return;
+ for (int32_t v : transposeOp.getPerms()) {
+ perms.push_back(v);
+ }
// We let --canonicalize deal with identity transpose.
if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..f2abb29b4fe66 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -56,15 +56,6 @@ static LogicalResult checkConstantOperandPad(Operation *op) {
return success();
}
-static LogicalResult checkConstantOpe...
[truncated]
|
This patch changes the perms operand for Tosa Transpose operator to an i32 array attribute Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I6c54d203ee7eb8d77442ff29fdbff2d2e3e6950b
3add291
to
2d54470
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - just had two minor non-blocking comments
@@ -223,7 +209,7 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32 | |||
%64 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: these constants seem no longer used
@@ -298,46 +275,41 @@ func.func @test_two_different_downstream_converge_to_reshape_same_perms(%arg0: t | |||
%shape = tosa.const_shape {value = dense<[1, 64, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: above const no longer used
There are some attribute changes upstream: - TOSA: llvm/llvm-project#128115 - NVPTX: llvm/llvm-project#127736 --------- Signed-off-by: yzhang93 <zhyuhang88@gmail.com>
There are some attribute changes upstream: - TOSA: llvm/llvm-project#128115 - NVPTX: llvm/llvm-project#127736 --------- Signed-off-by: yzhang93 <zhyuhang88@gmail.com> Signed-off-by: geomin12 <geomin12@amd.com>
This patch changes the perms operand for Tosa Transpose operator to an i32 array attribute