|  | 
| 11 | 11 | #include "mlir/IR/BuiltinTypes.h" | 
| 12 | 12 | #include "mlir/IR/Dialect.h" | 
| 13 | 13 | #include "mlir/IR/DialectInterface.h" | 
|  | 14 | +#include "mlir/Support/LLVM.h" | 
| 14 | 15 | #include "llvm/ADT/SmallVector.h" | 
| 15 | 16 | #include "gtest/gtest.h" | 
| 16 | 17 | #include <cstdint> | 
| @@ -226,4 +227,61 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) { | 
| 226 | 227 |   } | 
| 227 | 228 | } | 
| 228 | 229 | 
 | 
|  | 230 | +/// Simple wrapper class to enable "isa querying" and simple accessing of | 
|  | 231 | +/// encoding. | 
|  | 232 | +class TensorWithString : public RankedTensorType { | 
|  | 233 | +public: | 
|  | 234 | +  using RankedTensorType::RankedTensorType; | 
|  | 235 | + | 
|  | 236 | +  static TensorWithString get(ArrayRef<int64_t> shape, Type elementType, | 
|  | 237 | +                              StringRef name) { | 
|  | 238 | +    return mlir::cast<TensorWithString>(RankedTensorType::get( | 
|  | 239 | +        shape, elementType, StringAttr::get(elementType.getContext(), name))); | 
|  | 240 | +  } | 
|  | 241 | + | 
|  | 242 | +  StringRef getName() const { | 
|  | 243 | +    if (Attribute enc = getEncoding()) | 
|  | 244 | +      return mlir::cast<StringAttr>(enc).getValue(); | 
|  | 245 | +    return {}; | 
|  | 246 | +  } | 
|  | 247 | + | 
|  | 248 | +  static bool classof(Type type) { | 
|  | 249 | +    if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type)) | 
|  | 250 | +      return mlir::isa_and_present<StringAttr>(rt.getEncoding()); | 
|  | 251 | +    return false; | 
|  | 252 | +  } | 
|  | 253 | +}; | 
|  | 254 | + | 
|  | 255 | +TEST(ShapedTypeTest, RankedTensorTypeView) { | 
|  | 256 | +  MLIRContext context; | 
|  | 257 | +  Type f32 = FloatType::getF32(&context); | 
|  | 258 | + | 
|  | 259 | +  Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32); | 
|  | 260 | + | 
|  | 261 | +  UnitAttr unitAttr = UnitAttr::get(&context); | 
|  | 262 | +  Type unitEncodingRankedTensorType = | 
|  | 263 | +      RankedTensorType::get({10, 20}, f32, unitAttr); | 
|  | 264 | + | 
|  | 265 | +  StringAttr stringAttr = StringAttr::get(&context, "app"); | 
|  | 266 | +  Type stringEncodingRankedTensorType = | 
|  | 267 | +      RankedTensorType::get({10, 20}, f32, stringAttr); | 
|  | 268 | + | 
|  | 269 | +  EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType)); | 
|  | 270 | +  EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType)); | 
|  | 271 | +  ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType)); | 
|  | 272 | + | 
|  | 273 | +  // Cast to TensorWithString view. | 
|  | 274 | +  auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType); | 
|  | 275 | +  ASSERT_TRUE(mlir::isa<TensorWithString>(view)); | 
|  | 276 | +  EXPECT_EQ(view.getName(), "app"); | 
|  | 277 | +  // Verify one could cast view type back to base type. | 
|  | 278 | +  ASSERT_TRUE(mlir::isa<RankedTensorType>(view)); | 
|  | 279 | + | 
|  | 280 | +  Type viewCreated = TensorWithString::get({10, 20}, f32, "bob"); | 
|  | 281 | +  ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated)); | 
|  | 282 | +  ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated)); | 
|  | 283 | +  view = mlir::cast<TensorWithString>(viewCreated); | 
|  | 284 | +  EXPECT_EQ(view.getName(), "bob"); | 
|  | 285 | +} | 
|  | 286 | + | 
| 229 | 287 | } // namespace | 
0 commit comments