Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 18 additions & 60 deletions mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,10 @@ struct TensorExp final {
///
/// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
/// That is, its argument is a `LoopId` identifying the loop-variable
/// in question, and its value will be the current iteration's value
/// of that loop-variable. See the `LoopId` documentation for more details.
///
/// The `kSynZero` leaf kind is for representing a synthetic zero value, which
/// can be introduced when sparsifying operations like `arith::cmp` to generate
/// `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
//
// TODO: Modify this definition so that the numeric values already encode
// the `ExpArity` (while extending the notion of "arity" to include not
// just the number of `ExprId` children the node has, but also whether the
// node has a `Value` and/or `Operation*`). Doing this will avoid needing
// to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor,
// and should help clean up a few other places as well.
/// in question, and its value will be the current iteration's value.
/// The `kSynZero` leaf kind is for representing a synthetic zero value,
/// which can be introduced when sparsifying operations like `arith::cmp`
/// to generate `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
enum class TensorExp::Kind {
// Leaf.
kTensor = 0,
Expand Down Expand Up @@ -253,15 +244,6 @@ class Merger {
///
/// The maxLvlRank specifies the max level rank of all inputs/output tensors.
/// It is used to pre-allocate sufficient memory for internal storage.
//
// TODO: we want to make the filter loop more efficient in the future,
// e.g., by avoiding scanning the full list of stored coordinates (keeping
// the last position in ordered list) or even apply binary search to find
// the coordinate.
//
// TODO: would be cleaner to understand/document if the first argument
// gave the number of input tensors, instead of the current number of
// input+output tensors.
Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
unsigned numFilterLoops, unsigned maxLvlRank);

Expand Down Expand Up @@ -383,12 +365,15 @@ class Merger {

/// Gets the total number of loops (native loops + filter loops).
constexpr unsigned getNumLoops() const { return numLoops; }

/// Gets the number of native loops.
constexpr unsigned getNumNativeLoops() const { return numNativeLoops; }

/// Gets the number of filter loops.
constexpr unsigned getNumFilterLoops() const {
return numLoops - numNativeLoops;
}

/// Gets the identifier of the first filter-loop.
constexpr LoopId getStartingFilterLoopId() const {
return getNumNativeLoops();
Expand Down Expand Up @@ -473,8 +458,7 @@ class Merger {
lvlTypes[t][i] = dlt;
loopToLvl[t][i] = lvl;
lvlToLoop[t][lvl] = i;
// TODO: Maybe we should favor a constant loop bound when there are multiple
// choices.
// TODO: favor a constant loop bound when there are multiple choices.
loopBounds[i] = std::make_pair(t, lvl);
}

Expand Down Expand Up @@ -600,43 +584,19 @@ class Merger {
/// Checks whether the given expression has an associated value.
bool hasExprValue(ExprId e) const { return static_cast<bool>(exp(e).val); }

/// Sets the expression to have the associated value. Asserts that
/// the new value is defined, and that the expression does not already
/// have a value. If you want to overwrite a previous associated value,
/// use `updateExprValue` instead.
/// Sets the expression to have the associated value. Asserts that the new
/// value is defined, and that the expression does not already have a value.
void setExprValue(ExprId e, Value v) {
assert(isValidExprId(e));
assert(v && "Got an undefined value");
auto &val = tensorExps[e].val;
assert(!val && "Expression already has an associated value");
val = v;
assert(!exp(e).val && "Expression already has an associated value");
assert(v && "Trying to assign an undefined value");
tensorExps[e].val = v;
}

/// Clears the value associated with the expression. Asserts that the
/// Clears the value associated with the expression. Asserts that the
/// expression does indeed have an associated value before clearing it.
/// If you don't want to check for a previous associated value first,
/// then use `updateExprValue` instead.
void clearExprValue(ExprId e) {
assert(isValidExprId(e));
auto &val = tensorExps[e].val;
assert(val && "Expression does not have an associated value to clear");
val = Value();
}

/// Unilaterally updates the expression to have the associated value.
/// That is, unlike `setExprValue` and `clearExprValue`, this method
/// does not perform any checks on whether the expression had a
/// previously associated value nor whether the new value is defined.
//
// TODO: The unilateral update semantics are required by the
// current implementation of `CodegenEnv::genLoopBoundary`; however,
// that implementation seems a bit dubious. We would much rather have
// the semantics `{ clearExprValue(e); setExprValue(e, v); }` or
// `{ clearExprValue(e); if (v) setExprValue(e, v); }` since those
// provide better invariants.
void updateExprValue(ExprId e, Value v) {
assert(isValidExprId(e));
tensorExps[e].val = v;
assert(exp(e).val && "Expression does not have an associated value");
tensorExps[e].val = Value();
}

#ifndef NDEBUG
Expand Down Expand Up @@ -706,12 +666,10 @@ class Merger {
// `operator[]`: `SmallVector` performs OOB checks, whereas `std::vector`
// does not.

/// Map that converts pair<TensorId, LoopId> to the corresponding
/// level-type.
/// Map that converts pair<TensorId, LoopId> to the corresponding lvl-type.
std::vector<std::vector<DimLevelType>> lvlTypes;

/// Map that converts pair<TensorId, LoopId> to the corresponding
/// level.
/// Map that converts pair<TensorId, LoopId> to the corresponding lvl.
std::vector<std::vector<std::optional<Level>>> loopToLvl;

/// Map that converts pair<TensorId, Level> to the corresponding LoopId.
Expand Down
12 changes: 4 additions & 8 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
auto r = callback(params); // may update parameters
unsigned i = 0;
if (isReduc()) {
// FIXME: This requires `updateExprValue` to perform updates without
// checking for a previous value; but it's not clear whether that's
// by design or might be a potential source for bugs.
updateReduc(params[i++]);
if (redValidLexInsert)
setValidLexInsert(params[i++]);
Expand Down Expand Up @@ -283,16 +280,15 @@ void CodegenEnv::endExpand() {
void CodegenEnv::startReduc(ExprId exp, Value val) {
assert(!isReduc() && exp != detail::kInvalidId);
redExp = exp;
updateReduc(val);
redVal = val;
latticeMerger.setExprValue(exp, val);
}

void CodegenEnv::updateReduc(Value val) {
assert(isReduc());
redVal = val;
// NOTE: `genLoopBoundary` requires that this performs a unilateral
// update without checking for a previous value first. (It's not
// clear whether any other callsites also require that.)
latticeMerger.updateExprValue(redExp, val);
latticeMerger.clearExprValue(redExp);
latticeMerger.setExprValue(redExp, val);
}

Value CodegenEnv::endReduc() {
Expand Down