Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tosa] Change Transpose perms operand to attribute #128115

Merged
merged 1 commit into from
Feb 25, 2025

Conversation

Tai78641
Copy link
Contributor

This patch changes the perms operand for Tosa Transpose operator to an i32 array attribute

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir-linalg

Author: Tai Ly (Tai78641)

Changes

This patch changes the perms operand for Tosa Transpose operator to an i32 array attribute


Patch is 109.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128115.diff

20 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-5)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+5-11)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+7-24)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+48-93)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+2-10)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp (+1-6)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp (+5-10)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (-10)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+4-12)
  • (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+2-4)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+1-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+16-9)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+18-28)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+21-47)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+1-2)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+3-6)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+6-12)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+2-21)
  • (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+91-171)
  • (modified) mlir/test/Dialect/Tosa/transpose-fold.mlir (+9-18)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3de1c21f40b43..a06e03f831985 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2011,7 +2011,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    Tosa_Int32Tensor:$perms
+    DenseI32ArrayAttr:$perms
   );
 
   let results = (
@@ -2023,10 +2023,6 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
     Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
   ];
 
-  let extraClassDeclaration = [{
-    LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
-  }];
-
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a8fd536dd2548..42e88ee9026ac 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -329,13 +329,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
         SmallVector<int64_t> newWeightShape;
         for (auto dim : weightPerm)
           newWeightShape.push_back(weightShape[dim]);
-        auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
-        Value weightPermValue =
-            rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+        auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
         Type newWeightTy =
             RankedTensorType::get(newWeightShape, weightTy.getElementType());
         weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
-                                                    weightPermValue);
+                                                    weightPermAttr);
       }
     }
 
@@ -353,13 +351,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       SmallVector<int64_t> newWeightShape;
       for (auto dim : weightPerm)
         newWeightShape.push_back(weightShape[dim]);
-      auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
-      Value weightPermValue =
-          rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+      auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
       Type newWeightTy =
           RankedTensorType::get(newWeightShape, weightTy.getElementType());
       weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
-                                                  weightPermValue);
+                                                  weightPermAttr);
     }
 
     // Extract the attributes for convolution.
@@ -970,9 +966,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
 
   LogicalResult matchAndRewrite(tosa::TransposeOp op,
                                 PatternRewriter &rewriter) const final {
-    SmallVector<int32_t> constantPerms;
-    if (failed(op.getConstantPerms(constantPerms)))
-      return failure();
+    const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
 
     Location loc = op.getLoc();
     // The verifier should have made sure we have a valid TOSA permutation
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..8e2d8662ece8d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,13 +88,10 @@ struct ConsolidateTransposeOptimization
       return rewriter.notifyMatchFailure(transposeOp,
                                          "input must be transpose operation");
 
-    SmallVector<int32_t> transposePerms, innerTransposePerms;
-    if (transposeOp.getConstantPerms(transposePerms).failed())
-      return rewriter.notifyMatchFailure(transposeOp,
-                                         "transpose perms must be constant");
-    if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
-      return rewriter.notifyMatchFailure(
-          transposeOp, "inner transpose perms must be constant");
+    const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
+    const llvm::ArrayRef<int32_t> innerTransposePerms =
+        innerTranspose.getPerms();
+
     if (transposePerms.size() != innerTransposePerms.size())
       return rewriter.notifyMatchFailure(
           transposeOp,
@@ -108,15 +105,9 @@ struct ConsolidateTransposeOptimization
     for (int i = 0, s = transposePerms.size(); i < s; ++i)
       perms[i] = innerTransposePerms[transposePerms[i]];
 
-    auto permsTy =
-        RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
-    auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
-    Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
-                                                      permsTy, permsAttr);
-
     rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
         transposeOp, transposeOp.getResult().getType(),
-        innerTranspose.getInput1(), permsValue);
+        innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
 
     return success();
   }
@@ -128,10 +119,6 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
 
   LogicalResult matchAndRewrite(tosa::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
-    DenseIntElementsAttr permAttr;
-    if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
-      return rewriter.notifyMatchFailure(op, "Non-constant permutation");
-
     if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
       return rewriter.notifyMatchFailure(
           op, "Src is from transpose, can compose transposes");
@@ -156,9 +143,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
     if (numDynDims > 1)
       return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
 
-    SmallVector<int64_t> permValues = llvm::to_vector<6>(
-        llvm::map_range(permAttr.getValues<APInt>(),
-                        [](const APInt &val) { return val.getSExtValue(); }));
+    const llvm::ArrayRef<int32_t> permValues = op.getPerms();
 
     SmallVector<int64_t> nonZeroPerms;
     nonZeroPerms.reserve(permValues.size());
@@ -1175,9 +1160,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
   }
 
   // Transpose is not the identity transpose.
-  SmallVector<int32_t> perms;
-  if (getConstantPerms(perms).failed())
-    return {};
+  const llvm::ArrayRef<int32_t> perms = getPerms();
 
   if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
     return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e9c33e1b1bf10..7030dccd693a4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1374,41 +1374,22 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
   return mlir::success();
 }
 
-LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
-  // Perms must be constants.
-  DenseIntElementsAttr permsAttr;
-  if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
-    return failure();
-
-  perms.clear();
-  for (auto v : permsAttr.getValues<APInt>())
-    perms.push_back(v.getSExtValue());
-
-  return success();
-}
-
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TransposeOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
-  ShapeAdaptor permsShape(adaptor.getPerms().getType());
-
-  // We cannot infer anything from a rank-0 "permutation" tensor.
-  if (permsShape.hasRank() && permsShape.getRank() == 0)
-    return failure();
 
   // If input rank and permutation length is unknown, the output rank is
   // unknown.
