Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions mlir/unittests/IR/ShapedTypeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "gtest/gtest.h"
#include <cstdint>
Expand Down Expand Up @@ -226,4 +227,61 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
}
}

/// Simple wrapper class to enable "isa querying" and simple accessing of
/// encoding.
class TensorWithString : public RankedTensorType {
public:
using RankedTensorType::RankedTensorType;

static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,
StringRef name) {
return mlir::cast<TensorWithString>(RankedTensorType::get(
shape, elementType, StringAttr::get(elementType.getContext(), name)));
}

StringRef getName() const {
if (Attribute enc = getEncoding())
return mlir::cast<StringAttr>(enc).getValue();
Comment on lines +243 to +244
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're going to be adding these encoding wrappers, should we consider a templatized get encoding that does the cast/dyn_cast to the wrapped type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea yes. Like "getEncodingAs" ?

return {};
}

static bool classof(Type type) {
if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))
return mlir::isa_and_present<StringAttr>(rt.getEncoding());
return false;
}
};

TEST(ShapedTypeTest, RankedTensorTypeView) {
MLIRContext context;
Type f32 = FloatType::getF32(&context);

Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);

UnitAttr unitAttr = UnitAttr::get(&context);
Type unitEncodingRankedTensorType =
RankedTensorType::get({10, 20}, f32, unitAttr);

StringAttr stringAttr = StringAttr::get(&context, "app");
Type stringEncodingRankedTensorType =
RankedTensorType::get({10, 20}, f32, stringAttr);

EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));
EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));
ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));

// Cast to TensorWithString view.
auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType);
ASSERT_TRUE(mlir::isa<TensorWithString>(view));
EXPECT_EQ(view.getName(), "app");
// Verify one could cast view type back to base type.
ASSERT_TRUE(mlir::isa<RankedTensorType>(view));

Type viewCreated = TensorWithString::get({10, 20}, f32, "bob");
ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));
ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));
view = mlir::cast<TensorWithString>(viewCreated);
EXPECT_EQ(view.getName(), "bob");
}

} // namespace