@@ -20,40 +20,39 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
2020 mlir::sparse_tensor::SparseTensorDialect)
2121
2222// Ensure the C-API enums are int-castable to C++ equivalents.
23- static_assert(
24- static_cast <int >(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) ==
25- static_cast<int>(DimLevelType::Dense) &&
26- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) ==
27- static_cast<int>(DimLevelType::Compressed) &&
28- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) ==
29- static_cast<int>(DimLevelType::CompressedNu) &&
30- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) ==
31- static_cast<int>(DimLevelType::CompressedNo) &&
32- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) ==
33- static_cast<int>(DimLevelType::CompressedNuNo) &&
34- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) ==
35- static_cast<int>(DimLevelType::Singleton) &&
36- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) ==
37- static_cast<int>(DimLevelType::SingletonNu) &&
38- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) ==
39- static_cast<int>(DimLevelType::SingletonNo) &&
40- static_cast<int>(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) ==
41- static_cast<int>(DimLevelType::SingletonNuNo),
42- "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch");
23+ static_assert(static_cast <int >(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
24+ static_cast<int>(LevelType::Dense) &&
25+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
26+ static_cast<int>(LevelType::Compressed) &&
27+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
28+ static_cast<int>(LevelType::CompressedNu) &&
29+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
30+ static_cast<int>(LevelType::CompressedNo) &&
31+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
32+ static_cast<int>(LevelType::CompressedNuNo) &&
33+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
34+ static_cast<int>(LevelType::Singleton) &&
35+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
36+ static_cast<int>(LevelType::SingletonNu) &&
37+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
38+ static_cast<int>(LevelType::SingletonNo) &&
39+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
40+ static_cast<int>(LevelType::SingletonNuNo),
41+ "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
4342
4443bool mlirAttributeIsASparseTensorEncodingAttr (MlirAttribute attr) {
4544 return isa<SparseTensorEncodingAttr>(unwrap (attr));
4645}
4746
4847MlirAttribute
4948mlirSparseTensorEncodingAttrGet (MlirContext ctx, intptr_t lvlRank,
50- MlirSparseTensorDimLevelType const *lvlTypes,
49+ MlirSparseTensorLevelType const *lvlTypes,
5150 MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
5251 int posWidth, int crdWidth) {
53- SmallVector<DimLevelType > cppLvlTypes;
52+ SmallVector<LevelType > cppLvlTypes;
5453 cppLvlTypes.reserve (lvlRank);
5554 for (intptr_t l = 0 ; l < lvlRank; ++l)
56- cppLvlTypes.push_back (static_cast <DimLevelType >(lvlTypes[l]));
55+ cppLvlTypes.push_back (static_cast <LevelType >(lvlTypes[l]));
5756 return wrap (SparseTensorEncodingAttr::get (unwrap (ctx), cppLvlTypes,
5857 unwrap (dimToLvl), unwrap (lvlToDim),
5958 posWidth, crdWidth));
@@ -71,9 +70,9 @@ intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
7170 return cast<SparseTensorEncodingAttr>(unwrap (attr)).getLvlRank ();
7271}
7372
74- MlirSparseTensorDimLevelType
73+ MlirSparseTensorLevelType
7574mlirSparseTensorEncodingAttrGetLvlType (MlirAttribute attr, intptr_t lvl) {
76- return static_cast <MlirSparseTensorDimLevelType >(
75+ return static_cast <MlirSparseTensorLevelType >(
7776 cast<SparseTensorEncodingAttr>(unwrap (attr)).getLvlType (lvl));
7877}
7978
0 commit comments