-  if (!inputShape.hasRank() || !permsShape.hasRank() ||
-      permsShape.isDynamicDim(0)) {
+  if (!inputShape.hasRank()) {
     inferredReturnShapes.push_back(ShapedTypeComponents());
     return success();
   }
 
   // This would imply the number of permutations does not match the rank of
   // the input which is illegal.
-  if (permsShape.getDimSize(0) != inputShape.getRank()) {
+  if (adaptor.getPerms().size() != static_cast<size_t>(inputShape.getRank())) {
     return failure();
   }
 
@@ -1437,28 +1418,16 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   }
 
   outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
-  // If the permuations are a constant we can directly determine the output
-  // shape.
-  DenseIntElementsAttr attr;
-  if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
-      attr.getType().getRank() == 1) {
-    ShapeAdaptor permShape = attr;
-    // Constant permutation must be the same length as the input rank.
-    if (inputShape.getRank() != permShape.getRank())
-      return emitOptionalError(location,
-                               "constant permutation must be the same length"
-                               " as the input rank");
-
-    // Constant permutation values must be within the input rank.
-    for (int i = 0, e = inputShape.getRank(); i < e; i++) {
-      if (inputShape.getRank() <= permShape.getDimSize(i))
-        return failure();
-    }
 
-    outputShape.reserve(inputShape.getRank());
-    for (int i = 0, s = inputShape.getRank(); i < s; i++) {
-      outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
-    }
+  // Constant permutation values must be within the input rank.
+  for (auto i : adaptor.getPerms()) {
+    if (inputShape.getRank() <= i)
+      return failure();
+  }
+
+  outputShape.reserve(inputShape.getRank());
+  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
+    outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
   }
 
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -1467,75 +1436,61 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
 
 LogicalResult tosa::TransposeOp::verify() {
   TensorType inputType = getInput1().getType();
-  TensorType permType = getPerms().getType();
   TensorType outputType = getOutput().getType();
+  const llvm::ArrayRef<int32_t> constantPerms = getPerms();
 
-  if (permType.hasRank() && permType.getRank() != 1)
-    return emitOpError()
-           << "expected permutation tensor to be rank 1 but got rank "
-           << permType.getRank();
-  if (inputType.hasRank() && permType.hasRank())
-    if (!permType.isDynamicDim(0) &&
-        permType.getDimSize(0) != inputType.getRank())
-      return emitOpError() << "expected permutation tensor dim 0 to have size "
+  if (inputType.hasRank())
+    if (constantPerms.size() != static_cast<size_t>(inputType.getRank()))
+      return emitOpError() << "expected perms attribute to have size "
                            << inputType.getRank()
                            << " (input rank) but got size "
-                           << permType.getDimSize(0);
+                           << constantPerms.size();
   if (inputType.hasRank() && outputType.hasRank() &&
       inputType.getRank() != outputType.getRank())
     return emitOpError()
            << "expected input tensor rank to equal result tensor rank";
-  if (outputType.hasRank() && permType.hasRank())
-    if (!permType.isDynamicDim(0) &&
-        permType.getDimSize(0) != outputType.getRank())
-      return emitOpError() << "expected permutation tensor dim 0 to have size "
+  if (outputType.hasRank())
+    if (constantPerms.size() != static_cast<size_t>(outputType.getRank()))
+      return emitOpError() << "expected perms attribute to have size "
                            << outputType.getRank()
                            << " (output rank) but got size "
-                           << permType.getDimSize(0);
-
-  SmallVector<int32_t> constantPerms;
-  if (succeeded(getConstantPerms(constantPerms))) {
-    // Assert that the permutation tensor has a rank, which means that the
-    // rank has been verified above.
-    assert(permType.hasRank() &&
-           "Unexpectedly found permutation tensor without rank");
-    if (!llvm::all_of(constantPerms,
-                      [&constantPerms](int32_t s) {
-                        return s >= 0 &&
-                               static_cast<size_t>(s) < constantPerms.size();
-                      }) ||
-        !isPermutationVector(llvm::to_vector(llvm::map_range(
-            constantPerms, [](int32_t v) -> int64_t { return v; }))))
-      return emitOpError() << "expected valid permutation tensor";
-
-    // Verify that the types of the input and output tensors are properly
-    // permuted.
-    if (inputType.hasRank() && outputType.hasRank()) {
-      assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
-             inputType.getRank() == outputType.getRank());
-
-      for (auto i = 0; i < outputType.getRank(); i++) {
-        if (inputType.isDynamicDim(constantPerms[i]) ||
-            outputType.isDynamicDim(i))
-          continue;
-
-        if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
-          return emitOpError()
-                 << "expected output tensor dim " << i << " to match "
-                 << "input dim " << constantPerms[i] << " with value of "
-                 << inputType.getDimSize(constantPerms[i]);
-      }
+                           << constantPerms.size();
+
+  if (!llvm::all_of(constantPerms,
+                    [&constantPerms](int32_t s) {
+                      return s >= 0 &&
+                             static_cast<size_t>(s) < constantPerms.size();
+                    }) ||
+      !isPermutationVector(llvm::to_vector(llvm::map_range(
+          constantPerms, [](int32_t v) -> int64_t { return v; }))))
+    return emitOpError() << "expected valid permutation indices";
+
+  // Verify that the types of the input and output tensors are properly
+  // permuted.
+  if (inputType.hasRank() && outputType.hasRank()) {
+    assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
+           inputType.getRank() == outputType.getRank());
+
+    for (auto i = 0; i < outputType.getRank(); i++) {
+      if (inputType.isDynamicDim(constantPerms[i]) ||
+          outputType.isDynamicDim(i))
+        continue;
+
+      if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
+        return emitOpError()
+               << "expected output tensor dim " << i << " to match "
+               << "input dim " << constantPerms[i] << " with value of "
+               << inputType.getDimSize(constantPerms[i]);
     }
   }
+
   return success();
 }
 
 LogicalResult TransposeOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
 
