From 3c0a5bb5dac66a1e105f016008d2ad995938e73c Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Thu, 26 Oct 2023 12:31:57 -0700 Subject: [PATCH] [mlir][sparse] merger cleanup Implemented some TODOs and removed unlikely ones. Comment cleanup --- .../mlir/Dialect/SparseTensor/Utils/Merger.h | 78 +++++-------------- .../SparseTensor/Transforms/CodegenEnv.cpp | 12 +-- 2 files changed, 22 insertions(+), 68 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h index 5e75380067572..215920f8b4607 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -122,19 +122,10 @@ struct TensorExp final { /// /// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`. /// That is, its argument is a `LoopId` identifying the loop-variable -/// in question, and its value will be the current iteration's value -/// of that loop-variable. See the `LoopId` documentation for more details. -/// -/// The `kSynZero` leaf kind is for representing a synthetic zero value, which -/// can be introduced when sparsifying operations like `arith::cmp` to generate -/// `arith::cmp %lhs, %syn_zero` when the rhs operand is absent. -// -// TODO: Modify this definition so that the numeric values already encode -// the `ExpArity` (while extending the notion of "arity" to include not -// just the number of `ExprId` children the node has, but also whether the -// node has a `Value` and/or `Operation*`). Doing this will avoid needing -// to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor, -// and should help clean up a few other places as well. +/// in question, and its value will be the current iteration's value. +/// The `kSynZero` leaf kind is for representing a synthetic zero value, +/// which can be introduced when sparsifying operations like `arith::cmp` +/// to generate `arith::cmp %lhs, %syn_zero` when the rhs operand is absent. enum class TensorExp::Kind { // Leaf. kTensor = 0, @@ -253,15 +244,6 @@ class Merger { /// /// The maxLvlRank specifies the max level rank of all inputs/output tensors. /// It is used to pre-allocate sufficient memory for internal storage. - // - // TODO: we want to make the filter loop more efficient in the future, - // e.g., by avoiding scanning the full list of stored coordinates (keeping - // the last position in ordered list) or even apply binary search to find - // the coordinate. - // - // TODO: would be cleaner to understand/document if the first argument - // gave the number of input tensors, instead of the current number of - // input+output tensors. Merger(unsigned numInputOutputTensors, unsigned numNativeLoops, unsigned numFilterLoops, unsigned maxLvlRank); @@ -383,12 +365,15 @@ class Merger { /// Gets the total number of loops (native loops + filter loops). constexpr unsigned getNumLoops() const { return numLoops; } + /// Gets the number of native loops. constexpr unsigned getNumNativeLoops() const { return numNativeLoops; } + /// Gets the number of filter loops. constexpr unsigned getNumFilterLoops() const { return numLoops - numNativeLoops; } + /// Gets the identifier of the first filter-loop. constexpr LoopId getStartingFilterLoopId() const { return getNumNativeLoops(); @@ -473,8 +458,7 @@ class Merger { lvlTypes[t][i] = dlt; loopToLvl[t][i] = lvl; lvlToLoop[t][lvl] = i; - // TODO: Maybe we should favor a constant loop bound when there are multiple - // choices. + // TODO: favor a constant loop bound when there are multiple choices. loopBounds[i] = std::make_pair(t, lvl); } @@ -600,43 +584,19 @@ class Merger { /// Checks whether the given expression has an associated value. bool hasExprValue(ExprId e) const { return static_cast(exp(e).val); } - /// Sets the expression to have the associated value. Asserts that - /// the new value is defined, and that the expression does not already - /// have a value. If you want to overwrite a previous associated value, - /// use `updateExprValue` instead. + /// Sets the expression to have the associated value. Asserts that the new + /// value is defined, and that the expression does not already have a value. void setExprValue(ExprId e, Value v) { - assert(isValidExprId(e)); - assert(v && "Got an undefined value"); - auto &val = tensorExps[e].val; - assert(!val && "Expression already has an associated value"); - val = v; + assert(!exp(e).val && "Expression already has an associated value"); + assert(v && "Trying to assign an undefined value"); + tensorExps[e].val = v; } - /// Clears the value associated with the expression. Asserts that the + /// Clears the value associated with the expression. Asserts that the /// expression does indeed have an associated value before clearing it. - /// If you don't want to check for a previous associated value first, - /// then use `updateExprValue` instead. void clearExprValue(ExprId e) { - assert(isValidExprId(e)); - auto &val = tensorExps[e].val; - assert(val && "Expression does not have an associated value to clear"); - val = Value(); - } - - /// Unilaterally updates the expression to have the associated value. - /// That is, unlike `setExprValue` and `clearExprValue`, this method - /// does not perform any checks on whether the expression had a - /// previously associated value nor whether the new value is defined. - // - // TODO: The unilateral update semantics are required by the - // current implementation of `CodegenEnv::genLoopBoundary`; however, - // that implementation seems a bit dubious. We would much rather have - // the semantics `{ clearExprValue(e); setExprValue(e, v); }` or - // `{ clearExprValue(e); if (v) setExprValue(e, v); }` since those - // provide better invariants. - void updateExprValue(ExprId e, Value v) { - assert(isValidExprId(e)); - tensorExps[e].val = v; + assert(exp(e).val && "Expression does not have an associated value"); + tensorExps[e].val = Value(); } #ifndef NDEBUG @@ -706,12 +666,10 @@ class Merger { // `operator[]`: `SmallVector` performs OOB checks, whereas `std::vector` // does not. - /// Map that converts pair to the corresponding - /// level-type. + /// Map that converts pair to the corresponding lvl-type. std::vector> lvlTypes; - /// Map that converts pair to the corresponding - /// level. + /// Map that converts pair to the corresponding lvl. std::vector>> loopToLvl; /// Map that converts pair to the corresponding LoopId. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp index 924b0a0dac811..5c7cc93737b7f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -137,9 +137,6 @@ std::optional CodegenEnv::genLoopBoundary( auto r = callback(params); // may update parameters unsigned i = 0; if (isReduc()) { - // FIXME: This requires `updateExprValue` to perform updates without - // checking for a previous value; but it's not clear whether that's - // by design or might be a potential source for bugs. updateReduc(params[i++]); if (redValidLexInsert) setValidLexInsert(params[i++]); @@ -283,16 +280,15 @@ void CodegenEnv::endExpand() { void CodegenEnv::startReduc(ExprId exp, Value val) { assert(!isReduc() && exp != detail::kInvalidId); redExp = exp; - updateReduc(val); + redVal = val; + latticeMerger.setExprValue(exp, val); } void CodegenEnv::updateReduc(Value val) { assert(isReduc()); redVal = val; - // NOTE: `genLoopBoundary` requires that this performs a unilateral - // update without checking for a previous value first. (It's not - // clear whether any other callsites also require that.) - latticeMerger.updateExprValue(redExp, val); + latticeMerger.clearExprValue(redExp); + latticeMerger.setExprValue(redExp, val); } Value CodegenEnv::endReduc() {