2929#include " mlir/IR/OpDefinition.h"
3030#include " mlir/IR/PatternMatch.h"
3131#include " mlir/IR/TypeUtilities.h"
32+ #include " mlir/Support/ScalableVectorType.h"
3233#include " mlir/Transforms/DialectConversion.h"
3334#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
3435#include " llvm/ADT/ArrayRef.h"
@@ -39,24 +40,14 @@ using namespace mlir;
3940using namespace mlir ::math;
4041using namespace mlir ::vector;
4142
42- // Helper to encapsulate a vector's shape (including scalable dims).
43- struct VectorShape {
44- ArrayRef<int64_t > sizes;
45- ArrayRef<bool > scalableFlags;
46-
47- bool empty () const { return sizes.empty (); }
48- };
49-
5043// Returns vector shape if the type is a vector. Returns an empty shape if it is
5144// not a vector.
52- static VectorShape vectorShape (Type type) {
45+ static VectorDimList vectorShape (Type type) {
5346 auto vectorType = dyn_cast<VectorType>(type);
54- return vectorType
55- ? VectorShape{vectorType.getShape (), vectorType.getScalableDims ()}
56- : VectorShape{};
47+ return VectorDimList::from (vectorType);
5748}
5849
59- static VectorShape vectorShape (Value value) {
50+ static VectorDimList vectorShape (Value value) {
6051 return vectorShape (value.getType ());
6152}
6253
@@ -65,16 +56,14 @@ static VectorShape vectorShape(Value value) {
6556// ----------------------------------------------------------------------------//
6657
6758// Broadcasts scalar type into vector type (iff shape is non-scalar).
68- static Type broadcast (Type type, VectorShape shape) {
59+ static Type broadcast (Type type, VectorDimList shape) {
6960 assert (!isa<VectorType>(type) && " must be scalar type" );
70- return !shape.empty ()
71- ? VectorType::get (shape.sizes , type, shape.scalableFlags )
72- : type;
61+ return !shape.empty () ? ScalableVectorType::get (shape, type) : type;
7362}
7463
7564// Broadcasts scalar value into vector (iff shape is non-scalar).
7665static Value broadcast (ImplicitLocOpBuilder &builder, Value value,
77- VectorShape shape) {
66+ VectorDimList shape) {
7867 assert (!isa<VectorType>(value.getType ()) && " must be scalar value" );
7968 auto type = broadcast (value.getType (), shape);
8069 return !shape.empty () ? builder.create <BroadcastOp>(type, value) : value;
@@ -227,7 +216,7 @@ static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
227216static std::pair<Value, Value> frexp (ImplicitLocOpBuilder &builder, Value arg,
228217 bool isPositive = false ) {
229218 assert (getElementTypeOrSelf (arg).isF32 () && " arg must be f32 type" );
230- VectorShape shape = vectorShape (arg);
219+ VectorDimList shape = vectorShape (arg);
231220
232221 auto bcast = [&](Value value) -> Value {
233222 return broadcast (builder, value, shape);
@@ -267,7 +256,7 @@ static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
267256// Computes exp2 for an i32 argument.
268257static Value exp2I32 (ImplicitLocOpBuilder &builder, Value arg) {
269258 assert (getElementTypeOrSelf (arg).isInteger (32 ) && " arg must be i32 type" );
270- VectorShape shape = vectorShape (arg);
259+ VectorDimList shape = vectorShape (arg);
271260
272261 auto bcast = [&](Value value) -> Value {
273262 return broadcast (builder, value, shape);
@@ -293,7 +282,7 @@ Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
293282 Type elementType = getElementTypeOrSelf (x);
294283 assert ((elementType.isF32 () || elementType.isF16 ()) &&
295284 " x must be f32 or f16 type" );
296- VectorShape shape = vectorShape (x);
285+ VectorDimList shape = vectorShape (x);
297286
298287 if (coeffs.empty ())
299288 return broadcast (builder, floatCst (builder, 0 .0f , elementType), shape);
@@ -391,7 +380,7 @@ AtanApproximation::matchAndRewrite(math::AtanOp op,
391380 if (!getElementTypeOrSelf (operand).isF32 ())
392381 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
393382
394- VectorShape shape = vectorShape (op.getOperand ());
383+ VectorDimList shape = vectorShape (op.getOperand ());
395384
396385 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
397386 Value abs = builder.create <math::AbsFOp>(operand);
@@ -490,7 +479,7 @@ Atan2Approximation::matchAndRewrite(math::Atan2Op op,
490479 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
491480
492481 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
493- VectorShape shape = vectorShape (op.getResult ());
482+ VectorDimList shape = vectorShape (op.getResult ());
494483
495484 // Compute atan in the valid range.
496485 auto div = builder.create <arith::DivFOp>(y, x);
@@ -556,7 +545,7 @@ TanhApproximation::matchAndRewrite(math::TanhOp op,
556545 if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
557546 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
558547
559- VectorShape shape = vectorShape (op.getOperand ());
548+ VectorDimList shape = vectorShape (op.getOperand ());
560549
561550 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
562551 auto bcast = [&](Value value) -> Value {
@@ -644,7 +633,7 @@ LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
644633 if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
645634 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
646635
647- VectorShape shape = vectorShape (op.getOperand ());
636+ VectorDimList shape = vectorShape (op.getOperand ());
648637
649638 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
650639 auto bcast = [&](Value value) -> Value {
@@ -791,7 +780,7 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
791780 if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
792781 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
793782
794- VectorShape shape = vectorShape (op.getOperand ());
783+ VectorDimList shape = vectorShape (op.getOperand ());
795784
796785 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
797786 auto bcast = [&](Value value) -> Value {
@@ -846,7 +835,7 @@ AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
846835 if (!(elementType.isF32 () || elementType.isF16 ()))
847836 return rewriter.notifyMatchFailure (op,
848837 " only f32 and f16 type is supported." );
849- VectorShape shape = vectorShape (operand);
838+ VectorDimList shape = vectorShape (operand);
850839
851840 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
852841 auto bcast = [&](Value value) -> Value {
@@ -910,7 +899,7 @@ AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
910899 if (!(elementType.isF32 () || elementType.isF16 ()))
911900 return rewriter.notifyMatchFailure (op,
912901 " only f32 and f16 type is supported." );
913- VectorShape shape = vectorShape (operand);
902+ VectorDimList shape = vectorShape (operand);
914903
915904 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
916905 auto bcast = [&](Value value) -> Value {
@@ -988,7 +977,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
988977 if (!(elementType.isF32 () || elementType.isF16 ()))
989978 return rewriter.notifyMatchFailure (op,
990979 " only f32 and f16 type is supported." );
991- VectorShape shape = vectorShape (operand);
980+ VectorDimList shape = vectorShape (operand);
992981
993982 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
994983 auto bcast = [&](Value value) -> Value {
@@ -1097,7 +1086,7 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
10971086
10981087namespace {
10991088
1100- Value clampWithNormals (ImplicitLocOpBuilder &builder, const VectorShape shape,
1089+ Value clampWithNormals (ImplicitLocOpBuilder &builder, const VectorDimList shape,
11011090 Value value, float lowerBound, float upperBound) {
11021091 assert (!std::isnan (lowerBound));
11031092 assert (!std::isnan (upperBound));
@@ -1289,7 +1278,7 @@ ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
12891278 if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
12901279 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
12911280
1292- VectorShape shape = vectorShape (op.getOperand ());
1281+ VectorDimList shape = vectorShape (op.getOperand ());
12931282
12941283 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
12951284 auto bcast = [&](Value value) -> Value {
@@ -1359,7 +1348,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
13591348 if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
13601349 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
13611350
1362- VectorShape shape = vectorShape (op.getOperand ());
1351+ VectorDimList shape = vectorShape (op.getOperand ());
13631352
13641353 ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
13651354 auto bcast = [&](Value value) -> Value {
@@ -1486,7 +1475,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
14861475 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
14871476
14881477 ImplicitLocOpBuilder b (op->getLoc (), rewriter);
1489- VectorShape shape = vectorShape (operand);
1478+ VectorDimList shape = vectorShape (operand);
14901479
14911480 Type floatTy = getElementTypeOrSelf (operand.getType ());
14921481 Type intTy = b.getIntegerType (floatTy.getIntOrFloatBitWidth ());
@@ -1575,7 +1564,7 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
15751564 if (!getElementTypeOrSelf (op.getOperand ()).isF32 ())
15761565 return rewriter.notifyMatchFailure (op, " unsupported operand type" );
15771566
1578- VectorShape shape = vectorShape (op.getOperand ());
1567+ VectorDimList shape = vectorShape (op.getOperand ());
15791568
15801569 // Only support already-vectorized rsqrt's.
15811570 if (shape.empty () || shape.sizes .back () % 8 != 0 )
0 commit comments