-  SmallVector<int32_t> transposePerms;
-  if (getConstantPerms(transposePerms).failed())
-    return failure();
+  const llvm::ArrayRef<int32_t> transposePerms = getPerms();
 
   Value input = getInput1();
   auto inputType = cast<TensorType>(input.getType());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 26baddcf1dd15..61011b6df4617 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -166,13 +166,9 @@ class TransposeConvStridedConverter
         getTosaConstShape(rewriter, loc, weightReshapeDims0));
 
     // Transpose the factored-out stride to the output channels.
-    Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
-        loc, RankedTensorType::get({6}, rewriter.getI32Type()),
-        rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
-
     weight = CreateOpAndInferShape<tosa::TransposeOp>(
         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
-        transposeWeightVal);
+        rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
 
     // Collapse the strides and output channels into a single dimension.
     llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
@@ -269,13 +265,9 @@ class TransposeConvStridedConverter
         convReshapeDims0Value);
 
     // Transpose the factored-out stride to the output channels.
-    Value transposeConvVal = rewriter.create<tosa::ConstOp>(
-        loc, RankedTensorType::get({6}, rewriter.getI32Type()),
-        rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
-
     conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
         rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
-        transposeConvVal);
+        rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
 
     // Fuse striding behavior back into width / height.
     llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 403ac48b91559..43e9507b4d95a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -224,13 +224,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
     if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
       return failure();
 
-    DenseIntElementsAttr permAttr;
-    if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
-      return failure();
     auto permValues = llvm::map_to_vector(
-        // TOSA allows both 32- and 64-bit integer tensors here.
-        permAttr.getValues<APInt>(),
-        [](const APInt &val) { return val.getSExtValue(); });
+        op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
 
     auto inputType = cast<ShapedType>(op.getInput1().getType());
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 64e5c31793f84..d4d8aae8b0316 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -367,9 +367,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
     TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
     IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
-  SmallVector<int32_t> perms;
-  if (failed(transposeOp.getConstantPerms(perms)) ||
-      !areInvolutionTransposes(hoistedPerms, perms))
+  if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
     return std::nullopt;
   return transposeOp.getInput1();
 }
@@ -506,14 +504,11 @@ bool TosaReduceTransposes::dependenciesAreValid(
       // replaced.
       Operation *user = use.getOwner();
       if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
-        SmallVector<int32_t> otherPerms;
-
         // Can later think about cases where transpose -> transpose
         // or reshape -> transpose, where the transposes are not necessarily
         // the same perms as the hoisted, if implementing a more general
         // transform. These could be permitted.
-        if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
-            !llvm::equal(perms, otherPerms))
+        if (!llvm::equal(perms, otherTranspose.getPerms()))
           return false;
       } else if (userNotContainedInValidTransposeDependencies(
                      user, validTransposes, transposeInfo)) {
@@ -607,9 +602,9 @@ void TosaReduceTransposes::runOnOperation() {
         !llvm::isa<RankedTensorType>(output.getType()))
       return;
 
-    // No transformation when transpose permutation non-constant.
-    if (failed(transposeOp.getConstantPerms(perms)))
-      return;
+    for (int32_t v : transposeOp.getPerms()) {
+      perms.push_back(v);
+    }
 
     // We let --canonicalize deal with identity transpose.
     if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..f2abb29b4fe66 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -56,15 +56,6 @@ static LogicalResult checkConstantOperandPad(Operation *op) {
   return success();
 }
 
-static LogicalResult checkConstantOpe...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

This patch changes the perms operand for Tosa Transpose operator to an i32 array attribute


Patch is 109.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128115.diff

20 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-5)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+5-11)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+7-24)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+48-93)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+2-10)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp (+1-6)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp (+5-10)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (-10)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+4-12)
  • (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+2-4)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+1-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+16-9)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+18-28)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+21-47)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+1-2)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+3-6)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+6-12)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+2-21)
  • (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+91-171)
  • (modified) mlir/test/Dialect/Tosa/transpose-fold.mlir (+9-18)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3de1c21f40b43..a06e03f831985 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2011,7 +2011,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    Tosa_Int32Tensor:$perms
+    DenseI32ArrayAttr:$perms
   );
 
   let results = (
@@ -2023,10 +2023,6 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
     Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
   ];
 
-  let extraClassDeclaration = [{
-    LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
-  }];
-
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a8fd536dd2548..42e88ee9026ac 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -329,13 +329,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
         SmallVector<int64_t> newWeightShape;
         for (auto dim : weightPerm)
           newWeightShape.push_back(weightShape[dim]);
-        auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
-        Value weightPermValue =
-            rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+        auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
         Type newWeightTy =
             RankedTensorType::get(newWeightShape, weightTy.getElementType());
         weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
-                                                    weightPermValue);
+                                                    weightPermAttr);
       }
     }
 
@@ -353,13 +351,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       SmallVector<int64_t> newWeightShape;
       for (auto dim : weightPerm)
         newWeightShape.push_back(weightShape[dim]);
-      auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
-      Value weightPermValue =
-          rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+      auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
       Type newWeightTy =
           RankedTensorType::get(newWeightShape, weightTy.getElementType());
       weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
-                                                  weightPermValue);
+                                                  weightPermAttr);
     }
 
     // Extract the attributes for convolution.
