Skip to content

Commit f2efff3

Browse files
address comments
1 parent a957b09 commit f2efff3

File tree

3 files changed

+17
-13
lines changed

3 files changed

+17
-13
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
307307
"AffineMap":$lvlToDim,
308308
"unsigned":$posWidth,
309309
"unsigned":$crdWidth), [{
310+
if (!lvlToDim) {
311+
lvlToDim = ::mlir::sparse_tensor::inferLvlToDim(dimToLvl, $_ctxt);
312+
}
310313
return $_get($_ctxt, lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
311314
ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr>{});
312315
}]>

mlir/lib/CAPI/Dialect/SparseTensor.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,8 @@ mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
5454
cppLvlTypes.reserve(lvlRank);
5555
for (intptr_t l = 0; l < lvlRank; ++l)
5656
cppLvlTypes.push_back(static_cast<DimLevelType>(lvlTypes[l]));
57-
auto unwrappedLvlToDim = unwrap(lvlToDim);
58-
if (!unwrappedLvlToDim)
59-
unwrappedLvlToDim = inferLvlToDim(unwrap(dimToLvl), unwrap(ctx));
6057
return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes,
61-
unwrap(dimToLvl), unwrappedLvlToDim,
58+
unwrap(dimToLvl), unwrap(lvlToDim),
6259
posWidth, crdWidth));
6360
}
6461

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
582582
#undef RETURN_ON_FAIL
583583

584584
// Construct struct-like storage for attribute.
585+
// TODO: Fetch lvlToDim if user provides one
585586
AffineMap lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
586587
return parser.getChecked<SparseTensorEncodingAttr>(
587588
parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
@@ -770,29 +771,28 @@ AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
770771
lvlExprs.reserve(numLvls);
771772
// lvlExprComponents stores information of the floordiv and mod operations
772773
// applied to the same dimension, so as to build the lvlToDim map.
773-
// Map key is the position of the dimension in dimToLvl.
774-
// Map value is a SmallVector that contains lvl var for floordiv, multiplier,
775-
// lvl var for mod in dimToLvl.
776-
// For example, for il = i floordiv 2 and ii = i mod 2, the SmalleVector
777-
// would be [il, 2, ii]. It could be used to build the AffineExpr
778-
// i = il * 2 + ii in lvlToDim.
779774
std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
780775
for (unsigned i = 0, n = numLvls; i < n; i++) {
781776
auto result = dimToLvl.getResult(i);
782777
if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
783778
if (result.getKind() == AffineExprKind::FloorDiv) {
779+
// Position of the dimension in dimToLvl.
780+
auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
781+
assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
782+
"expected only one floordiv for each dimension");
784783
SmallVector<AffineExpr, 3> components;
785784
// Level variable for floordiv.
786785
components.push_back(getAffineDimExpr(i, context));
787786
// Multiplier.
788787
components.push_back(binOp.getRHS());
789-
auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
788+
// Map key is the position of the dimension.
790789
lvlExprComponents[pos] = components;
791790
} else if (result.getKind() == AffineExprKind::Mod) {
792791
auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
793792
assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
794793
"expected floordiv before mod");
795-
// Level variable for mod.
794+
// Level variable for mod added to the vector of the corresponding
795+
// floordiv with the same dimension.
796796
lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
797797
} else {
798798
assert(false && "expected floordiv or mod");
@@ -801,6 +801,10 @@ AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
801801
lvlExprs.push_back(getAffineDimExpr(i, context));
802802
}
803803
}
804+
// Build lvlExprs from lvlExprComponents.
805+
// For example, for il = i floordiv 2 and ii = i mod 2, the components
806+
// would be [il, 2, ii]. It could be used to build the AffineExpr
807+
// i = il * 2 + ii in lvlToDim.
804808
for (auto &components : lvlExprComponents) {
805809
assert(components.second.size() == 3 &&
806810
"expected 3 components to build lvlExprs");
@@ -875,7 +879,7 @@ RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
875879
// default value.
876880
unsigned posWidth = src.getPosWidth();
877881
unsigned crdWidth = src.getCrdWidth();
878-
auto invPerm = src.getLvlToDim();
882+
AffineMap invPerm = src.getLvlToDim();
879883
auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
880884
invPerm, posWidth, crdWidth);
881885
return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);

0 commit comments

Comments
 (0)