diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index a101552e419bc..44b0a4b26588f 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1908,6 +1908,10 @@ void TransposeOp::getEffects( LogicalResult TransposeOp::fold(FoldAdaptor adaptor, SmallVectorImpl &result) { + // Only the tensor type is supported. + if (!isa(getInput().getType())) + return failure(); + // Single dimension transpose. if (getPermutation().size() == 0) { result.push_back(getInput()); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index a50fbb0fc3b86..4bc2ed140da91 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1216,3 +1216,19 @@ func.func @concats_of_fill( // CHECK: %[[CONCAT:.+]] = tensor.concat dim(1) %[[EMPTY0]], %[[EMPTY1]] // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[CONCAT]] : // CHECK: return %[[FILL]] + +// ----- + +func.func @transpose_buffer(%input: memref, + %init: memref) { + linalg.transpose ins(%input:memref) + outs(%init:memref) + permutation = [0] + func.return +} + +// CHECK-LABEL: func.func @transpose_buffer( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref) { +// CHECK: linalg.transpose ins(%[[VAL_0]] : memref) +// CHECK-SAME: outs(%[[VAL_1]] : memref) permutation = [0] diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir index b818170a8e797..6ddbd06389f5e 100644 --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -873,3 +873,39 @@ func.func @lower_to_loops_with_rank_reducing_subviews( // CHECKPARALLEL: %[[VAL:.+]] = memref.load %{{.+}}[%[[IV]]] // CHECKPARALLEL: memref.store %[[VAL]], %{{.+}}[%[[IV]]] // CHECKPARALLEL: } + +// ----- + +func.func @transpose(%input: memref, + %init: memref) { + linalg.transpose ins(%input:memref) + outs(%init:memref) + permutation = [0] + return +} +// CHECK-LABEL: func.func @transpose( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref) { +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref +// CHECK: scf.for %[[VAL_5:.*]] = %[[VAL_3]] to %[[VAL_4]] step %[[VAL_2]] { +// CHECK: %[[VAL_6:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_5]]] : memref +// CHECK: memref.store %[[VAL_6]], %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref +// CHECK: } +// CHECK: return +// CHECK: } + +// CHECKPARALLEL-LABEL: func.func @transpose( +// CHECKPARALLEL-SAME: %[[VAL_0:.*]]: memref, +// CHECKPARALLEL-SAME: %[[VAL_1:.*]]: memref) { +// CHECKPARALLEL: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECKPARALLEL: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECKPARALLEL: %[[VAL_4:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : memref +// CHECKPARALLEL: scf.parallel (%[[VAL_5:.*]]) = (%[[VAL_3]]) to (%[[VAL_4]]) step (%[[VAL_2]]) { +// CHECKPARALLEL: %[[VAL_6:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_5]]] : memref +// CHECKPARALLEL: memref.store %[[VAL_6]], %[[VAL_1]]{{\[}}%[[VAL_5]]] : memref +// CHECKPARALLEL: scf.reduce +// CHECKPARALLEL: } +// CHECKPARALLEL: return +// CHECKPARALLEL: }