@@ -970,9 +966,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
 
   LogicalResult matchAndRewrite(tosa::TransposeOp op,
                                 PatternRewriter &rewriter) const final {
-    SmallVector<int32_t> constantPerms;
-    if (failed(op.getConstantPerms(constantPerms)))
-      return failure();
+    const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
 
     Location loc = op.getLoc();
     // The verifier should have made sure we have a valid TOSA permutation
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..8e2d8662ece8d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,13 +88,10 @@ struct ConsolidateTransposeOptimization
       return rewriter.notifyMatchFailure(transposeOp,
                                          "input must be transpose operation");
 
-    SmallVector<int32_t> transposePerms, innerTransposePerms;
-    if (transposeOp.getConstantPerms(transposePerms).failed())
-      return rewriter.notifyMatchFailure(transposeOp,
-                                         "transpose perms must be constant");
-    if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
-      return rewriter.notifyMatchFailure(
-          transposeOp, "inner transpose perms must be constant");
+    const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
+    const llvm::ArrayRef<int32_t> innerTransposePerms =
+        innerTranspose.getPerms();
+
     if (transposePerms.size() != innerTransposePerms.size())
       return rewriter.notifyMatchFailure(
           transposeOp,
@@ -108,15 +105,9 @@ struct ConsolidateTransposeOptimization
     for (int i = 0, s = transposePerms.size(); i < s; ++i)
       perms[i] = innerTransposePerms[transposePerms[i]];
 
-    auto permsTy =
-        RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
-    auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
-    Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
-                                                      permsTy, permsAttr);
-
     rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
         transposeOp, transposeOp.getResult().getType(),
-        innerTranspose.getInput1(), permsValue);
+        innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
 
     return success();
   }
@@ -128,10 +119,6 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
 
   LogicalResult matchAndRewrite(tosa::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
-    DenseIntElementsAttr permAttr;
-    if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
-      return rewriter.notifyMatchFailure(op, "Non-constant permutation");
-
     if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
       return rewriter.notifyMatchFailure(
           op, "Src is from transpose, can compose transposes");
@@ -156,9 +143,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
     if (numDynDims > 1)
       return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
 
-    SmallVector<int64_t> permValues = llvm::to_vector<6>(
-        llvm::map_range(permAttr.getValues<APInt>(),
-                        [](const APInt &val) { return val.getSExtValue(); }));
+    const llvm::ArrayRef<int32_t> permValues = op.getPerms();
 
     SmallVector<int64_t> nonZeroPerms;
     nonZeroPerms.reserve(permValues.size());
@@ -1175,9 +1160,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
   }
 
   // Transpose is not the identity transpose.
-  SmallVector<int32_t> perms;
-  if (getConstantPerms(perms).failed())
-    return {};
+  const llvm::ArrayRef<int32_t> perms = getPerms();
 
   if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
     return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e9c33e1b1bf10..7030dccd693a4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1374,41 +1374,22 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
   return mlir::success();
 }
 
-LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
-  // Perms must be constants.
-  DenseIntElementsAttr permsAttr;
-  if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
-    return failure();
-
-  perms.clear();
-  for (auto v : permsAttr.getValues<APInt>())
-    perms.push_back(v.getSExtValue());
-
-  return success();
-}
-
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TransposeOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
-  ShapeAdaptor permsShape(adaptor.getPerms().getType());
-
-  // We cannot infer anything from a rank-0 "permutation" tensor.
-  if (permsShape.hasRank() && permsShape.getRank() == 0)
-    return failure();
 
   // If input rank and permutation length is unknown, the output rank is
   // unknown.
