diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp index 61264bc523648..7a5b0722a03ba 100644 --- a/mlir/unittests/IR/ShapedTypeTest.cpp +++ b/mlir/unittests/IR/ShapedTypeTest.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectInterface.h" +#include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallVector.h" #include "gtest/gtest.h" #include @@ -226,4 +227,61 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) { } } +/// Simple wrapper class to enable "isa querying" and simple accessing of +/// encoding. +class TensorWithString : public RankedTensorType { +public: + using RankedTensorType::RankedTensorType; + + static TensorWithString get(ArrayRef shape, Type elementType, + StringRef name) { + return mlir::cast(RankedTensorType::get( + shape, elementType, StringAttr::get(elementType.getContext(), name))); + } + + StringRef getName() const { + if (Attribute enc = getEncoding()) + return mlir::cast(enc).getValue(); + return {}; + } + + static bool classof(Type type) { + if (auto rt = mlir::dyn_cast_or_null(type)) + return mlir::isa_and_present(rt.getEncoding()); + return false; + } +}; + +TEST(ShapedTypeTest, RankedTensorTypeView) { + MLIRContext context; + Type f32 = FloatType::getF32(&context); + + Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32); + + UnitAttr unitAttr = UnitAttr::get(&context); + Type unitEncodingRankedTensorType = + RankedTensorType::get({10, 20}, f32, unitAttr); + + StringAttr stringAttr = StringAttr::get(&context, "app"); + Type stringEncodingRankedTensorType = + RankedTensorType::get({10, 20}, f32, stringAttr); + + EXPECT_FALSE(mlir::isa(noEncodingRankedTensorType)); + EXPECT_FALSE(mlir::isa(unitEncodingRankedTensorType)); + ASSERT_TRUE(mlir::isa(stringEncodingRankedTensorType)); + + // Cast to TensorWithString view. + auto view = mlir::cast(stringEncodingRankedTensorType); + ASSERT_TRUE(mlir::isa(view)); + EXPECT_EQ(view.getName(), "app"); + // Verify one could cast view type back to base type. + ASSERT_TRUE(mlir::isa(view)); + + Type viewCreated = TensorWithString::get({10, 20}, f32, "bob"); + ASSERT_TRUE(mlir::isa(viewCreated)); + ASSERT_TRUE(mlir::isa(viewCreated)); + view = mlir::cast(viewCreated); + EXPECT_EQ(view.getName(), "bob"); +} + } // namespace