| 
11 | 11 | #include "mlir/IR/BuiltinTypes.h"  | 
12 | 12 | #include "mlir/IR/Dialect.h"  | 
13 | 13 | #include "mlir/IR/DialectInterface.h"  | 
 | 14 | +#include "mlir/Support/LLVM.h"  | 
14 | 15 | #include "llvm/ADT/SmallVector.h"  | 
15 | 16 | #include "gtest/gtest.h"  | 
16 | 17 | #include <cstdint>  | 
@@ -226,4 +227,61 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {  | 
226 | 227 |   }  | 
227 | 228 | }  | 
228 | 229 | 
 
  | 
 | 230 | +/// Simple wrapper class to enable "isa querying" and simple accessing of  | 
 | 231 | +/// encoding.  | 
 | 232 | +class TensorWithString : public RankedTensorType {  | 
 | 233 | +public:  | 
 | 234 | +  using RankedTensorType::RankedTensorType;  | 
 | 235 | + | 
 | 236 | +  static TensorWithString get(ArrayRef<int64_t> shape, Type elementType,  | 
 | 237 | +                              StringRef name) {  | 
 | 238 | +    return mlir::cast<TensorWithString>(RankedTensorType::get(  | 
 | 239 | +        shape, elementType, StringAttr::get(elementType.getContext(), name)));  | 
 | 240 | +  }  | 
 | 241 | + | 
 | 242 | +  StringRef getName() const {  | 
 | 243 | +    if (Attribute enc = getEncoding())  | 
 | 244 | +      return mlir::cast<StringAttr>(enc).getValue();  | 
 | 245 | +    return {};  | 
 | 246 | +  }  | 
 | 247 | + | 
 | 248 | +  static bool classof(Type type) {  | 
 | 249 | +    if (auto rt = mlir::dyn_cast_or_null<RankedTensorType>(type))  | 
 | 250 | +      return mlir::isa_and_present<StringAttr>(rt.getEncoding());  | 
 | 251 | +    return false;  | 
 | 252 | +  }  | 
 | 253 | +};  | 
 | 254 | + | 
 | 255 | +TEST(ShapedTypeTest, RankedTensorTypeView) {  | 
 | 256 | +  MLIRContext context;  | 
 | 257 | +  Type f32 = FloatType::getF32(&context);  | 
 | 258 | + | 
 | 259 | +  Type noEncodingRankedTensorType = RankedTensorType::get({10, 20}, f32);  | 
 | 260 | + | 
 | 261 | +  UnitAttr unitAttr = UnitAttr::get(&context);  | 
 | 262 | +  Type unitEncodingRankedTensorType =  | 
 | 263 | +      RankedTensorType::get({10, 20}, f32, unitAttr);  | 
 | 264 | + | 
 | 265 | +  StringAttr stringAttr = StringAttr::get(&context, "app");  | 
 | 266 | +  Type stringEncodingRankedTensorType =  | 
 | 267 | +      RankedTensorType::get({10, 20}, f32, stringAttr);  | 
 | 268 | + | 
 | 269 | +  EXPECT_FALSE(mlir::isa<TensorWithString>(noEncodingRankedTensorType));  | 
 | 270 | +  EXPECT_FALSE(mlir::isa<TensorWithString>(unitEncodingRankedTensorType));  | 
 | 271 | +  ASSERT_TRUE(mlir::isa<TensorWithString>(stringEncodingRankedTensorType));  | 
 | 272 | + | 
 | 273 | +  // Cast to TensorWithString view.  | 
 | 274 | +  auto view = mlir::cast<TensorWithString>(stringEncodingRankedTensorType);  | 
 | 275 | +  ASSERT_TRUE(mlir::isa<TensorWithString>(view));  | 
 | 276 | +  EXPECT_EQ(view.getName(), "app");  | 
 | 277 | +  // Verify one could cast view type back to base type.  | 
 | 278 | +  ASSERT_TRUE(mlir::isa<RankedTensorType>(view));  | 
 | 279 | + | 
 | 280 | +  Type viewCreated = TensorWithString::get({10, 20}, f32, "bob");  | 
 | 281 | +  ASSERT_TRUE(mlir::isa<TensorWithString>(viewCreated));  | 
 | 282 | +  ASSERT_TRUE(mlir::isa<RankedTensorType>(viewCreated));  | 
 | 283 | +  view = mlir::cast<TensorWithString>(viewCreated);  | 
 | 284 | +  EXPECT_EQ(view.getName(), "bob");  | 
 | 285 | +}  | 
 | 286 | + | 
229 | 287 | } // namespace  | 
0 commit comments