-  if (!inputShape.hasRank() || !permsShape.hasRank() ||
-      permsShape.isDynamicDim(0)) {
+  if (!inputShape.hasRank()) {
     inferredReturnShapes.push_back(ShapedTypeComponents());
     return success();
   }
 
   // This would imply the number of permutations does not match the rank of
   // the input which is illegal.
-  if (permsShape.getDimSize(0) != inputShape.getRank()) {
+  if (adaptor.getPerms().size() != static_cast<size_t>(inputShape.getRank())) {
     return failure();
   }
 
@@ -1437,28 +1418,16 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   }
 
   outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
-  // If the permuations are a constant we can directly determine the output
-  // shape.
-  DenseIntElementsAttr attr;
-  if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
-      attr.getType().getRank() == 1) {
-    ShapeAdaptor permShape = attr;
-    // Constant permutation must be the same length as the input rank.
-    if (inputShape.getRank() != permShape.getRank())
-      return emitOptionalError(location,
-                               "constant permutation must be the same length"
-                               " as the input rank");
-
-    // Constant permutation values must be within the input rank.
-    for (int i = 0, e = inputShape.getRank(); i < e; i++) {
-      if (inputShape.getRank() <= permShape.getDimSize(i))
-        return failure();
-    }
 
-    outputShape.reserve(inputShape.getRank());
-    for (int i = 0, s = inputShape.getRank(); i < s; i++) {
-      outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
-    }
+  // Constant permutation values must be within the input rank.
+  for (auto i : adaptor.getPerms()) {
+    if (inputShape.getRank() <= i)
+      return failure();
+  }
+
+  outputShape.reserve(inputShape.getRank());
+  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
+    outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
   }
 
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -1467,75 +1436,61 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
 
 LogicalResult tosa::TransposeOp::verify() {
   TensorType inputType = getInput1().getType();
-  TensorType permType = getPerms().getType();
   TensorType outputType = getOutput().getType();
+  const llvm::ArrayRef<int32_t> constantPerms = getPerms();
 
-  if (permType.hasRank() && permType.getRank() != 1)
-    return emitOpError()
-           << "expected permutation tensor to be rank 1 but got rank "
-           << permType.getRank();
-  if (inputType.hasRank() && permType.hasRank())
-    if (!permType.isDynamicDim(0) &&
-        permType.getDimSize(0) != inputType.getRank())
-      return emitOpError() << "expected permutation tensor dim 0 to have size "
+  if (inputType.hasRank())
+    if (constantPerms.size() != static_cast<size_t>(inputType.getRank()))
+      return emitOpError() << "expected perms attribute to have size "
                            << inputType.getRank()
                            << " (input rank) but got size "
-                           << permType.getDimSize(0);
+                           << constantPerms.size();
   if (inputType.hasRank() && outputType.hasRank() &&
       inputType.getRank() != outputType.getRank())
     return emitOpError()
            << "expected input tensor rank to equal result tensor rank";
-  if (outputType.hasRank() && permType.hasRank())
-    if (!permType.isDynamicDim(0) &&
-        permType.getDimSize(0) != outputType.getRank())
-      return emitOpError() << "expected permutation tensor dim 0 to have size "
+  if (outputType.hasRank())
+    if (constantPerms.size() != static_cast<size_t>(outputType.getRank()))
+      return emitOpError() << "expected perms attribute to have size "
                            << outputType.getRank()
                            << " (output rank) but got size "
-                           << permType.getDimSize(0);
-
-  SmallVector<int32_t> constantPerms;
-  if (succeeded(getConstantPerms(constantPerms))) {
-    // Assert that the permutation tensor has a rank, which means that the
-    // rank has been verified above.
-    assert(permType.hasRank() &&
-           "Unexpectedly found permutation tensor without rank");
-    if (!llvm::all_of(constantPerms,
-                      [&constantPerms](int32_t s) {
-                        return s >= 0 &&
-                               static_cast<size_t>(s) < constantPerms.size();
-                      }) ||
-        !isPermutationVector(llvm::to_vector(llvm::map_range(
-            constantPerms, [](int32_t v) -> int64_t { return v; }))))
-      return emitOpError() << "expected valid permutation tensor";
-
-    // Verify that the types of the input and output tensors are properly
-    // permuted.
-    if (inputType.hasRank() && outputType.hasRank()) {
-      assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
-             inputType.getRank() == outputType.getRank());
-
-      for (auto i = 0; i < outputType.getRank(); i++) {
-        if (inputType.isDynamicDim(constantPerms[i]) ||
-            outputType.isDynamicDim(i))
-          continue;
-
-        if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
-          return emitOpError()
-                 << "expected output tensor dim " << i << " to match "
-                 << "input dim " << constantPerms[i] << " with value of "
-                 << inputType.getDimSize(constantPerms[i]);
-      }
+                           << constantPerms.size();
+
+  if (!llvm::all_of(constantPerms,
+                    [&constantPerms](int32_t s) {
+                      return s >= 0 &&
+                             static_cast<size_t>(s) < constantPerms.size();
+                    }) ||
+      !isPermutationVector(llvm::to_vector(llvm::map_range(
+          constantPerms, [](int32_t v) -> int64_t { return v; }))))
+    return emitOpError() << "expected valid permutation indices";
+
+  // Verify that the types of the input and output tensors are properly
+  // permuted.
+  if (inputType.hasRank() && outputType.hasRank()) {
+    assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
+           inputType.getRank() == outputType.getRank());
+
+    for (auto i = 0; i < outputType.getRank(); i++) {
+      if (inputType.isDynamicDim(constantPerms[i]) ||
+          outputType.isDynamicDim(i))
+        continue;
+
+      if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
+        return emitOpError()
+               << "expected output tensor dim " << i << " to match "
+               << "input dim " << constantPerms[i] << " with value of "
+               << inputType.getDimSize(constantPerms[i]);
     }
   }
+
   return success();
 }
 
 LogicalResult TransposeOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
 
-  SmallVector<int32_t> transposePerms;
-  if (getConstantPerms(transposePerms).failed())
-    return failure();
+  const llvm::ArrayRef<int32_t> transposePerms = getPerms();
 
   Value input = getInput1();
   auto inputType = cast<TensorType>(input.getType());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 26baddcf1dd15..61011b6df4617 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -166,13 +166,9 @@ class TransposeConvStridedConverter
         getTosaConstShape(rewriter, loc, weightReshapeDims0));
 
     // Transpose the factored-out stride to the output channels.
-    Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
-        loc, RankedTensorType::get({6}, rewriter.getI32Type()),
-        rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
-
     weight = CreateOpAndInferShape<tosa::TransposeOp>(
         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
-        transposeWeightVal);
+        rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
 
     // Collapse the strides and output channels into a single dimension.
     llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
@@ -269,13 +265,9 @@ class TransposeConvStridedConverter
         convReshapeDims0Value);
 
     // Transpose the factored-out stride to the output channels.
-    Value transposeConvVal = rewriter.create<tosa::ConstOp>(
-        loc, RankedTensorType::get({6}, rewriter.getI32Type()),
-        rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
-
     conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
         rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
-        transposeConvVal);
+        rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
 
     // Fuse striding behavior back into width / height.
     llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 403ac48b91559..43e9507b4d95a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -224,13 +224,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
     if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
       return failure();
 
-    DenseIntElementsAttr permAttr;
-    if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
-      return failure();
     auto permValues = llvm::map_to_vector(
-        // TOSA allows both 32- and 64-bit integer tensors here.
-        permAttr.getValues<APInt>(),
-        [](const APInt &val) { return val.getSExtValue(); });
+        op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
 
     auto inputType = cast<ShapedType>(op.getInput1().getType());
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 64e5c31793f84..d4d8aae8b0316 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -367,9 +367,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
     TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
     IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
-  SmallVector<int32_t> perms;
-  if (failed(transposeOp.getConstantPerms(perms)) ||
-      !areInvolutionTransposes(hoistedPerms, perms))
+  if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
     return std::nullopt;
   return transposeOp.getInput1();
 }
