diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 14cff4ff893b5..6761cd65c5009 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -297,6 +297,13 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns, /// n > 1. void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns); +/// Populate patterns to rewrite sequences of `vector.to_elements` + +/// `vector.from_elements` operations into a tree of `vector.shuffle` +/// operations. +void populateVectorToFromElementsToShuffleTreePatterns( + RewritePatternSet &patterns, PatternBenefit benefit = 1); + } // namespace vector } // namespace mlir + #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h index 5667f4fa95ace..959c2fbf31f1a 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_ #define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_ +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td index 7436998749791..9431a4d8e240f 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td @@ -34,4 +34,9 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func ]; } +def LowerVectorToFromElementsToShuffleTree + : Pass<"lower-vector-to-from-elements-to-shuffle-tree", "func::FuncOp"> { + let summary = "Lower `vector.to_elements` and `vector.from_elements` to a tree of `vector.shuffle` operations"; +} + #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 8ca5cb6c6dfab..9e287fc109990 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorScan.cpp LowerVectorShapeCast.cpp LowerVectorStep.cpp + LowerVectorToFromElementsToShuffleTree.cpp LowerVectorTransfer.cpp LowerVectorTranspose.cpp SubsetOpInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp new file mode 100644 index 0000000000000..53728d6dbe2a3 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -0,0 +1,692 @@ +//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements pattern rewrites to lower sequences of +// `vector.to_elements` and `vector.from_elements` operations into a tree of +// `vector.shuffle` operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace vector { + +#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE +#include "mlir/Dialect/Vector/Transforms/Passes.h.inc" + +} // namespace vector +} // namespace mlir + +#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree" + +using namespace mlir; +using namespace mlir::vector; + +namespace { + +// Indentation unit for debug output formatting. +constexpr unsigned kIndScale = 2; + +/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements). +using Interval = std::pair; +// Sentinel value for uninitialized intervals. +constexpr unsigned kMaxUnsigned = std::numeric_limits::max(); + +/// The VectorShuffleTreeBuilder builds a balanced binary tree of +/// `vector.shuffle` operations from one or more `vector.to_elements` +/// operations feeding a single `vector.from_elements` operation. +/// +/// The implementation generates hardware-agnostic `vector.shuffle` operations +/// that minimize both the number of shuffle operations and the length of +/// intermediate vectors (to the extent possible). The tree has the +/// following properties: +/// +/// 1. Vectors are shuffled in pairs by order of appearance in +/// the `vector.from_elements` operand list. +/// 2. Each input vector to each level is used only once. +/// 3. The number of levels in the tree is: +/// ceil(log2(# `vector.to_elements` ops)). +/// 4. Vectors at each level of the tree have the same vector length. +/// 5. Vector positions that do not need to be shuffled are represented with +/// poison in the shuffle mask. +/// +/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>: +/// +/// %0:4 = vector.to_elements %a : vector<4xf32> +/// %1:4 = vector.to_elements %b : vector<4xf32> +/// %2:4 = vector.to_elements %c : vector<4xf32> +/// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, +/// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3 +/// : vector<12xf32> +/// => +/// +/// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7] +/// : vector<4xf32>, vector<4xf32> +/// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1] +/// : vector<4xf32>, vector<4xf32> +/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5, +/// 6, 7, 8, 9, 10, 11] +/// : vector<8xf32>, vector<8xf32> +/// +/// Comments: +/// * The shuffle tree has two levels: +/// - Level 1 = (%shuffle0, %shuffle1) +/// - Level 2 = (%result) +/// * `%a` and `%b` are shuffled first because they appear first in the +/// `vector.from_elements` operand list (`%0#0` and `%1#0`). +/// * `%c` is shuffled with itself because the number of +/// `vector.from_elements` operands is odd. +/// * The vector length for the first and second levels are 8 and 16, +/// respectively. +/// * `%shuffle1` uses poison values to match the vector length of its +/// tree level (8). +/// +/// +/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// %0:5 = vector.to_elements %a : vector<5xf32> +/// %1:5 = vector.to_elements %b : vector<5xf32> +/// %2:5 = vector.to_elements %c : vector<5xf32> +/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, +/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> +/// => +/// +/// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6] +/// : vector<5xf32>, vector<5xf32> +/// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1] +/// : vector<5xf32>, vector<5xf32> +/// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14] +/// : vector<8xf32>, vector<8xf32> +/// +/// Comments: +/// * `%c` and `%b` are shuffled first because they appear first in the +/// `vector.from_elements` operand list (`%2#2` and `%1#1`). +/// * `%a` is shuffled with itself because the number of +/// `vector.from_elements` operands is odd. +/// * The vector length for the first and second levels are 8 and 9, +/// respectively. +/// * `%shuffle0` uses poison values to mark unused vector positions and +/// match the vector length of its tree level (8). +/// +/// TODO: Implement mask compression to reduce the number of intermediate poison +/// values. +/// +class VectorShuffleTreeBuilder { +public: + VectorShuffleTreeBuilder() = delete; + VectorShuffleTreeBuilder(FromElementsOp fromElemOp, + ArrayRef toElemDefs); + + /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence + /// and compute the shuffle tree configuration. This method does not generate + /// any IR. + LogicalResult computeShuffleTree(); + + /// Materialize the shuffle tree configuration computed by + /// `computeShuffleTree` in the IR. + Value generateShuffleTree(PatternRewriter &rewriter); + +private: + // IR input information. + FromElementsOp fromElementsOp; + SmallVector toElementsDefs; + + // Shuffle tree configuration. + unsigned numLevels; + SmallVector vectorSizePerLevel; + /// Holds the range of positions in the final output that each vector input + /// in the tree is contributing to. + SmallVector> inputIntervalsPerLevel; + + // Utility methods to compute the shuffle tree configuration. + void computeInputVectorIntervals(); + void computeOutputVectorSizePerLevel(); + + /// Dump the shuffle tree configuration. + void dump(); +}; + +VectorShuffleTreeBuilder::VectorShuffleTreeBuilder( + FromElementsOp fromElemOp, ArrayRef toElemDefs) + : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) { + + assert(fromElementsOp && "from_elements op is required"); + assert(!toElementsDefs.empty() && "At least one to_elements op is required"); + + // Duplicate the last vector if the number of `vector.to_elements` is odd to + // simplify the shuffle tree algorithm. + if (toElementsDefs.size() % 2 != 0) { + toElementsDefs.push_back(toElementsDefs.back()); + } +} + +// ===--------------------------------------------------------------------===// +// Shuffle Tree Analysis Utilities. +// ===--------------------------------------------------------------------===// + +/// Compute the intervals for all the input vectors in the shuffle tree. The +/// interval of an input vector is the range of positions in the final output +/// that the input vector contributes to. +/// +/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// %0:5 = vector.to_elements %a : vector<5xf32> +/// %1:5 = vector.to_elements %b : vector<5xf32> +/// %2:5 = vector.to_elements %c : vector<5xf32> +/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, +/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> +/// +/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the +/// number of inputs even) so we compute the interval for each input vector: +/// +/// * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6] +/// * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7] +/// * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8] +/// * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8] +/// +/// Level 1 has 2 inputs, resulting from the shuffling of %2 + %1 and %0 + %0 so +/// we compute the intervals for each input vector to level 1 as: +/// * inputIntervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7] +/// * inputIntervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8] +/// +void VectorShuffleTreeBuilder::computeInputVectorIntervals() { + // Map `vector.to_elements` ops to their ordinal position in the + // `vector.from_elements` operand list. Make sure duplicated + // `vector.to_elements` ops are mapped to the its first occurrence. + DenseMap toElementsToInputOrdinal; + for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs)) + toElementsToInputOrdinal.insert({toElementsOp, idx}); + + // Compute intervals for each input vector in the shuffle tree. The first + // level computation is special-cased to keep the implementation simpler. + + SmallVector firstLevelIntervals(toElementsDefs.size(), + {kMaxUnsigned, kMaxUnsigned}); + + for (const auto &[idx, element] : + llvm::enumerate(fromElementsOp.getElements())) { + auto toElementsOp = cast(element.getDefiningOp()); + unsigned inputIdx = toElementsToInputOrdinal[toElementsOp]; + Interval ¤tInterval = firstLevelIntervals[inputIdx]; + + // Set lower bound to the first occurrence of the `vector.to_elements`. + if (currentInterval.first == kMaxUnsigned) + currentInterval.first = idx; + + // Set upper bound to the last occurrence of the `vector.to_elements`. + currentInterval.second = idx; + } + + // If the number of `vector.to_elements` is odd and the last op was + // duplicated, the interval for the duplicated op was not computed in the + // previous step as all the input occurrences were mapped to the original op. + // We copy the interval of the original op to the interval of the duplicated + // op manually. + if (firstLevelIntervals.back().second == kMaxUnsigned) + firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2); + + inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals)); + + // Compute intervals for the remaining levels. + unsigned outputNumElements = + cast(fromElementsOp.getResult().getType()).getNumElements(); + for (unsigned level = 1; level < numLevels; ++level) { + const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1]; + SmallVector currentLevelIntervals( + llvm::divideCeil(prevLevelIntervals.size(), 2), + {kMaxUnsigned, kMaxUnsigned}); + + for (size_t inputIdx = 0; inputIdx < currentLevelIntervals.size(); + ++inputIdx) { + auto &interval = currentLevelIntervals[inputIdx]; + const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2]; + const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1]; + + // The interval of a vector at the current level is the union of the + // intervals of the two input vectors from the previous level being + // shuffled at this level. + interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first); + interval.second = + std::min(std::max(prevLhsInterval.second, prevRhsInterval.second), + outputNumElements - 1); + } + + inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals)); + } +} + +/// Compute the uniform output vector size for each level of the shuffle tree, +/// given the intervals of the input vectors at that level. The output vector +/// size of a level is the size of the widest interval resulting from shuffling +/// each pair of input vectors. +/// +/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// Intervals: +/// * Level 0: [0,6], [1,7], [2,8], [2,8] +/// * Level 1: [0,7], [2,8] +/// +/// Vector sizes: +/// * Level 0: max(size_of([0,6] U [1,7] = [0,7]) = 8, +/// size_of([2,8] U [2,8] = [2,8]) = 7) = 8 +/// +/// * Level 1: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9 +/// +void VectorShuffleTreeBuilder::computeOutputVectorSizePerLevel() { + // Compute vector size for each level. + for (unsigned level = 0; level < numLevels; ++level) { + const auto ¤tLevelIntervals = inputIntervalsPerLevel[level]; + unsigned currentVectorSize = 1; + for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) { + const auto &lhsInterval = currentLevelIntervals[i]; + const auto &rhsInterval = currentLevelIntervals[i + 1]; + unsigned combinedIntervalSize = + std::max(lhsInterval.second, rhsInterval.second) - lhsInterval.first + + 1; + currentVectorSize = std::max(currentVectorSize, combinedIntervalSize); + } + vectorSizePerLevel[level] = currentVectorSize; + } +} + +void VectorShuffleTreeBuilder::dump() { + LLVM_DEBUG({ + unsigned indLv = 0; + + llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n"; + ++indLv; + llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n"; + ++indLv; + for (const auto &toElementsOp : toElementsDefs) + llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n"; + llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n"; + --indLv; + + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* Total levels: " << numLevels << "\n"; + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* Vector sizes per level: ["; + llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs()); + llvm::dbgs() << "]\n"; + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* Input intervals per level:\n"; + ++indLv; + for (const auto &[level, intervals] : + llvm::enumerate(inputIntervalsPerLevel)) { + llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level + << ": "; + llvm::interleaveComma(intervals, llvm::dbgs(), + [](const Interval &interval) { + llvm::dbgs() << "[" << interval.first << "," + << interval.second << "]"; + }); + llvm::dbgs() << "\n"; + } + }); +} + +/// Compute the shuffle tree configuration for the given `vector.to_elements` + +/// `vector.from_elements` input sequence. This method builds a balanced binary +/// shuffle tree that combines pairs of input vectors at each level. +/// +/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// %0:5 = vector.to_elements %a : vector<5xf32> +/// %1:5 = vector.to_elements %b : vector<5xf32> +/// %2:5 = vector.to_elements %c : vector<5xf32> +/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, +/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> +/// +/// build a tree that looks like: +/// +/// %2 %1 %0 %0 +/// \ / \ / +/// %2_1 = vector.shuffle %0_0 = vector.shuffle +/// \ / +/// %2_1_0_0 =vector.shuffle +/// +/// The configuration comprises of computing the intervals of the input vectors +/// at each level of the shuffle tree (i.e., %2, %1, %0, %0, %2_1, %0_0 and +/// %2_1_0_0) and the output vector size for each level. For further details on +/// intervals and output vector size computation, please, take a look at the +/// corresponding utility functions. +LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() { + // Initialize shuffle tree information based on its size. + assert(toElementsDefs.size() > 1 && + "At least two 'vector.to_elements' ops are required"); + numLevels = llvm::Log2_64(toElementsDefs.size()); + vectorSizePerLevel.resize(numLevels, 0); + inputIntervalsPerLevel.reserve(numLevels); + + computeInputVectorIntervals(); + computeOutputVectorSizePerLevel(); + dump(); + + return success(); +} + +// ===--------------------------------------------------------------------===// +// Shuffle Tree Code Generation Utilities. +// ===--------------------------------------------------------------------===// + +/// Compute the permutation mask for shuffling two input `vector.to_elements` +/// ops. The permutation mask is the mapping of the input vector elements to +/// their final position in the output vector, relative to the intermediate +/// output vector of the `vector.shuffle` operation combining the two inputs. +/// +/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// %0:5 = vector.to_elements %a : vector<5xf32> +/// %1:5 = vector.to_elements %b : vector<5xf32> +/// %2:5 = vector.to_elements %c : vector<5xf32> +/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, +/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> +/// +/// => +/// +/// // Level 0, vector length = 8 +/// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6] +/// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1] +/// +/// TODO: Implement mask compression. +static SmallVector computePermutationShuffleMask( + ToElementsOp toElementOp0, const Interval &interval0, + ToElementsOp toElementOp1, const Interval &interval1, + FromElementsOp fromElementsOp, unsigned outputVectorSize) { + SmallVector mask(outputVectorSize, ShuffleOp::kPoisonIndex); + unsigned inputVectorSize = + toElementOp0.getSource().getType().getNumElements(); + + for (const auto &[inputIdx, element] : + llvm::enumerate(fromElementsOp.getElements())) { + auto currentToElemOp = cast(element.getDefiningOp()); + // Match `vector.from_elements` operands to the two input ops. + if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1) + continue; + + // The permutation value for a particular operand is the ordinal position of + // the operand in the `vector.to_elements` list of results. + unsigned permVal = cast(element).getResultNumber(); + unsigned maskIdx = inputIdx; + + // The mask index is the ordinal position of the operand in + // `vector.from_elements` operand list. We make this position relative to + // the interval of the output vector resulting from combining the two + // input vectors. + if (currentToElemOp == toElementOp0) { + maskIdx -= interval0.first; + } else { + // currentToElemOp == toElementOp1 + unsigned intervalOffset = interval1.first - interval0.first; + maskIdx += intervalOffset - interval1.first; + permVal += inputVectorSize; + } + + mask[maskIdx] = permVal; + } + + LLVM_DEBUG({ + unsigned indLv = 1; + llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Permutation mask: ["; + llvm::interleaveComma(mask, llvm::dbgs()); + llvm::dbgs() << "]\n"; + ++indLv; + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* Combining: " << toElementOp0 << " and " << toElementOp1 + << "\n"; + }); + + return mask; +} + +/// Compute the propagation shuffle mask for combining two intermediate shuffle +/// operations of the tree. The propagation shuffle mask is the mapping of the +/// intermediate vector elements, which have already been shuffled to their +/// relative output position using the mask generated by +/// `computePermutationShuffleMask`, to their next position in the tree. +/// +/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// %0:5 = vector.to_elements %a : vector<5xf32> +/// %1:5 = vector.to_elements %b : vector<5xf32> +/// %2:5 = vector.to_elements %c : vector<5xf32> +/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, +/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> +/// +/// // Level 0, vector length = 8 +/// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6] +/// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1] +/// +/// => +/// +/// // Level 1, vector length = 9 +/// PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14] +/// +/// TODO: Implement mask compression. +/// +static SmallVector computePropagationShuffleMask( + ShuffleOp lhsShuffleOp, const Interval &lhsInterval, ShuffleOp rhsShuffleOp, + const Interval &rhsInterval, unsigned outputVectorSize) { + ArrayRef lhsShuffleMask = lhsShuffleOp.getMask(); + ArrayRef rhsShuffleMask = rhsShuffleOp.getMask(); + unsigned inputVectorSize = lhsShuffleMask.size(); + assert(inputVectorSize == rhsShuffleMask.size() && + "Expected both shuffle masks to have the same size"); + + unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first; + SmallVector mask(outputVectorSize, ShuffleOp::kPoisonIndex); + + // Propagate any element from the input mask that is not poison. For the RHS + // input vector, the mask index is offset by the offset between the two + // intervals of the input vectors. + for (unsigned i = 0; i < inputVectorSize; ++i) { + if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex) + mask[i] = i; + + unsigned rhsIdx = i + lhsRhsOffset; + if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) { + assert(rhsIdx < outputVectorSize && "RHS index out of bounds"); + assert(mask[rhsIdx] == ShuffleOp::kPoisonIndex && "mask already set"); + mask[rhsIdx] = i + inputVectorSize; + } + } + + LLVM_DEBUG({ + unsigned indLv = 1; + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* Propagation shuffle mask computation:\n"; + ++indLv; + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* LHS shuffle op: " << lhsShuffleOp << "\n"; + llvm::dbgs() << llvm::indent(indLv, kIndScale) + << "* RHS shuffle op: " << rhsShuffleOp << "\n"; + llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Result mask: ["; + llvm::interleaveComma(mask, llvm::dbgs()); + llvm::dbgs() << "]\n"; + }); + + return mask; +} + +/// Materialize the pre-computed shuffle tree configuration in the IR by +/// generating the corresponding `vector.shuffle` ops. +/// +/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: +/// +/// %0:5 = vector.to_elements %a : vector<5xf32> +/// %1:5 = vector.to_elements %b : vector<5xf32> +/// %2:5 = vector.to_elements %c : vector<5xf32> +/// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, +/// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> +/// +/// with the pre-computed shuffle tree configuration: +/// +/// * Vector sizes per level: [8, 9] +/// * Input intervals per level: +/// * Level 0: [0,6], [1,7], [2,8], [2,8] +/// * Level 1: [0,7], [2,8] +/// +/// => +/// +/// %0 = vector.shuffle %arg2, %arg1 [2, 6, -1, -1, 7, 2, 0, 6] +/// : vector<5xf32>, vector<5xf32> +/// %1 = vector.shuffle %arg0, %arg0 [1, 1, -1, -1, -1, -1, 4, -1] +/// : vector<5xf32>, vector<5xf32> +/// %2 = vector.shuffle %0, %1 [0, 1, 8, 9, 4, 5, 6, 7, 14] +/// : vector<8xf32>, vector<8xf32> +/// +/// The code generation comprises of combining pairs of input vectors for each +/// level of the tree, using the pre-computed per tree level intervals and +/// vector sizes. The algorithm generates two kinds of shuffle masks: +/// permutation masks and propagation masks. Permutation masks are computed for +/// the first level of the tree and permute the input vector elements to their +/// relative position in the final output. Propagation masks are computed for +/// subsequent levels and propagate the elements to the next level without +/// permutation. For further details on the shuffle mask computation, please, +/// take a look at the corresponding `computePermutationShuffleMask` and +/// `computePropagationShuffleMask` functions. +/// +Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) { + LLVM_DEBUG(llvm::dbgs() << "VectorShuffleTreeBuilder Code Generation:\n"); + + // Initialize work list with the `vector.to_elements` sources. + SmallVector levelInputs; + llvm::transform( + toElementsDefs, std::back_inserter(levelInputs), + [](ToElementsOp toElementsOp) { return toElementsOp.getSource(); }); + + // Build shuffle tree by combining pairs of vectors. + Location loc = fromElementsOp.getLoc(); + unsigned currentLevel = 0; + for (const auto &[levelVectorSize, inputIntervals] : + llvm::zip_equal(vectorSizePerLevel, inputIntervalsPerLevel)) { + LLVM_DEBUG(llvm::dbgs() + << llvm::indent(1, kIndScale) << "* Processing level " + << currentLevel << " (vector size: " << levelVectorSize + << ", # inputs: " << levelInputs.size() << ")\n"); + + // Process level input vectors in pairs. + SmallVector levelOutputs; + for (size_t i = 0; i < levelInputs.size(); i += 2) { + Value lhsVector = levelInputs[i]; + Value rhsVector = levelInputs[i + 1]; + const Interval &lhsInterval = inputIntervals[i]; + const Interval &rhsInterval = inputIntervals[i + 1]; + + // For the first level of the tree, permute the vector elements to their + // relative position in the final output. For subsequent levels, we + // propagate the elements to the next level without permutation. + SmallVector shuffleMask; + if (currentLevel == 0) { + shuffleMask = computePermutationShuffleMask( + toElementsDefs[i], lhsInterval, toElementsDefs[i + 1], rhsInterval, + fromElementsOp, levelVectorSize); + } else { + auto lhsShuffleOp = cast(lhsVector.getDefiningOp()); + auto rhsShuffleOp = cast(rhsVector.getDefiningOp()); + shuffleMask = computePropagationShuffleMask(lhsShuffleOp, lhsInterval, + rhsShuffleOp, rhsInterval, + levelVectorSize); + } + + Value shuffleVal = rewriter.create( + loc, lhsVector, rhsVector, shuffleMask); + levelOutputs.push_back(shuffleVal); + } + + levelInputs = std::move(levelOutputs); + ++currentLevel; + } + + assert(levelInputs.size() == 1 && "Should have exactly one result"); + return levelInputs.front(); +} + +/// Gather and unique all the `vector.to_elements` operations that feed the +/// `vector.from_elements` operation. The `vector.to_elements` operations are +/// returned in order of appearance in the `vector.from_elements`'s operand +/// list. +static LogicalResult +getToElementsDefiningOps(FromElementsOp fromElementsOp, + SmallVectorImpl &toElementsDefs) { + SetVector toElementsDefsSet; + for (Value element : fromElementsOp.getElements()) { + auto toElementsOp = element.getDefiningOp(); + if (!toElementsOp) + return failure(); + toElementsDefsSet.insert(toElementsOp); + } + + toElementsDefs.assign(toElementsDefsSet.begin(), toElementsDefsSet.end()); + return success(); +} + +/// Pass to rewrite `vector.to_elements` + `vector.from_elements` sequences into +/// a tree of `vector.shuffle` operations. +struct ToFromElementsToShuffleTreeRewrite final + : OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp, + PatternRewriter &rewriter) const override { + VectorType resultType = fromElementsOp.getType(); + if (resultType.getRank() != 1 || resultType.isScalable()) + return failure(); + + SmallVector toElementsDefs; + if (failed(getToElementsDefiningOps(fromElementsOp, toElementsDefs))) + return failure(); + + // Avoid generating a shuffle tree for trivial `vector.to_elements` -> + // `vector.from_elements` forwarding cases that do not require shuffling. + if (toElementsDefs.size() == 1) { + ToElementsOp toElementsOp0 = toElementsDefs.front(); + if (llvm::equal(fromElementsOp.getElements(), toElementsOp0.getResults())) + return failure(); + } + + VectorShuffleTreeBuilder shuffleTreeBuilder(fromElementsOp, toElementsDefs); + if (failed(shuffleTreeBuilder.computeShuffleTree())) + return failure(); + + Value finalShuffle = shuffleTreeBuilder.generateShuffleTree(rewriter); + rewriter.replaceOp(fromElementsOp, finalShuffle); + return success(); + } +}; + +struct LowerVectorToFromElementsToShuffleTreePass + : public vector::impl::LowerVectorToFromElementsToShuffleTreeBase< + LowerVectorToFromElementsToShuffleTreePass> { + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorToFromElementsToShuffleTreePatterns(patterns); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), + benefit); +} diff --git a/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir new file mode 100644 index 0000000000000..3dc579be12f0f --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir @@ -0,0 +1,329 @@ +// RUN: mlir-opt -lower-vector-to-from-elements-to-shuffle-tree -split-input-file %s | FileCheck %s + +// Captured variable names for `vector.shuffle` operations follow the L#SH# convention, +// where L# refers to the level of the tree the shuffle belongs to, and SH# refers to +// the shuffle index within that level. + +func.func @to_from_elements_single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> { + %0:8 = vector.to_elements %a : vector<8xf32> + %1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32> + return %1 : vector<8xf32> +} + +// CHECK-LABEL: func @to_from_elements_single_input_shuffle( +// CHECK-SAME: %[[A:.*]]: vector<8xf32> + // CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[A]] [7, 0, 6, 1, 5, 2, 4, 3] : vector<8xf32>, vector<8xf32> + // CHECK: return %[[L0SH0]] + +// ----- + +func.func @from_elements_to_elements_single_shuffle(%a: vector<8xf32>, + %b: vector<8xf32>) -> vector<8xf32> { + %0:8 = vector.to_elements %a : vector<8xf32> + %1:8 = vector.to_elements %b : vector<8xf32> + %2 = vector.from_elements %0#7, %1#0, %0#6, %1#1, %0#5, %1#2, %0#4, %1#3 : vector<8xf32> + return %2 : vector<8xf32> +} + +// CHECK-LABEL: func @from_elements_to_elements_single_shuffle( +// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [7, 8, 6, 9, 5, 10, 4, 11] : vector<8xf32> +// CHECK: return %[[L0SH0]] + +// ----- + +func.func @to_from_elements_shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>, + %b: vector<8xf32>, + %c: vector<8xf32>, + %d: vector<8xf32>) -> vector<32xf32> { + %0:8 = vector.to_elements %a : vector<8xf32> + %1:8 = vector.to_elements %b : vector<8xf32> + %2:8 = vector.to_elements %c : vector<8xf32> + %3:8 = vector.to_elements %d : vector<8xf32> + %4 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, + %1#0, %1#1, %1#2, %1#3, %1#4, %1#5, %1#6, %1#7, + %2#0, %2#1, %2#2, %2#3, %2#4, %2#5, %2#6, %2#7, + %3#0, %3#1, %3#2, %3#3, %3#4, %3#5, %3#6, %3#7 : vector<32xf32> + return %4 : vector<32xf32> +} + +// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_4x8_to_32( +// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>, %[[C:.*]]: vector<8xf32>, %[[D:.*]]: vector<8xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: return %[[L1SH0]] : vector<32xf32> + +// ----- + +func.func @to_from_elements_shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>, + %b: vector<4xf32>, + %c: vector<4xf32>) -> vector<12xf32> { + %0:4 = vector.to_elements %a : vector<4xf32> + %1:4 = vector.to_elements %b : vector<4xf32> + %2:4 = vector.to_elements %c : vector<4xf32> + %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %2#2, %2#3 : vector<12xf32> + return %3 : vector<12xf32> +} + +// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_3x4_to_12( +// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> +// CHECK: return %[[L1SH0]] : vector<12xf32> + +// ----- + +func.func @to_from_elements_shuffle_tree_concat_64x4_256( + %a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>, %d: vector<4xf32>, + %e: vector<4xf32>, %f: vector<4xf32>, %g: vector<4xf32>, %h: vector<4xf32>, + %i: vector<4xf32>, %j: vector<4xf32>, %k: vector<4xf32>, %l: vector<4xf32>, + %m: vector<4xf32>, %n: vector<4xf32>, %o: vector<4xf32>, %p: vector<4xf32>, + %q: vector<4xf32>, %r: vector<4xf32>, %s: vector<4xf32>, %t: vector<4xf32>, + %u: vector<4xf32>, %v: vector<4xf32>, %w: vector<4xf32>, %x: vector<4xf32>, + %y: vector<4xf32>, %z: vector<4xf32>, %aa: vector<4xf32>, %ab: vector<4xf32>, + %ac: vector<4xf32>, %ad: vector<4xf32>, %ae: vector<4xf32>, %af: vector<4xf32>, + %ag: vector<4xf32>, %ah: vector<4xf32>, %ai: vector<4xf32>, %aj: vector<4xf32>, + %ak: vector<4xf32>, %al: vector<4xf32>, %am: vector<4xf32>, %an: vector<4xf32>, + %ao: vector<4xf32>, %ap: vector<4xf32>, %aq: vector<4xf32>, %ar: vector<4xf32>, + %as: vector<4xf32>, %at: vector<4xf32>, %au: vector<4xf32>, %av: vector<4xf32>, + %aw: vector<4xf32>, %ax: vector<4xf32>, %ay: vector<4xf32>, %az: vector<4xf32>, + %ba: vector<4xf32>, %bb: vector<4xf32>, %bc: vector<4xf32>, %bd: vector<4xf32>, + %be: vector<4xf32>, %bf: vector<4xf32>, %bg: vector<4xf32>, %bh: vector<4xf32>, + %bi: vector<4xf32>, %bj: vector<4xf32>, %bk: vector<4xf32>, %bl: vector<4xf32>) -> vector<256xf32> { + %0:4 = vector.to_elements %a : vector<4xf32> + %1:4 = vector.to_elements %b : vector<4xf32> + %2:4 = vector.to_elements %c : vector<4xf32> + %3:4 = vector.to_elements %d : vector<4xf32> + %4:4 = vector.to_elements %e : vector<4xf32> + %5:4 = vector.to_elements %f : vector<4xf32> + %6:4 = vector.to_elements %g : vector<4xf32> + %7:4 = vector.to_elements %h : vector<4xf32> + %8:4 = vector.to_elements %i : vector<4xf32> + %9:4 = vector.to_elements %j : vector<4xf32> + %10:4 = vector.to_elements %k : vector<4xf32> + %11:4 = vector.to_elements %l : vector<4xf32> + %12:4 = vector.to_elements %m : vector<4xf32> + %13:4 = vector.to_elements %n : vector<4xf32> + %14:4 = vector.to_elements %o : vector<4xf32> + %15:4 = vector.to_elements %p : vector<4xf32> + %16:4 = vector.to_elements %q : vector<4xf32> + %17:4 = vector.to_elements %r : vector<4xf32> + %18:4 = vector.to_elements %s : vector<4xf32> + %19:4 = vector.to_elements %t : vector<4xf32> + %20:4 = vector.to_elements %u : vector<4xf32> + %21:4 = vector.to_elements %v : vector<4xf32> + %22:4 = vector.to_elements %w : vector<4xf32> + %23:4 = vector.to_elements %x : vector<4xf32> + %24:4 = vector.to_elements %y : vector<4xf32> + %25:4 = vector.to_elements %z : vector<4xf32> + %26:4 = vector.to_elements %aa : vector<4xf32> + %27:4 = vector.to_elements %ab : vector<4xf32> + %28:4 = vector.to_elements %ac : vector<4xf32> + %29:4 = vector.to_elements %ad : vector<4xf32> + %30:4 = vector.to_elements %ae : vector<4xf32> + %31:4 = vector.to_elements %af : vector<4xf32> + %32:4 = vector.to_elements %ag : vector<4xf32> + %33:4 = vector.to_elements %ah : vector<4xf32> + %34:4 = vector.to_elements %ai : vector<4xf32> + %35:4 = vector.to_elements %aj : vector<4xf32> + %36:4 = vector.to_elements %ak : vector<4xf32> + %37:4 = vector.to_elements %al : vector<4xf32> + %38:4 = vector.to_elements %am : vector<4xf32> + %39:4 = vector.to_elements %an : vector<4xf32> + %40:4 = vector.to_elements %ao : vector<4xf32> + %41:4 = vector.to_elements %ap : vector<4xf32> + %42:4 = vector.to_elements %aq : vector<4xf32> + %43:4 = vector.to_elements %ar : vector<4xf32> + %44:4 = vector.to_elements %as : vector<4xf32> + %45:4 = vector.to_elements %at : vector<4xf32> + %46:4 = vector.to_elements %au : vector<4xf32> + %47:4 = vector.to_elements %av : vector<4xf32> + %48:4 = vector.to_elements %aw : vector<4xf32> + %49:4 = vector.to_elements %ax : vector<4xf32> + %50:4 = vector.to_elements %ay : vector<4xf32> + %51:4 = vector.to_elements %az : vector<4xf32> + %52:4 = vector.to_elements %ba : vector<4xf32> + %53:4 = vector.to_elements %bb : vector<4xf32> + %54:4 = vector.to_elements %bc : vector<4xf32> + %55:4 = vector.to_elements %bd : vector<4xf32> + %56:4 = vector.to_elements %be : vector<4xf32> + %57:4 = vector.to_elements %bf : vector<4xf32> + %58:4 = vector.to_elements %bg : vector<4xf32> + %59:4 = vector.to_elements %bh : vector<4xf32> + %60:4 = vector.to_elements %bi : vector<4xf32> + %61:4 = vector.to_elements %bj : vector<4xf32> + %62:4 = vector.to_elements %bk : vector<4xf32> + %63:4 = vector.to_elements %bl : vector<4xf32> + %64 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %2#2, %2#3, %3#0, %3#1, %3#2, %3#3, %4#0, %4#1, %4#2, %4#3, + %5#0, %5#1, %5#2, %5#3, %6#0, %6#1, %6#2, %6#3, %7#0, %7#1, %7#2, %7#3, %8#0, %8#1, %8#2, %8#3, %9#0, %9#1, %9#2, %9#3, + %10#0, %10#1, %10#2, %10#3, %11#0, %11#1, %11#2, %11#3, %12#0, %12#1, %12#2, %12#3, %13#0, %13#1, %13#2, %13#3, %14#0, %14#1, %14#2, %14#3, + %15#0, %15#1, %15#2, %15#3, %16#0, %16#1, %16#2, %16#3, %17#0, %17#1, %17#2, %17#3, %18#0, %18#1, %18#2, %18#3, %19#0, %19#1, %19#2, %19#3, + %20#0, %20#1, %20#2, %20#3, %21#0, %21#1, %21#2, %21#3, %22#0, %22#1, %22#2, %22#3, %23#0, %23#1, %23#2, %23#3, %24#0, %24#1, %24#2, %24#3, + %25#0, %25#1, %25#2, %25#3, %26#0, %26#1, %26#2, %26#3, %27#0, %27#1, %27#2, %27#3, %28#0, %28#1, %28#2, %28#3, %29#0, %29#1, %29#2, %29#3, + %30#0, %30#1, %30#2, %30#3, %31#0, %31#1, %31#2, %31#3, %32#0, %32#1, %32#2, %32#3, %33#0, %33#1, %33#2, %33#3, %34#0, %34#1, %34#2, %34#3, + %35#0, %35#1, %35#2, %35#3, %36#0, %36#1, %36#2, %36#3, %37#0, %37#1, %37#2, %37#3, %38#0, %38#1, %38#2, %38#3, %39#0, %39#1, %39#2, %39#3, + %40#0, %40#1, %40#2, %40#3, %41#0, %41#1, %41#2, %41#3, %42#0, %42#1, %42#2, %42#3, %43#0, %43#1, %43#2, %43#3, %44#0, %44#1, %44#2, %44#3, + %45#0, %45#1, %45#2, %45#3, %46#0, %46#1, %46#2, %46#3, %47#0, %47#1, %47#2, %47#3, %48#0, %48#1, %48#2, %48#3, %49#0, %49#1, %49#2, %49#3, + %50#0, %50#1, %50#2, %50#3, %51#0, %51#1, %51#2, %51#3, %52#0, %52#1, %52#2, %52#3, %53#0, %53#1, %53#2, %53#3, %54#0, %54#1, %54#2, %54#3, + %55#0, %55#1, %55#2, %55#3, %56#0, %56#1, %56#2, %56#3, %57#0, %57#1, %57#2, %57#3, %58#0, %58#1, %58#2, %58#3, %59#0, %59#1, %59#2, %59#3, + %60#0, %60#1, %60#2, %60#3, %61#0, %61#1, %61#2, %61#3, %62#0, %62#1, %62#2, %62#3, %63#0, %63#1, %63#2, %63#3 : vector<256xf32> + return %64 : vector<256xf32> +} + +// CHECK-LABEL: func.func @to_from_elements_shuffle_tree_concat_64x4_256( +// CHECK-SAME: %[[A:.+]]: vector<4xf32>, %[[B:.+]]: vector<4xf32>, %[[C:.+]]: vector<4xf32>, %[[D:.+]]: vector<4xf32>, %[[E:.+]]: vector<4xf32>, %[[F:.+]]: vector<4xf32>, %[[G:.+]]: vector<4xf32>, %[[H:.+]]: vector<4xf32>, %[[I:.+]]: vector<4xf32>, %[[J:.+]]: vector<4xf32>, %[[K:.+]]: vector<4xf32>, %[[L:.+]]: vector<4xf32>, %[[M:.+]]: vector<4xf32>, %[[N:.+]]: vector<4xf32>, %[[O:.+]]: vector<4xf32>, %[[P:.+]]: vector<4xf32>, %[[Q:.+]]: vector<4xf32>, %[[R:.+]]: vector<4xf32>, %[[S:.+]]: vector<4xf32>, %[[T:.+]]: vector<4xf32>, %[[U:.+]]: vector<4xf32>, %[[V:.+]]: vector<4xf32>, %[[W:.+]]: vector<4xf32>, %[[X:.+]]: vector<4xf32>, %[[Y:.+]]: vector<4xf32>, %[[Z:.+]]: vector<4xf32>, %[[AA:.+]]: vector<4xf32>, %[[AB:.+]]: vector<4xf32>, %[[AC:.+]]: vector<4xf32>, %[[AD:.+]]: vector<4xf32>, %[[AE:.+]]: vector<4xf32>, %[[AF:.+]]: vector<4xf32>, %[[AG:.+]]: vector<4xf32>, %[[AH:.+]]: vector<4xf32>, %[[AI:.+]]: vector<4xf32>, %[[AJ:.+]]: vector<4xf32>, %[[AK:.+]]: vector<4xf32>, %[[AL:.+]]: vector<4xf32>, %[[AM:.+]]: vector<4xf32>, %[[AN:.+]]: vector<4xf32>, %[[AO:.+]]: vector<4xf32>, %[[AP:.+]]: vector<4xf32>, %[[AQ:.+]]: vector<4xf32>, %[[AR:.+]]: vector<4xf32>, %[[AS:.+]]: vector<4xf32>, %[[AT:.+]]: vector<4xf32>, %[[AU:.+]]: vector<4xf32>, %[[AV:.+]]: vector<4xf32>, %[[AW:.+]]: vector<4xf32>, %[[AX:.+]]: vector<4xf32>, %[[AY:.+]]: vector<4xf32>, %[[AZ:.+]]: vector<4xf32>, %[[BA:.+]]: vector<4xf32>, %[[BB:.+]]: vector<4xf32>, %[[BC:.+]]: vector<4xf32>, %[[BD:.+]]: vector<4xf32>, %[[BE:.+]]: vector<4xf32>, %[[BF:.+]]: vector<4xf32>, %[[BG:.+]]: vector<4xf32>, %[[BH:.+]]: vector<4xf32>, %[[BI:.+]]: vector<4xf32>, %[[BJ:.+]]: vector<4xf32>, %[[BK:.+]]: vector<4xf32>, %[[BL:.+]]: vector<4xf32>) +// CHECK: %[[L0SH0:.+]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH1:.+]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH2:.+]] = vector.shuffle %[[E]], %[[F]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH3:.+]] = vector.shuffle %[[G]], %[[H]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH4:.+]] = vector.shuffle %[[I]], %[[J]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH5:.+]] = vector.shuffle %[[K]], %[[L]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH6:.+]] = vector.shuffle %[[M]], %[[N]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH7:.+]] = vector.shuffle %[[O]], %[[P]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH8:.+]] = vector.shuffle %[[Q]], %[[R]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH9:.+]] = vector.shuffle %[[S]], %[[T]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH10:.+]] = vector.shuffle %[[U]], %[[V]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH11:.+]] = vector.shuffle %[[W]], %[[X]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH12:.+]] = vector.shuffle %[[Y]], %[[Z]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH13:.+]] = vector.shuffle %[[AA]], %[[AB]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH14:.+]] = vector.shuffle %[[AC]], %[[AD]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH15:.+]] = vector.shuffle %[[AE]], %[[AF]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH16:.+]] = vector.shuffle %[[AG]], %[[AH]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH17:.+]] = vector.shuffle %[[AI]], %[[AJ]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH18:.+]] = vector.shuffle %[[AK]], %[[AL]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH19:.+]] = vector.shuffle %[[AM]], %[[AN]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH20:.+]] = vector.shuffle %[[AO]], %[[AP]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH21:.+]] = vector.shuffle %[[AQ]], %[[AR]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH22:.+]] = vector.shuffle %[[AS]], %[[AT]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH23:.+]] = vector.shuffle %[[AU]], %[[AV]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH24:.+]] = vector.shuffle %[[AW]], %[[AX]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH25:.+]] = vector.shuffle %[[AY]], %[[AZ]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH26:.+]] = vector.shuffle %[[BA]], %[[BB]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH27:.+]] = vector.shuffle %[[BC]], %[[BD]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH28:.+]] = vector.shuffle %[[BE]], %[[BF]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH29:.+]] = vector.shuffle %[[BG]], %[[BH]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH30:.+]] = vector.shuffle %[[BI]], %[[BJ]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH31:.+]] = vector.shuffle %[[BK]], %[[BL]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L1SH0:.+]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH1:.+]] = vector.shuffle %[[L0SH2]], %[[L0SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH2:.+]] = vector.shuffle %[[L0SH4]], %[[L0SH5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH3:.+]] = vector.shuffle %[[L0SH6]], %[[L0SH7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH4:.+]] = vector.shuffle %[[L0SH8]], %[[L0SH9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH5:.+]] = vector.shuffle %[[L0SH10]], %[[L0SH11]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH6:.+]] = vector.shuffle %[[L0SH12]], %[[L0SH13]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH7:.+]] = vector.shuffle %[[L0SH14]], %[[L0SH15]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH8:.+]] = vector.shuffle %[[L0SH16]], %[[L0SH17]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH9:.+]] = vector.shuffle %[[L0SH18]], %[[L0SH19]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH10:.+]] = vector.shuffle %[[L0SH20]], %[[L0SH21]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH11:.+]] = vector.shuffle %[[L0SH22]], %[[L0SH23]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH12:.+]] = vector.shuffle %[[L0SH24]], %[[L0SH25]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH13:.+]] = vector.shuffle %[[L0SH26]], %[[L0SH27]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH14:.+]] = vector.shuffle %[[L0SH28]], %[[L0SH29]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L1SH15:.+]] = vector.shuffle %[[L0SH30]], %[[L0SH31]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> +// CHECK: %[[L2SH0:.+]] = vector.shuffle %[[L1SH0]], %[[L1SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH1:.+]] = vector.shuffle %[[L1SH2]], %[[L1SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH2:.+]] = vector.shuffle %[[L1SH4]], %[[L1SH5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH3:.+]] = vector.shuffle %[[L1SH6]], %[[L1SH7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH4:.+]] = vector.shuffle %[[L1SH8]], %[[L1SH9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH5:.+]] = vector.shuffle %[[L1SH10]], %[[L1SH11]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH6:.+]] = vector.shuffle %[[L1SH12]], %[[L1SH13]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L2SH7:.+]] = vector.shuffle %[[L1SH14]], %[[L1SH15]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[L3SH0:.+]] = vector.shuffle %[[L2SH0]], %[[L2SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32> +// CHECK: %[[L3SH1:.+]] = vector.shuffle %[[L2SH2]], %[[L2SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32> +// CHECK: %[[L3SH2:.+]] = vector.shuffle %[[L2SH4]], %[[L2SH5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32> +// CHECK: %[[L3SH3:.+]] = vector.shuffle %[[L2SH6]], %[[L2SH7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32> +// CHECK: %[[L4SH0:.+]] = vector.shuffle %[[L3SH0]], %[[L3SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127] : vector<64xf32>, vector<64xf32> +// CHECK: %[[L4SH1:.+]] = vector.shuffle %[[L3SH2]], %[[L3SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127] : vector<64xf32>, vector<64xf32> +// CHECK: %[[L5SH0:.+]] = vector.shuffle %[[L4SH0]], %[[L4SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] : vector<128xf32>, vector<128xf32> +// CHECK: return %[[L5SH0]] : vector<256xf32> + +// ----- + +func.func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>, + %b: vector<4xf32>, + %c: vector<4xf32>, + %d: vector<4xf32>) -> vector<16xf32> { + %0:4 = vector.to_elements %a : vector<4xf32> + %1:4 = vector.to_elements %b : vector<4xf32> + %2:4 = vector.to_elements %c : vector<4xf32> + %3:4 = vector.to_elements %d : vector<4xf32> + %4 = vector.from_elements %3#3, %0#0, %2#2, %1#1, %3#0, %2#1, %0#3, %1#2, %0#1, %3#2, %1#0, %2#3, %1#3, %0#2, %3#1, %2#0 : vector<16xf32> + return %4 : vector<16xf32> +} + +// TODO: Implement mask compression to reduce the number of intermediate poison values. + +// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16( +// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>, %[[D:.*]]: vector<4xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[D]], %[[A]] [3, 4, -1, -1, 0, -1, 7, -1, 5, 2, -1, -1, -1, 6, 1] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[B]] [2, 5, -1, 1, -1, 6, -1, -1, 4, 3, 7, -1, -1, 0, -1] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 15, 16, 4, 18, 6, 20, 8, 9, 23, 24, 25, 13, 14, 28] : vector<15xf32>, vector<15xf32> +// CHECK: return %[[L1SH0]] : vector<16xf32> + +// ----- + +func.func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>, + %b: vector<4xf32>, + %c: vector<4xf32>) -> vector<12xf32> { + %0:4 = vector.to_elements %a : vector<4xf32> + %1:4 = vector.to_elements %b : vector<4xf32> + %2:4 = vector.to_elements %c : vector<4xf32> + %3 = vector.from_elements %0#2, %1#1, %2#0, %0#1, %1#0, %2#2, %0#0, %1#3, %2#3, %0#3, %1#2, %2#1 : vector<12xf32> + return %3 : vector<12xf32> +} + +// TODO: Implement mask compression to reduce the number of intermediate poison values. + +// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12( +// CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [2, 5, -1, 1, 4, -1, 0, 7, -1, 3, 6] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, -1, -1, 2, -1, -1, 3, -1, -1, 1, -1] : vector<4xf32>, vector<4xf32> +// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 11, 3, 4, 14, 6, 7, 17, 9, 10, 20] : vector<11xf32>, vector<11xf32> +// CHECK: return %[[L1SH0]] : vector<12xf32> + +// ----- + +func.func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>, + %b: vector<5xf32>, + %c: vector<5xf32>) -> vector<9xf32> { + %0:5 = vector.to_elements %a : vector<5xf32> + %1:5 = vector.to_elements %b : vector<5xf32> + %2:5 = vector.to_elements %c : vector<5xf32> + %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, %2#2, %2#0, %1#1, %0#4 : vector<9xf32> + return %3 : vector<9xf32> +} + +// TODO: Implement mask compression to reduce the number of intermediate poison values. + +// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9( +// CHECK-SAME: %[[A:.*]]: vector<5xf32>, %[[B:.*]]: vector<5xf32>, %[[C:.*]]: vector<5xf32> +// CHECK: %[[L0SH0:.*]] = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6] : vector<5xf32>, vector<5xf32> +// CHECK: %[[L0SH1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1] : vector<5xf32>, vector<5xf32> +// CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 8, 9, 4, 5, 6, 7, 14] : vector<8xf32>, vector<8xf32> +// CHECK: return %[[L1SH0]] : vector<9xf32> + +// ----- + +func.func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>, + %b: vector<2xf32>, + %c: vector<2xf32>, + %d: vector<2xf32>) -> vector<32xf32> { + %0:2 = vector.to_elements %a : vector<2xf32> + %1:2 = vector.to_elements %b : vector<2xf32> + %2:2 = vector.to_elements %c : vector<2xf32> + %3:2 = vector.to_elements %d : vector<2xf32> + %4 = vector.from_elements %0#0, %0#0, %0#0, %0#0, %0#1, %0#1, %0#1, %0#1, + %1#0, %1#0, %1#0, %1#0, %1#1, %1#1, %1#1, %1#1, + %2#0, %2#0, %2#0, %2#0, %2#1, %2#1, %2#1, %2#1, + %3#0, %3#0, %3#0, %3#0, %3#1, %3#1, %3#1, %3#1 : vector<32xf32> + return %4 : vector<32xf32> +} + +// CHECK-LABEL: func @to_from_elements_shuffle_tree_broadcast_4x2_to_32( +// CHECK-SAME: %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>, %[[C:.*]]: vector<2xf32>, %[[D:.*]]: vector<2xf32> + // CHECK: %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32> + // CHECK: %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32> + // CHECK: %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: return %[[L1SH0]] : vector<32xf32>