-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Add vector.shuffle
tree transformation
#145740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This PR adds a new transformation that turns sequences of `vector.to_elements` and `vector.from_elements` into a binary tree of `vector.shuffle` operations. (Related RFC: https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779). Example: ``` %0:4 = vector.to_elements %a : vector<4xf32> %1:4 = vector.to_elements %b : vector<4xf32> %2:4 = vector.to_elements %c : vector<4xf32> %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %2#2, %2#3 : vector<12xf32> ==> %0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> %1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32> %2 = vector.shuffle %0, %1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> ``` The algorithm leverages the structured extraction/insertion information of `vector.to_elements` and `vector.from_elements` operations and builds a set of intervals to determine the vector length that should be used at each level of the tree. There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along.
@llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesThis PR adds a new transformation that turns sequences of Example:
The algorithm leverages the structured extraction/insertion information of There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along. Patch is 62.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145740.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 14cff4ff893b5..6761cd65c5009 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -297,6 +297,13 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
/// n > 1.
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+/// Populate patterns to rewrite sequences of `vector.to_elements` +
+/// `vector.from_elements` operations into a tree of `vector.shuffle`
+/// operations.
+void populateVectorToFromElementsToShuffleTreePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..959c2fbf31f1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..9431a4d8e240f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,4 +34,9 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
];
}
+def LowerVectorToFromElementsToShuffleTree
+ : Pass<"lower-vector-to-from-elements-to-shuffle-tree", "func::FuncOp"> {
+ let summary = "Lower `vector.to_elements` and `vector.from_elements` to a tree of `vector.shuffle` operations";
+}
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..9e287fc109990 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
+ LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
SubsetOpInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
new file mode 100644
index 0000000000000..53728d6dbe2a3
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -0,0 +1,692 @@
+//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern rewrites to lower sequences of
+// `vector.to_elements` and `vector.from_elements` operations into a tree of
+// `vector.shuffle` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace vector {
+
+#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+
+} // namespace vector
+} // namespace mlir
+
+#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+// Indentation unit for debug output formatting.
+constexpr unsigned kIndScale = 2;
+
+/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
+using Interval = std::pair<unsigned, unsigned>;
+// Sentinel value for uninitialized intervals.
+constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
+
+/// The VectorShuffleTreeBuilder builds a balanced binary tree of
+/// `vector.shuffle` operations from one or more `vector.to_elements`
+/// operations feeding a single `vector.from_elements` operation.
+///
+/// The implementation generates hardware-agnostic `vector.shuffle` operations
+/// that minimize both the number of shuffle operations and the length of
+/// intermediate vectors (to the extent possible). The tree has the
+/// following properties:
+///
+/// 1. Vectors are shuffled in pairs by order of appearance in
+/// the `vector.from_elements` operand list.
+/// 2. Each input vector to each level is used only once.
+/// 3. The number of levels in the tree is:
+/// ceil(log2(# `vector.to_elements` ops)).
+/// 4. Vectors at each level of the tree have the same vector length.
+/// 5. Vector positions that do not need to be shuffled are represented with
+/// poison in the shuffle mask.
+///
+/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
+///
+/// %0:4 = vector.to_elements %a : vector<4xf32>
+/// %1:4 = vector.to_elements %b : vector<4xf32>
+/// %2:4 = vector.to_elements %c : vector<4xf32>
+/// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
+/// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
+/// : vector<12xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
+/// : vector<4xf32>, vector<4xf32>
+/// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
+/// : vector<4xf32>, vector<4xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
+/// 6, 7, 8, 9, 10, 11]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * The shuffle tree has two levels:
+/// - Level 1 = (%shuffle0, %shuffle1)
+/// - Level 2 = (%result)
+/// * `%a` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%0#0` and `%1#0`).
+/// * `%c` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 16,
+/// respectively.
+/// * `%shuffle1` uses poison values to match the vector length of its
+/// tree level (8).
+///
+///
+/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
+/// : vector<5xf32>, vector<5xf32>
+/// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
+/// : vector<5xf32>, vector<5xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * `%c` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%2#2` and `%1#1`).
+/// * `%a` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 9,
+/// respectively.
+/// * `%shuffle0` uses poison values to mark unused vector positions and
+/// match the vector length of its tree level (8).
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+///
+class VectorShuffleTreeBuilder {
+public:
+ VectorShuffleTreeBuilder() = delete;
+ VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
+ ArrayRef<ToElementsOp> toElemDefs);
+
+ /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
+ /// and compute the shuffle tree configuration. This method does not generate
+ /// any IR.
+ LogicalResult computeShuffleTree();
+
+ /// Materialize the shuffle tree configuration computed by
+ /// `computeShuffleTree` in the IR.
+ Value generateShuffleTree(PatternRewriter &rewriter);
+
+private:
+ // IR input information.
+ FromElementsOp fromElementsOp;
+ SmallVector<ToElementsOp> toElementsDefs;
+
+ // Shuffle tree configuration.
+ unsigned numLevels;
+ SmallVector<unsigned> vectorSizePerLevel;
+ /// Holds the range of positions in the final output that each vector input
+ /// in the tree is contributing to.
+ SmallVector<SmallVector<Interval>> inputIntervalsPerLevel;
+
+ // Utility methods to compute the shuffle tree configuration.
+ void computeInputVectorIntervals();
+ void computeOutputVectorSizePerLevel();
+
+ /// Dump the shuffle tree configuration.
+ void dump();
+};
+
+VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
+ FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
+ : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
+
+ assert(fromElementsOp && "from_elements op is required");
+ assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+
+ // Duplicate the last vector if the number of `vector.to_elements` is odd to
+ // simplify the shuffle tree algorithm.
+ if (toElementsDefs.size() % 2 != 0) {
+ toElementsDefs.push_back(toElementsDefs.back());
+ }
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Analysis Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the intervals for all the input vectors in the shuffle tree. The
+/// interval of an input vector is the range of positions in the final output
+/// that the input vector contributes to.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the
+/// number of inputs even) so we compute the interval for each input vector:
+///
+/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6]
+/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7]
+/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8]
+/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8]
+///
+/// Level 1 has 2 inputs, resulting from the shuffling of %2 + %1 and %0 + %0 so
+/// we compute the intervals for each input vector to level 1 as:
+/// * inputIntervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7]
+/// * inputIntervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8]
+///
+void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
+ // Map `vector.to_elements` ops to their ordinal position in the
+ // `vector.from_elements` operand list. Make sure duplicated
+ // `vector.to_elements` ops are mapped to the its first occurrence.
+ DenseMap<ToElementsOp, unsigned> toElementsToInputOrdinal;
+ for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs))
+ toElementsToInputOrdinal.insert({toElementsOp, idx});
+
+ // Compute intervals for each input vector in the shuffle tree. The first
+ // level computation is special-cased to keep the implementation simpler.
+
+ SmallVector<Interval> firstLevelIntervals(toElementsDefs.size(),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (const auto &[idx, element] :
+ llvm::enumerate(fromElementsOp.getElements())) {
+ auto toElementsOp = cast<ToElementsOp>(element.getDefiningOp());
+ unsigned inputIdx = toElementsToInputOrdinal[toElementsOp];
+ Interval ¤tInterval = firstLevelIntervals[inputIdx];
+
+ // Set lower bound to the first occurrence of the `vector.to_elements`.
+ if (currentInterval.first == kMaxUnsigned)
+ currentInterval.first = idx;
+
+ // Set upper bound to the last occurrence of the `vector.to_elements`.
+ currentInterval.second = idx;
+ }
+
+ // If the number of `vector.to_elements` is odd and the last op was
+ // duplicated, the interval for the duplicated op was not computed in the
+ // previous step as all the input occurrences were mapped to the original op.
+ // We copy the interval of the original op to the interval of the duplicated
+ // op manually.
+ if (firstLevelIntervals.back().second == kMaxUnsigned)
+ firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2);
+
+ inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals));
+
+ // Compute intervals for the remaining levels.
+ unsigned outputNumElements =
+ cast<VectorType>(fromElementsOp.getResult().getType()).getNumElements();
+ for (unsigned level = 1; level < numLevels; ++level) {
+ const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1];
+ SmallVector<Interval> currentLevelIntervals(
+ llvm::divideCeil(prevLevelIntervals.size(), 2),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (size_t inputIdx = 0; inputIdx < currentLevelIntervals.size();
+ ++inputIdx) {
+ auto &interval = currentLevelIntervals[inputIdx];
+ const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
+ const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
+
+ // The interval of a vector at the current level is the union of the
+ // intervals of the two input vectors from the previous level being
+ // shuffled at this level.
+ interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first);
+ interval.second =
+ std::min(std::max(prevLhsInterval.second, prevRhsInterval.second),
+ outputNumElements - 1);
+ }
+
+ inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals));
+ }
+}
+
+/// Compute the uniform output vector size for each level of the shuffle tree,
+/// given the intervals of the input vectors at that level. The output vector
+/// size of a level is the size of the widest interval resulting from shuffling
+/// each pair of input vectors.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// Intervals:
+/// * Level 0: [0,6], [1,7], [2,8], [2,8]
+/// * Level 1: [0,7], [2,8]
+///
+/// Vector sizes:
+/// * Level 0: max(size_of([0,6] U [1,7] = [0,7]) = 8,
+/// size_of([2,8] U [2,8] = [2,8]) = 7) = 8
+///
+/// * Level 1: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9
+///
+void VectorShuffleTreeBuilder::computeOutputVectorSizePerLevel() {
+ // Compute vector size for each level.
+ for (unsigned level = 0; level < numLevels; ++level) {
+ const auto ¤tLevelIntervals = inputIntervalsPerLevel[level];
+ unsigned currentVectorSize = 1;
+ for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) {
+ const auto &lhsInterval = currentLevelIntervals[i];
+ const auto &rhsInterval = currentLevelIntervals[i + 1];
+ unsigned combinedIntervalSize =
+ std::max(lhsInterval.second, rhsInterval.second) - lhsInterval.first +
+ 1;
+ currentVectorSize = std::max(currentVectorSize, combinedIntervalSize);
+ }
+ vectorSizePerLevel[level] = currentVectorSize;
+ }
+}
+
+void VectorShuffleTreeBuilder::dump() {
+ LLVM_DEBUG({
+ unsigned indLv = 0;
+
+ llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n";
+ ++indLv;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
+ ++indLv;
+ for (const auto &toElementsOp : toElementsDefs)
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n";
+ --indLv;
+
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Total levels: " << numLevels << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Vector sizes per level: [";
+ llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Input intervals per level:\n";
+ ++indLv;
+ for (const auto &[level, intervals] :
+ llvm::enumerate(inputIntervalsPerLevel)) {
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level
+ << ": ";
+ llvm::interleaveComma(intervals, llvm::dbgs(),
+ [](const Interval &interval) {
+ llvm::dbgs() << "[" << interval.first << ","
+ << interval.second << "]";
+ });
+ llvm::dbgs() << "\n";
+ }
+ });
+}
+
+/// Compute the shuffle tree configuration for the given `vector.to_elements` +
+/// `vector.from_elements` input sequence. This method builds a balanced binary
+/// shuffle tree that combines pairs of input vectors at each level.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// build a tree that looks like:
+///
+/// %2 %1 %0 %0
+/// \ / \ /
+/// %2_1 = vector.shuffle %0_0 = vector.shuffle
+/// \ /
+/// %2_1_0_0 =vector.shuffle
+///
+/// The configuration comprises of computing the intervals of the input vectors
+/// at each level of the shuffle tree (i.e., %2, %1, %0, %0, %2_1, %0_0 and
+/// %2_1_0_0) and the output vector size for each level. For further details on
+/// intervals and output vector size computation, please, take a look at the
+/// corresponding utility functions.
+LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
+ // Initialize shuffle tree information based on its size.
+ assert(toElementsDefs.size() > 1 &&
+ "At least two 'vector.to_elements' ops are required");
+ numLevels = llvm::Log2_64(toElementsDefs.size());
+ vectorSizePerLevel.resize(numLevels, 0);
+ inputIntervalsPerLevel.reserve(numLevels);
+
+ computeInputVectorIntervals();
+ computeOutputVectorSizePerLevel();
+ dump();
+
+ return success();
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Code Generation Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the permutation mask for shuffling two input `vector.to_elements`
+/// ops. The permutation mask is the mapping of the input vector elements to
+/// their final position in the output vector, relative to the intermediate
+/// output vector of the `vector.shuffle` operation combining the two inputs.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// =>
+///
+/// // Level 0, vector length = 8
+/// %2_1 = PermutationShuffleMask(%2, %1) = [2,...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR adds a new transformation that turns sequences of Example:
The algorithm leverages the structured extraction/insertion information of There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along. Patch is 62.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145740.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 14cff4ff893b5..6761cd65c5009 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -297,6 +297,13 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
/// n > 1.
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+/// Populate patterns to rewrite sequences of `vector.to_elements` +
+/// `vector.from_elements` operations into a tree of `vector.shuffle`
+/// operations.
+void populateVectorToFromElementsToShuffleTreePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..959c2fbf31f1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..9431a4d8e240f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,4 +34,9 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
];
}
+def LowerVectorToFromElementsToShuffleTree
+ : Pass<"lower-vector-to-from-elements-to-shuffle-tree", "func::FuncOp"> {
+ let summary = "Lower `vector.to_elements` and `vector.from_elements` to a tree of `vector.shuffle` operations";
+}
+
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..9e287fc109990 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorScan.cpp
LowerVectorShapeCast.cpp
LowerVectorStep.cpp
+ LowerVectorToFromElementsToShuffleTree.cpp
LowerVectorTransfer.cpp
LowerVectorTranspose.cpp
SubsetOpInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
new file mode 100644
index 0000000000000..53728d6dbe2a3
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -0,0 +1,692 @@
+//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern rewrites to lower sequences of
+// `vector.to_elements` and `vector.from_elements` operations into a tree of
+// `vector.shuffle` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace vector {
+
+#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+
+} // namespace vector
+} // namespace mlir
+
+#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+// Indentation unit for debug output formatting.
+constexpr unsigned kIndScale = 2;
+
+/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
+using Interval = std::pair<unsigned, unsigned>;
+// Sentinel value for uninitialized intervals.
+constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
+
+/// The VectorShuffleTreeBuilder builds a balanced binary tree of
+/// `vector.shuffle` operations from one or more `vector.to_elements`
+/// operations feeding a single `vector.from_elements` operation.
+///
+/// The implementation generates hardware-agnostic `vector.shuffle` operations
+/// that minimize both the number of shuffle operations and the length of
+/// intermediate vectors (to the extent possible). The tree has the
+/// following properties:
+///
+/// 1. Vectors are shuffled in pairs by order of appearance in
+/// the `vector.from_elements` operand list.
+/// 2. Each input vector to each level is used only once.
+/// 3. The number of levels in the tree is:
+/// ceil(log2(# `vector.to_elements` ops)).
+/// 4. Vectors at each level of the tree have the same vector length.
+/// 5. Vector positions that do not need to be shuffled are represented with
+/// poison in the shuffle mask.
+///
+/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
+///
+/// %0:4 = vector.to_elements %a : vector<4xf32>
+/// %1:4 = vector.to_elements %b : vector<4xf32>
+/// %2:4 = vector.to_elements %c : vector<4xf32>
+/// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
+/// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
+/// : vector<12xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
+/// : vector<4xf32>, vector<4xf32>
+/// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
+/// : vector<4xf32>, vector<4xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
+/// 6, 7, 8, 9, 10, 11]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * The shuffle tree has two levels:
+/// - Level 1 = (%shuffle0, %shuffle1)
+/// - Level 2 = (%result)
+/// * `%a` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%0#0` and `%1#0`).
+/// * `%c` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 16,
+/// respectively.
+/// * `%shuffle1` uses poison values to match the vector length of its
+/// tree level (8).
+///
+///
+/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+/// =>
+///
+/// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
+/// : vector<5xf32>, vector<5xf32>
+/// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
+/// : vector<5xf32>, vector<5xf32>
+/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+/// : vector<8xf32>, vector<8xf32>
+///
+/// Comments:
+/// * `%c` and `%b` are shuffled first because they appear first in the
+/// `vector.from_elements` operand list (`%2#2` and `%1#1`).
+/// * `%a` is shuffled with itself because the number of
+/// `vector.from_elements` operands is odd.
+/// * The vector length for the first and second levels are 8 and 9,
+/// respectively.
+/// * `%shuffle0` uses poison values to mark unused vector positions and
+/// match the vector length of its tree level (8).
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+///
+class VectorShuffleTreeBuilder {
+public:
+ VectorShuffleTreeBuilder() = delete;
+ VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
+ ArrayRef<ToElementsOp> toElemDefs);
+
+ /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
+ /// and compute the shuffle tree configuration. This method does not generate
+ /// any IR.
+ LogicalResult computeShuffleTree();
+
+ /// Materialize the shuffle tree configuration computed by
+ /// `computeShuffleTree` in the IR.
+ Value generateShuffleTree(PatternRewriter &rewriter);
+
+private:
+ // IR input information.
+ FromElementsOp fromElementsOp;
+ SmallVector<ToElementsOp> toElementsDefs;
+
+ // Shuffle tree configuration.
+ unsigned numLevels;
+ SmallVector<unsigned> vectorSizePerLevel;
+ /// Holds the range of positions in the final output that each vector input
+ /// in the tree is contributing to.
+ SmallVector<SmallVector<Interval>> inputIntervalsPerLevel;
+
+ // Utility methods to compute the shuffle tree configuration.
+ void computeInputVectorIntervals();
+ void computeOutputVectorSizePerLevel();
+
+ /// Dump the shuffle tree configuration.
+ void dump();
+};
+
+VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
+ FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
+ : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
+
+ assert(fromElementsOp && "from_elements op is required");
+ assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+
+ // Duplicate the last vector if the number of `vector.to_elements` is odd to
+ // simplify the shuffle tree algorithm.
+ if (toElementsDefs.size() % 2 != 0) {
+ toElementsDefs.push_back(toElementsDefs.back());
+ }
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Analysis Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the intervals for all the input vectors in the shuffle tree. The
+/// interval of an input vector is the range of positions in the final output
+/// that the input vector contributes to.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the
+/// number of inputs even) so we compute the interval for each input vector:
+///
+/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6]
+/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7]
+/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8]
+/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8]
+///
+/// Level 1 has 2 inputs, resulting from the shuffling of %2 + %1 and %0 + %0 so
+/// we compute the intervals for each input vector to level 1 as:
+/// * inputIntervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7]
+/// * inputIntervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8]
+///
+void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
+ // Map `vector.to_elements` ops to their ordinal position in the
+ // `vector.from_elements` operand list. Make sure duplicated
+ // `vector.to_elements` ops are mapped to the its first occurrence.
+ DenseMap<ToElementsOp, unsigned> toElementsToInputOrdinal;
+ for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs))
+ toElementsToInputOrdinal.insert({toElementsOp, idx});
+
+ // Compute intervals for each input vector in the shuffle tree. The first
+ // level computation is special-cased to keep the implementation simpler.
+
+ SmallVector<Interval> firstLevelIntervals(toElementsDefs.size(),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (const auto &[idx, element] :
+ llvm::enumerate(fromElementsOp.getElements())) {
+ auto toElementsOp = cast<ToElementsOp>(element.getDefiningOp());
+ unsigned inputIdx = toElementsToInputOrdinal[toElementsOp];
+ Interval ¤tInterval = firstLevelIntervals[inputIdx];
+
+ // Set lower bound to the first occurrence of the `vector.to_elements`.
+ if (currentInterval.first == kMaxUnsigned)
+ currentInterval.first = idx;
+
+ // Set upper bound to the last occurrence of the `vector.to_elements`.
+ currentInterval.second = idx;
+ }
+
+ // If the number of `vector.to_elements` is odd and the last op was
+ // duplicated, the interval for the duplicated op was not computed in the
+ // previous step as all the input occurrences were mapped to the original op.
+ // We copy the interval of the original op to the interval of the duplicated
+ // op manually.
+ if (firstLevelIntervals.back().second == kMaxUnsigned)
+ firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2);
+
+ inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals));
+
+ // Compute intervals for the remaining levels.
+ unsigned outputNumElements =
+ cast<VectorType>(fromElementsOp.getResult().getType()).getNumElements();
+ for (unsigned level = 1; level < numLevels; ++level) {
+ const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1];
+ SmallVector<Interval> currentLevelIntervals(
+ llvm::divideCeil(prevLevelIntervals.size(), 2),
+ {kMaxUnsigned, kMaxUnsigned});
+
+ for (size_t inputIdx = 0; inputIdx < currentLevelIntervals.size();
+ ++inputIdx) {
+ auto &interval = currentLevelIntervals[inputIdx];
+ const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
+ const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
+
+ // The interval of a vector at the current level is the union of the
+ // intervals of the two input vectors from the previous level being
+ // shuffled at this level.
+ interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first);
+ interval.second =
+ std::min(std::max(prevLhsInterval.second, prevRhsInterval.second),
+ outputNumElements - 1);
+ }
+
+ inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals));
+ }
+}
+
+/// Compute the uniform output vector size for each level of the shuffle tree,
+/// given the intervals of the input vectors at that level. The output vector
+/// size of a level is the size of the widest interval resulting from shuffling
+/// each pair of input vectors.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// Intervals:
+/// * Level 0: [0,6], [1,7], [2,8], [2,8]
+/// * Level 1: [0,7], [2,8]
+///
+/// Vector sizes:
+/// * Level 0: max(size_of([0,6] U [1,7] = [0,7]) = 8,
+/// size_of([2,8] U [2,8] = [2,8]) = 7) = 8
+///
+/// * Level 1: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9
+///
+void VectorShuffleTreeBuilder::computeOutputVectorSizePerLevel() {
+ // Compute vector size for each level.
+ for (unsigned level = 0; level < numLevels; ++level) {
+ const auto ¤tLevelIntervals = inputIntervalsPerLevel[level];
+ unsigned currentVectorSize = 1;
+ for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) {
+ const auto &lhsInterval = currentLevelIntervals[i];
+ const auto &rhsInterval = currentLevelIntervals[i + 1];
+ unsigned combinedIntervalSize =
+ std::max(lhsInterval.second, rhsInterval.second) - lhsInterval.first +
+ 1;
+ currentVectorSize = std::max(currentVectorSize, combinedIntervalSize);
+ }
+ vectorSizePerLevel[level] = currentVectorSize;
+ }
+}
+
+void VectorShuffleTreeBuilder::dump() {
+ LLVM_DEBUG({
+ unsigned indLv = 0;
+
+ llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n";
+ ++indLv;
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
+ ++indLv;
+ for (const auto &toElementsOp : toElementsDefs)
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n";
+ --indLv;
+
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Total levels: " << numLevels << "\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Vector sizes per level: [";
+ llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ llvm::dbgs() << llvm::indent(indLv, kIndScale)
+ << "* Input intervals per level:\n";
+ ++indLv;
+ for (const auto &[level, intervals] :
+ llvm::enumerate(inputIntervalsPerLevel)) {
+ llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level
+ << ": ";
+ llvm::interleaveComma(intervals, llvm::dbgs(),
+ [](const Interval &interval) {
+ llvm::dbgs() << "[" << interval.first << ","
+ << interval.second << "]";
+ });
+ llvm::dbgs() << "\n";
+ }
+ });
+}
+
+/// Compute the shuffle tree configuration for the given `vector.to_elements` +
+/// `vector.from_elements` input sequence. This method builds a balanced binary
+/// shuffle tree that combines pairs of input vectors at each level.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// build a tree that looks like:
+///
+/// %2 %1 %0 %0
+/// \ / \ /
+/// %2_1 = vector.shuffle %0_0 = vector.shuffle
+/// \ /
+/// %2_1_0_0 =vector.shuffle
+///
+/// The configuration comprises of computing the intervals of the input vectors
+/// at each level of the shuffle tree (i.e., %2, %1, %0, %0, %2_1, %0_0 and
+/// %2_1_0_0) and the output vector size for each level. For further details on
+/// intervals and output vector size computation, please, take a look at the
+/// corresponding utility functions.
+LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
+ // Initialize shuffle tree information based on its size.
+ assert(toElementsDefs.size() > 1 &&
+ "At least two 'vector.to_elements' ops are required");
+ numLevels = llvm::Log2_64(toElementsDefs.size());
+ vectorSizePerLevel.resize(numLevels, 0);
+ inputIntervalsPerLevel.reserve(numLevels);
+
+ computeInputVectorIntervals();
+ computeOutputVectorSizePerLevel();
+ dump();
+
+ return success();
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Code Generation Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the permutation mask for shuffling two input `vector.to_elements`
+/// ops. The permutation mask is the mapping of the input vector elements to
+/// their final position in the output vector, relative to the intermediate
+/// output vector of the `vector.shuffle` operation combining the two inputs.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+/// %0:5 = vector.to_elements %a : vector<5xf32>
+/// %1:5 = vector.to_elements %b : vector<5xf32>
+/// %2:5 = vector.to_elements %c : vector<5xf32>
+/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// =>
+///
+/// // Level 0, vector length = 8
+/// %2_1 = PermutationShuffleMask(%2, %1) = [2,...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the nice documention! I think I get the basic idea, but I need to spend some more time getting into the details. Possible edge case to test out:
func.func @foo(
%a : vector<2xf32>,
%b : vector<1xf32>,
%c : vector<f32>,
%d : vector<f32>,
%e : vector<f32>) -> vector<6xf32> {
%0:2 = vector.to_elements %a : vector<2xf32>
%1:1 = vector.to_elements %b : vector<1xf32>
%2:1 = vector.to_elements %c : vector<f32>
%3:1 = vector.to_elements %d : vector<f32>
%4:1 = vector.to_elements %e : vector<f32>
%5 = vector.from_elements %0#0, %1#0, %2#0, %3#0, %4#0, %0#1 : vector<6xf32>
return %5 : vector<6xf32>
}
LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp, | ||
PatternRewriter &rewriter) const override { | ||
VectorType resultType = fromElementsOp.getType(); | ||
if (resultType.getRank() != 1 || resultType.isScalable()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this PR, but this rank check got me wondering. I would like to propose removing the implicit abillity to do a shape_cast out of vector.to_elements
and vector.from_elements
operations, so that they must act on rank-1 vectors. Actually I've thought this before for other Vector ops that do reshape-like things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check here comes from the limitations of vector.shape
to represent n-D shuffles, not really from the vector.to_/from_elements
. That limitation is actually more like a TODO that we should address at some point.
vector.to_/from_elements
semantics naturally extend to n-D vectors given the extraction/insertion order they define but, yes, I guess we could see it as an "implicit shape cast"...
I think, though, we've been moving towards the opposite direction. To have a cohesive multi-dimensional vector layer we need these "implicit shape casts" so that ops work nicely across the board without having to special-case 1-D from n-D... This supports even more the idea that shape casts are really no-ops...
|
||
// Duplicate the last vector if the number of `vector.to_elements` is odd to | ||
// simplify the shuffle tree algorithm. | ||
if (toElementsDefs.size() % 2 != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a check that it is a power of 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check refers to the number of vector.to_elements
inputs to combine so we want to be able to combine an arbitrary number of inputs. If that number is not even, we duplicate the las input to simplify the algorithm (the shuffle for that input would have the same input vector twice). Does it make sense?
++inputIdx) { | ||
auto &interval = currentLevelIntervals[inputIdx]; | ||
const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2]; | ||
const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to power-of-2 comment: If previous level here had 3 intervals, current level has 2. If inputIdx = 1 here, you're accessing index 3 of previous intervals -- problem? That's why I think it might be necessary to ensure the number of starting intervals is a power of 2 (stricter than just being even).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I thought I had a check to duplicate the last input, similar to the one in the constructor, but I must have removed it at some point. Let me fix that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! Happy to clarify any questions you may have!
|
||
// Duplicate the last vector if the number of `vector.to_elements` is odd to | ||
// simplify the shuffle tree algorithm. | ||
if (toElementsDefs.size() % 2 != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check refers to the number of vector.to_elements
inputs to combine so we want to be able to combine an arbitrary number of inputs. If that number is not even, we duplicate the las input to simplify the algorithm (the shuffle for that input would have the same input vector twice). Does it make sense?
++inputIdx) { | ||
auto &interval = currentLevelIntervals[inputIdx]; | ||
const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2]; | ||
const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I thought I had a check to duplicate the last input, similar to the one in the constructor, but I must have removed it at some point. Let me fix that.
LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp, | ||
PatternRewriter &rewriter) const override { | ||
VectorType resultType = fromElementsOp.getType(); | ||
if (resultType.getRank() != 1 || resultType.isScalable()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check here comes from the limitations of vector.shape
to represent n-D shuffles, not really from the vector.to_/from_elements
. That limitation is actually more like a TODO that we should address at some point.
vector.to_/from_elements
semantics naturally extend to n-D vectors given the extraction/insertion order they define but, yes, I guess we could see it as an "implicit shape cast"...
I think, though, we've been moving towards the opposite direction. To have a cohesive multi-dimensional vector layer we need these "implicit shape casts" so that ops work nicely across the board without having to special-case 1-D from n-D... This supports even more the idea that shape casts are really no-ops...
This PR adds a new transformation that turns sequences of
vector.to_elements
andvector.from_elements
into a binary tree ofvector.shuffle
operations.(Related RFC: https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779).
Example:
The algorithm leverages the structured extraction/insertion information of
vector.to_elements
andvector.from_elements
operations and builds a set of intervals to determine the vector length that should be used at each level of the tree to combine the level inputs in pairs.There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along.