@@ -506,14 +504,11 @@ bool TosaReduceTransposes::dependenciesAreValid(
       // replaced.
       Operation *user = use.getOwner();
       if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
-        SmallVector<int32_t> otherPerms;
-
         // Can later think about cases where transpose -> transpose
         // or reshape -> transpose, where the transposes are not necessarily
         // the same perms as the hoisted, if implementing a more general
         // transform. These could be permitted.
-        if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
-            !llvm::equal(perms, otherPerms))
+        if (!llvm::equal(perms, otherTranspose.getPerms()))
           return false;
       } else if (userNotContainedInValidTransposeDependencies(
                      user, validTransposes, transposeInfo)) {
@@ -607,9 +602,9 @@ void TosaReduceTransposes::runOnOperation() {
         !llvm::isa<RankedTensorType>(output.getType()))
       return;
 
-    // No transformation when transpose permutation non-constant.
-    if (failed(transposeOp.getConstantPerms(perms)))
-      return;
+    for (int32_t v : transposeOp.getPerms()) {
+      perms.push_back(v);
+    }
 
     // We let --canonicalize deal with identity transpose.
     if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..f2abb29b4fe66 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -56,15 +56,6 @@ static LogicalResult checkConstantOperandPad(Operation *op) {
   return success();
 }
 
-static LogicalResult checkConstantOpe...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-mlir-memref

Author: Tai Ly (Tai78641)

Changes

This patch changes the perms operand for Tosa Transpose operator to an i32 array attribute


Patch is 109.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128115.diff

20 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-5)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+5-11)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+7-24)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+48-93)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+2-10)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp (+1-6)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp (+5-10)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (-10)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+4-12)
  • (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+2-4)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+1-2)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+16-9)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+18-28)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+21-47)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+1-2)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+3-6)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+6-12)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+2-21)
  • (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+91-171)
  • (modified) mlir/test/Dialect/Tosa/transpose-fold.mlir (+9-18)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 3de1c21f40b43..a06e03f831985 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2011,7 +2011,7 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    Tosa_Int32Tensor:$perms
+    DenseI32ArrayAttr:$perms
   );
 
   let results = (
@@ -2023,10 +2023,6 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
     Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
   ];
 
-  let extraClassDeclaration = [{
-    LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
-  }];
-
   let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a8fd536dd2548..42e88ee9026ac 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -329,13 +329,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
         SmallVector<int64_t> newWeightShape;
         for (auto dim : weightPerm)
           newWeightShape.push_back(weightShape[dim]);
-        auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
-        Value weightPermValue =
-            rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+        auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
         Type newWeightTy =
             RankedTensorType::get(newWeightShape, weightTy.getElementType());
         weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
-                                                    weightPermValue);
+                                                    weightPermAttr);
       }
     }
 
@@ -353,13 +351,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       SmallVector<int64_t> newWeightShape;
       for (auto dim : weightPerm)
         newWeightShape.push_back(weightShape[dim]);
-      auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
-      Value weightPermValue =
-          rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+      auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
       Type newWeightTy =
           RankedTensorType::get(newWeightShape, weightTy.getElementType());
       weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
-                                                  weightPermValue);
+                                                  weightPermAttr);
     }
 
     // Extract the attributes for convolution.
@@ -970,9 +966,7 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
 
   LogicalResult matchAndRewrite(tosa::TransposeOp op,
                                 PatternRewriter &rewriter) const final {
-    SmallVector<int32_t> constantPerms;
-    if (failed(op.getConstantPerms(constantPerms)))
-      return failure();
+    const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
 
     Location loc = op.getLoc();
     // The verifier should have made sure we have a valid TOSA permutation
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 9bfc2aae1d6a5..8e2d8662ece8d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -88,13 +88,10 @@ struct ConsolidateTransposeOptimization
       return rewriter.notifyMatchFailure(transposeOp,
                                          "input must be transpose operation");
 
-    SmallVector<int32_t> transposePerms, innerTransposePerms;
-    if (transposeOp.getConstantPerms(transposePerms).failed())
-      return rewriter.notifyMatchFailure(transposeOp,
-                                         "transpose perms must be constant");
-    if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
-      return rewriter.notifyMatchFailure(
-          transposeOp, "inner transpose perms must be constant");
+    const llvm::ArrayRef<int32_t> transposePerms = transposeOp.getPerms();
+    const llvm::ArrayRef<int32_t> innerTransposePerms =
+        innerTranspose.getPerms();
+
     if (transposePerms.size() != innerTransposePerms.size())
       return rewriter.notifyMatchFailure(
           transposeOp,
@@ -108,15 +105,9 @@ struct ConsolidateTransposeOptimization
     for (int i = 0, s = transposePerms.size(); i < s; ++i)
       perms[i] = innerTransposePerms[transposePerms[i]];
 
-    auto permsTy =
-        RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
-    auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
-    Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
-                                                      permsTy, permsAttr);
-
     rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
         transposeOp, transposeOp.getResult().getType(),
-        innerTranspose.getInput1(), permsValue);
+        innerTranspose.getInput1(), rewriter.getDenseI32ArrayAttr(perms));
 
     return success();
   }
@@ -128,10 +119,6 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
 
   LogicalResult matchAndRewrite(tosa::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
-    DenseIntElementsAttr permAttr;
-    if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
-      return rewriter.notifyMatchFailure(op, "Non-constant permutation");
-
     if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
       return rewriter.notifyMatchFailure(
           op, "Src is from transpose, can compose transposes");
@@ -156,9 +143,7 @@ struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
     if (numDynDims > 1)
       return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
 
-    SmallVector<int64_t> permValues = llvm::to_vector<6>(
-        llvm::map_range(permAttr.getValues<APInt>(),
-                        [](const APInt &val) { return val.getSExtValue(); }));
+    const llvm::ArrayRef<int32_t> permValues = op.getPerms();
 
     SmallVector<int64_t> nonZeroPerms;
     nonZeroPerms.reserve(permValues.size());
@@ -1175,9 +1160,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
   }
 
   // Transpose is not the identity transpose.
-  SmallVector<int32_t> perms;
-  if (getConstantPerms(perms).failed())
-    return {};
+  const llvm::ArrayRef<int32_t> perms = getPerms();
 
   if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
     return {};
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index e9c33e1b1bf10..7030dccd693a4 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1374,41 +1374,22 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
   return mlir::success();
 }
 
-LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
-  // Perms must be constants.
-  DenseIntElementsAttr permsAttr;
-  if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
-    return failure();
-
-  perms.clear();
-  for (auto v : permsAttr.getValues<APInt>())
-    perms.push_back(v.getSExtValue());
-
-  return success();
-}
-
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TransposeOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
-  ShapeAdaptor permsShape(adaptor.getPerms().getType());
-
-  // We cannot infer anything from a rank-0 "permutation" tensor.
-  if (permsShape.hasRank() && permsShape.getRank() == 0)
-    return failure();
 
   // If input rank and permutation length is unknown, the output rank is
   // unknown.
-  if (!inputShape.hasRank() || !permsShape.hasRank() ||
-      permsShape.isDynamicDim(0)) {
+  if (!inputShape.hasRank()) {
     inferredReturnShapes.push_back(ShapedTypeComponents());
     return success();
   }
 
   // This would imply the number of permutations does not match the rank of
   // the input which is illegal.
-  if (permsShape.getDimSize(0) != inputShape.getRank()) {
+  if (adaptor.getPerms().size() != static_cast<size_t>(inputShape.getRank())) {
     return failure();
   }
 
@@ -1437,28 +1418,16 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
   }
 
   outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
-  // If the permuations are a constant we can directly determine the output
-  // shape.
-  DenseIntElementsAttr attr;
-  if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
-      attr.getType().getRank() == 1) {
-    ShapeAdaptor permShape = attr;
-    // Constant permutation must be the same length as the input rank.
-    if (inputShape.getRank() != permShape.getRank())
-      return emitOptionalError(location,
-                               "constant permutation must be the same length"
-                               " as the input rank");
-
-    // Constant permutation values must be within the input rank.
-    for (int i = 0, e = inputShape.getRank(); i < e; i++) {
-      if (inputShape.getRank() <= permShape.getDimSize(i))
-        return failure();
-    }
 
-    outputShape.reserve(inputShape.getRank());
-    for (int i = 0, s = inputShape.getRank(); i < s; i++) {
-      outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
-    }
+  // Constant permutation values must be within the input rank.
+  for (auto i : adaptor.getPerms()) {
+    if (inputShape.getRank() <= i)
+      return failure();
+  }
+
+  outputShape.reserve(inputShape.getRank());
+  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
+    outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
   }
 
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -1467,75 +1436,61 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
 
 LogicalResult tosa::TransposeOp::verify() {
   TensorType inputType = getInput1().getType();
-  TensorType permType = getPerms().getType();
   TensorType outputType = getOutput().getType();
+  const llvm::ArrayRef<int32_t> constantPerms = getPerms();
 
-  if (permType.hasRank() && permType.getRank() != 1)
-    return emitOpError()
-           << "expected permutation tensor to be rank 1 but got rank "
-           << permType.getRank();
-  if (inputType.hasRank() && permType.hasRank())
-    if (!permType.isDynamicDim(0) &&
-        permType.getDimSize(0) != inputType.getRank())
-      return emitOpError() << "expected permutation tensor dim 0 to have size "
+  if (inputType.hasRank())
+    if (constantPerms.size() != static_cast<size_t>(inputType.getRank()))
+      return emitOpError() << "expected perms attribute to have size "
                            << inputType.getRank()
                            << " (input rank) but got size "
-                           << permType.getDimSize(0);
+                           << constantPerms.size();
   if (inputType.hasRank() && outputType.hasRank() &&
       inputType.getRank() != outputType.getRank())
     return emitOpError()
            << "expected input tensor rank to equal result tensor rank";
-  if (outputType.hasRank() && permType.hasRank())
-    if (!permType.isDynamicDim(0) &&
-        permType.getDimSize(0) != outputType.getRank())
-      return emitOpError() << "expected permutation tensor dim 0 to have size "
+  if (outputType.hasRank())
+    if (constantPerms.size() != static_cast<size_t>(outputType.getRank()))
+      return emitOpError() << "expected perms attribute to have size "
                            << outputType.getRank()
                            << " (output rank) but got size "
-                           << permType.getDimSize(0);
-
-  SmallVector<int32_t> constantPerms;
-  if (succeeded(getConstantPerms(constantPerms))) {
-    // Assert that the permutation tensor has a rank, which means that the
-    // rank has been verified above.
-    assert(permType.hasRank() &&
-           "Unexpectedly found permutation tensor without rank");
-    if (!llvm::all_of(constantPerms,
-                      [&constantPerms](int32_t s) {
-                        return s >= 0 &&
-                               static_cast<size_t>(s) < constantPerms.size();
-                      }) ||
-        !isPermutationVector(llvm::to_vector(llvm::map_range(
-            constantPerms, [](int32_t v) -> int64_t { return v; }))))
-      return emitOpError() << "expected valid permutation tensor";
-
-    // Verify that the types of the input and output tensors are properly
-    // permuted.
-    if (inputType.hasRank() && outputType.hasRank()) {
-      assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
-             inputType.getRank() == outputType.getRank());
-
-      for (auto i = 0; i < outputType.getRank(); i++) {
-        if (inputType.isDynamicDim(constantPerms[i]) ||
-            outputType.isDynamicDim(i))
-          continue;
-
-        if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
-          return emitOpError()
-                 << "expected output tensor dim " << i << " to match "
-                 << "input dim " << constantPerms[i] << " with value of "
-                 << inputType.getDimSize(constantPerms[i]);
-      }
+                           << constantPerms.size();
+
+  if (!llvm::all_of(constantPerms,
+                    [&constantPerms](int32_t s) {
+                      return s >= 0 &&
+                             static_cast<size_t>(s) < constantPerms.size();
+                    }) ||
+      !isPermutationVector(llvm::to_vector(llvm::map_range(
+          constantPerms, [](int32_t v) -> int64_t { return v; }))))
+    return emitOpError() << "expected valid permutation indices";
+
+  // Verify that the types of the input and output tensors are properly
+  // permuted.
+  if (inputType.hasRank() && outputType.hasRank()) {
+    assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
+           inputType.getRank() == outputType.getRank());
+
+    for (auto i = 0; i < outputType.getRank(); i++) {
+      if (inputType.isDynamicDim(constantPerms[i]) ||
+          outputType.isDynamicDim(i))
+        continue;
+
+      if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
+        return emitOpError()
+               << "expected output tensor dim " << i << " to match "
+               << "input dim " << constantPerms[i] << " with value of "
+               << inputType.getDimSize(constantPerms[i]);
     }
   }
+
   return success();
 }
 
 LogicalResult TransposeOp::reifyResultShapes(
     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
 
-  SmallVector<int32_t> transposePerms;
-  if (getConstantPerms(transposePerms).failed())
-    return failure();
+  const llvm::ArrayRef<int32_t> transposePerms = getPerms();
 
   Value input = getInput1();
   auto inputType = cast<TensorType>(input.getType());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 26baddcf1dd15..61011b6df4617 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -166,13 +166,9 @@ class TransposeConvStridedConverter
         getTosaConstShape(rewriter, loc, weightReshapeDims0));
 
     // Transpose the factored-out stride to the output channels.
-    Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
-        loc, RankedTensorType::get({6}, rewriter.getI32Type()),
-        rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
-
     weight = CreateOpAndInferShape<tosa::TransposeOp>(
         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
-        transposeWeightVal);
+        rewriter.getDenseI32ArrayAttr({2, 4, 0, 1, 3, 5}));
 
     // Collapse the strides and output channels into a single dimension.
     llvm::SmallVector<int64_t, 4> weightReshapeDims1 = {
@@ -269,13 +265,9 @@ class TransposeConvStridedConverter
         convReshapeDims0Value);
 
     // Transpose the factored-out stride to the output channels.
-    Value transposeConvVal = rewriter.create<tosa::ConstOp>(
-        loc, RankedTensorType::get({6}, rewriter.getI32Type()),
-        rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
-
     conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
         rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
-        transposeConvVal);
+        rewriter.getDenseI32ArrayAttr({0, 1, 3, 2, 4, 5}));
 
     // Fuse striding behavior back into width / height.
     llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 403ac48b91559..43e9507b4d95a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -224,13 +224,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
     if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
       return failure();
 
-    DenseIntElementsAttr permAttr;
-    if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
-      return failure();
     auto permValues = llvm::map_to_vector(
-        // TOSA allows both 32- and 64-bit integer tensors here.
-        permAttr.getValues<APInt>(),
-        [](const APInt &val) { return val.getSExtValue(); });
+        op.getPerms(), [](const int32_t v) { return static_cast<int64_t>(v); });
 
     auto inputType = cast<ShapedType>(op.getInput1().getType());
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
index 64e5c31793f84..d4d8aae8b0316 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp
@@ -367,9 +367,7 @@ std::optional<Value> TosaReduceTransposes::buildMappedToValue(
 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
     TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
     IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
-  SmallVector<int32_t> perms;
-  if (failed(transposeOp.getConstantPerms(perms)) ||
-      !areInvolutionTransposes(hoistedPerms, perms))
+  if (!areInvolutionTransposes(hoistedPerms, transposeOp.getPerms()))
     return std::nullopt;
   return transposeOp.getInput1();
 }
@@ -506,14 +504,11 @@ bool TosaReduceTransposes::dependenciesAreValid(
       // replaced.
       Operation *user = use.getOwner();
       if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
-        SmallVector<int32_t> otherPerms;
-
         // Can later think about cases where transpose -> transpose
         // or reshape -> transpose, where the transposes are not necessarily
         // the same perms as the hoisted, if implementing a more general
         // transform. These could be permitted.
-        if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
-            !llvm::equal(perms, otherPerms))
+        if (!llvm::equal(perms, otherTranspose.getPerms()))
           return false;
       } else if (userNotContainedInValidTransposeDependencies(
                      user, validTransposes, transposeInfo)) {
@@ -607,9 +602,9 @@ void TosaReduceTransposes::runOnOperation() {
         !llvm::isa<RankedTensorType>(output.getType()))
       return;
 
-    // No transformation when transpose permutation non-constant.
-    if (failed(transposeOp.getConstantPerms(perms)))
-      return;
+    for (int32_t v : transposeOp.getPerms()) {
+      perms.push_back(v);
+    }
 
     // We let --canonicalize deal with identity transpose.
     if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index f74a4b4c58b80..f2abb29b4fe66 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -56,15 +56,6 @@ static LogicalResult checkConstantOperandPad(Operation *op) {
   return success();
 }
 
-static LogicalResult checkConstantOpe...
[truncated]

This patch changes the perms operand for Tosa Transpose operator to
an i32 array attribute

Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I6c54d203ee7eb8d77442ff29fdbff2d2e3e6950b
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - just had two minor non-blocking comments

@@ -223,7 +209,7 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32
%64 = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these constants seem no longer used

@@ -298,46 +275,41 @@ func.func @test_two_different_downstream_converge_to_reshape_same_perms(%arg0: t
%shape = tosa.const_shape {value = dense<[1, 64, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: above const no longer used

@Jerry-Ge Jerry-Ge merged commit 48db4e8 into llvm:main Feb 25, 2025
11 checks passed
yzhang93 added a commit to iree-org/iree that referenced this pull request Feb 26, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
There are some attribute changes upstream:

- TOSA: llvm/llvm-project#128115
- NVPTX: llvm/llvm-project#127736

---------

Signed-off-by: yzhang93 <zhyuhang88@gmail.com>
geomin12 pushed a commit to geomin12/iree that referenced this pull request Mar 5, 2025
There are some attribute changes upstream:

- TOSA: llvm/llvm-project#128115
- NVPTX: llvm/llvm-project#127736

---------

Signed-off-by: yzhang93 <zhyuhang88@gmail.com>
Signed-off-by: geomin12 <geomin12@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants