diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h index efc3f82d6a307..1b5f0553a3af9 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h @@ -201,10 +201,11 @@ class SparseTensorReader final { const uint64_t *lvl2dim) { const uint64_t dimRank = getRank(); MapRef map(dimRank, lvlRank, dim2lvl, lvl2dim); - auto *coo = readCOO(map, lvlSizes); + auto *lvlCOO = readCOO(map, lvlSizes); auto *tensor = SparseTensorStorage::newFromCOO( - dimRank, getDimSizes(), lvlRank, lvlTypes, dim2lvl, lvl2dim, *coo); - delete coo; + dimRank, getDimSizes(), lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, + *lvlCOO); + delete lvlCOO; return tensor; } diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h index bafc9baa7edde..f1aeb12c662fd 100644 --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h @@ -10,8 +10,6 @@ // // * `SparseTensorStorageBase` // * `SparseTensorStorage` -// * `SparseTensorEnumeratorBase` -// * `SparseTensorEnumerator` // //===----------------------------------------------------------------------===// @@ -28,26 +26,15 @@ namespace mlir { namespace sparse_tensor { -/// The type of callback functions which receive an element. -template -using ElementConsumer = - const std::function &, V)> &; - -// Forward references. -template -class SparseTensorEnumeratorBase; -template -class SparseTensorEnumerator; - //===----------------------------------------------------------------------===// // -// SparseTensorStorage +// SparseTensorStorage Classes // //===----------------------------------------------------------------------===// /// Abstract base class for `SparseTensorStorage`. This class /// takes responsibility for all the ``-independent aspects -/// of the tensor (e.g., shape, sparsity, mapping). In addition, +/// of the tensor (e.g., sizes, sparsity, mapping). In addition, /// we use function overloading to implement "partial" method /// specialization, which the C-API relies on to catch type errors /// arising from our use of opaque pointers. @@ -55,7 +42,7 @@ class SparseTensorEnumerator; /// Because this class forms a bridge between the denotational semantics /// of "tensors" and the operational semantics of how we store and /// compute with them, it also distinguishes between two different -/// coordinate spaces (and their associated rank, shape, sizes, etc). +/// coordinate spaces (and their associated rank, sizes, etc). /// Denotationally, we have the *dimensions* of the tensor represented /// by this object. Operationally, we have the *levels* of the storage /// representation itself. @@ -139,10 +126,6 @@ class SparseTensorStorageBase { /// Safely checks if the level is unique. bool isUniqueLvl(uint64_t l) const { return isUniqueDLT(getLvlType(l)); } - /// Gets the level-to-dimension mapping. - // TODO: REMOVE THIS - const std::vector &getLvl2Dim() const { return lvl2dimVec; } - /// Gets positions-overhead storage for the given level. #define DECL_GETPOSITIONS(PNAME, P) \ virtual void getPositions(std::vector

**, uint64_t); @@ -154,6 +137,7 @@ class SparseTensorStorageBase { virtual void getCoordinates(std::vector **, uint64_t); MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETCOORDINATES) #undef DECL_GETCOORDINATES + /// Gets the coordinate-value stored at the given level and position. virtual uint64_t getCrd(uint64_t lvl, uint64_t pos) const = 0; @@ -220,8 +204,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase { const uint64_t *lvl2dim) : SparseTensorStorageBase(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim), - positions(lvlRank), coordinates(lvlRank), lvlCursor(lvlRank), lvlCOO() { - } + positions(lvlRank), coordinates(lvlRank), lvlCursor(lvlRank), coo() {} public: /// Constructs a sparse tensor with the given encoding, and allocates @@ -234,24 +217,16 @@ class SparseTensorStorage final : public SparseTensorStorageBase { SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, const uint64_t *lvlSizes, const DimLevelType *lvlTypes, const uint64_t *dim2lvl, - const uint64_t *lvl2dim, SparseTensorCOO *coo, + const uint64_t *lvl2dim, SparseTensorCOO *lvlCOO, bool initializeValuesIfAllDense); /// Constructs a sparse tensor with the given encoding, and initializes /// the contents from the COO. This ctor performs the same heuristic /// overhead-storage allocation as the ctor above. SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes, - uint64_t lvlRank, const DimLevelType *lvlTypes, - const uint64_t *dim2lvl, const uint64_t *lvl2dim, - SparseTensorCOO &lvlCOO); - - /// Constructs a sparse tensor with the given encoding, and initializes - /// the contents from the enumerator. This ctor allocates exactly - /// the required amount of overhead storage, not using any heuristics. - SparseTensorStorage(uint64_t dimRank, const uint64_t *dimSizes, - uint64_t lvlRank, const DimLevelType *lvlTypes, - const uint64_t *dim2lvl, const uint64_t *lvl2dim, - SparseTensorEnumeratorBase &lvlEnumerator); + uint64_t lvlRank, const uint64_t *lvlSizes, + const DimLevelType *lvlTypes, const uint64_t *dim2lvl, + const uint64_t *lvl2dim, SparseTensorCOO &lvlCOO); /// Constructs a sparse tensor with the given encoding, and initializes /// the contents from the level buffers. This ctor allocates exactly @@ -265,39 +240,27 @@ class SparseTensorStorage final : public SparseTensorStorageBase { const DimLevelType *lvlTypes, const uint64_t *dim2lvl, const uint64_t *lvl2dim, const intptr_t *lvlBufs); - /// Allocates a new empty sparse tensor. The preconditions/assertions - /// are as per the `SparseTensorStorageBase` ctor; which is to say, - /// the `dimSizes` and `lvlSizes` must both be "sizes" not "shapes", - /// since there's nowhere to reconstruct dynamic sizes from. + /// Allocates a new empty sparse tensor. static SparseTensorStorage * newEmpty(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, const uint64_t *lvlSizes, const DimLevelType *lvlTypes, const uint64_t *dim2lvl, const uint64_t *lvl2dim, bool forwarding); /// Allocates a new sparse tensor and initializes it from the given COO. - /// The preconditions are as per the `SparseTensorStorageBase` ctor - /// (where we define `lvlSizes = lvlCOO.getDimSizes().data()`), but - /// using the following assertions in lieu of the base ctor's assertions: - // - // TODO: The ability to reconstruct dynamic dimensions-sizes does not - // easily generalize to arbitrary `lvl2dim` mappings. When compiling - // MLIR programs to use this library, we should be able to generate - // code for effectively computing the reconstruction, but it's not clear - // that there's a feasible way to do so from within the library itself. - // Therefore, when we functionalize the `lvl2dim` mapping we'll have - // to update the type/preconditions of this factory too. static SparseTensorStorage * - newFromCOO(uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank, - const DimLevelType *lvlTypes, const uint64_t *dim2lvl, - const uint64_t *lvl2dim, SparseTensorCOO &lvlCOO); + newFromCOO(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, + const uint64_t *lvlSizes, const DimLevelType *lvlTypes, + const uint64_t *dim2lvl, const uint64_t *lvl2dim, + SparseTensorCOO &lvlCOO); /// Allocates a new sparse tensor and initialize it with the data stored level /// buffers directly. - static SparseTensorStorage *packFromLvlBuffers( - uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank, - const uint64_t *lvlSizes, const DimLevelType *lvlTypes, - const uint64_t *src2lvl, // FIXME: dim2lvl - const uint64_t *lvl2dim, uint64_t srcRank, const intptr_t *buffers); + static SparseTensorStorage * + packFromLvlBuffers(uint64_t dimRank, const uint64_t *dimSizes, + uint64_t lvlRank, const uint64_t *lvlSizes, + const DimLevelType *lvlTypes, const uint64_t *dim2lvl, + const uint64_t *lvl2dim, uint64_t srcRank, + const intptr_t *buffers); ~SparseTensorStorage() final = default; @@ -326,16 +289,14 @@ class SparseTensorStorage final : public SparseTensorStorageBase { /// Partially specialize forwarding insertions based on template types. void forwardingInsert(const uint64_t *dimCoords, V val) final { - assert(dimCoords && lvlCOO); + assert(dimCoords && coo); map.pushforward(dimCoords, lvlCursor.data()); - lvlCOO->add(lvlCursor, val); + coo->add(lvlCursor, val); } /// Partially specialize lexicographical insertions based on template types. void lexInsert(const uint64_t *lvlCoords, V val) final { assert(lvlCoords); - // TODO: get rid of this! canonicalize all-dense "sparse" array into dense - // tensors. bool allDense = std::all_of(getLvlTypes().begin(), getLvlTypes().end(), [](DimLevelType lt) { return isDenseDLT(lt); }); if (allDense) { @@ -391,16 +352,17 @@ class SparseTensorStorage final : public SparseTensorStorageBase { /// Finalizes forwarding insertions. void endForwardingInsert() final { - // Ensure lvlCOO is sorted. - assert(lvlCOO); - lvlCOO->sort(); + // Ensure COO is sorted. + assert(coo); + coo->sort(); // Now actually insert the `elements`. - const auto &elements = lvlCOO->getElements(); + const auto &elements = coo->getElements(); const uint64_t nse = elements.size(); assert(values.size() == 0); values.reserve(nse); fromCOO(elements, 0, nse, 0); - delete lvlCOO; + delete coo; + coo = nullptr; } /// Finalizes lexicographic insertions. @@ -411,23 +373,12 @@ class SparseTensorStorage final : public SparseTensorStorageBase { endPath(0); } - /// Allocates a new COO object and initializes it with the contents - /// of this tensor under the given mapping from the `getDimSizes()` - /// coordinate-space to the `trgSizes` coordinate-space. Callers must - /// make sure to delete the COO when they're done with it. - SparseTensorCOO *toCOO(uint64_t trgRank, const uint64_t *trgSizes, - uint64_t srcRank, - const uint64_t *src2trg, // FIXME: dim2lvl - const uint64_t *lvl2dim) const { - // TODO: use MapRef here too for the translation - SparseTensorEnumerator enumerator(*this, trgRank, trgSizes, - srcRank, src2trg); - auto *coo = new SparseTensorCOO(trgRank, trgSizes, values.size()); - enumerator.forallElements( - [&coo](const auto &trgCoords, V val) { coo->add(trgCoords, val); }); - // TODO: This assertion assumes there are no stored zeros, - // or if there are then that we don't filter them out. - // + /// Allocates a new COO object and initializes it with the contents. + /// Callers must make sure to delete the COO when they're done with it. + SparseTensorCOO *toCOO() { + std::vector dimCoords(getDimRank()); + coo = new SparseTensorCOO(getDimSizes(), values.size()); + toCOO(0, 0, dimCoords); assert(coo->getElements().size() == values.size()); return coo; } @@ -525,27 +476,11 @@ class SparseTensorStorage final : public SparseTensorStorageBase { } } - /// Writes the given coordinate to `coordinates[lvl][pos]`. This method - /// checks that `crd` is representable in the `C` type; however, it - /// does not check that `crd` is semantically valid (i.e., in bounds - /// for `dimSizes[lvl]` and not elsewhere occurring in the same segment). - void writeCrd(uint64_t lvl, uint64_t pos, uint64_t crd) { - assert(isCompressedDLT(getLvlType(lvl)) || isSingletonDLT(getLvlType(lvl))); - // Subscript assignment to `std::vector` requires that the `pos`-th - // entry has been initialized; thus we must be sure to check `size()` - // here, instead of `capacity()` as would be ideal. - assert(pos < coordinates[lvl].size()); - coordinates[lvl][pos] = detail::checkOverflowCast(crd); - } - /// Computes the assembled-size associated with the `l`-th level, /// given the assembled-size associated with the `(l-1)`-th level. /// "Assembled-sizes" correspond to the (nominal) sizes of overhead /// storage, as opposed to "level-sizes" which are the cardinality /// of possible coordinates for that level. - /// - /// Precondition: the `positions[l]` array must be fully initialized - /// before calling this method. uint64_t assembledSize(uint64_t parentSz, uint64_t l) const { const auto dlt = getLvlType(l); // Avoid redundant bounds checking. if (isCompressedDLT(dlt)) @@ -553,7 +488,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase { if (isSingletonDLT(dlt)) return parentSz; // New size is same as the parent. if (isDenseDLT(dlt)) - return parentSz * getLvlSizes()[l]; + return parentSz * getLvlSize(l); MLIR_SPARSETENSOR_FATAL("unsupported level type: %d\n", static_cast(dlt)); } @@ -561,11 +496,6 @@ class SparseTensorStorage final : public SparseTensorStorageBase { /// Initializes sparse tensor storage scheme from a memory-resident sparse /// tensor in coordinate scheme. This method prepares the positions and /// coordinates arrays under the given per-level dense/sparse annotations. - /// - /// Preconditions: - /// * the `lvlElements` must be lexicographically sorted. - /// * the coordinates of every element are valid for `getLvlSizes()` - /// (i.e., equal rank and pointwise less-than). void fromCOO(const std::vector> &lvlElements, uint64_t lo, uint64_t hi, uint64_t l) { const uint64_t lvlRank = getLvlRank(); @@ -669,184 +599,48 @@ class SparseTensorStorage final : public SparseTensorStorageBase { return -1u; } - // Allow `SparseTensorEnumerator` to access the data-members (to avoid - // the cost of virtual-function dispatch in inner loops), without - // making them public to other client code. - friend class SparseTensorEnumerator; - - std::vector> positions; - std::vector> coordinates; - std::vector values; - std::vector lvlCursor; // cursor for lexicographic insertion. - SparseTensorCOO *lvlCOO; // COO used during forwarding -}; - -//===----------------------------------------------------------------------===// -// -// SparseTensorEnumerator -// -//===----------------------------------------------------------------------===// - -/// A (higher-order) function object for enumerating the elements of some -/// `SparseTensorStorage` under a permutation. That is, the `forallElements` -/// method encapsulates the loop-nest for enumerating the elements of -/// the source tensor (in whatever order is best for the source tensor), -/// and applies a permutation to the coordinates before handing -/// each element to the callback. A single enumerator object can be -/// freely reused for several calls to `forallElements`, just so long -/// as each call is sequential with respect to one another. -/// -/// N.B., this class stores a reference to the `SparseTensorStorageBase` -/// passed to the constructor; thus, objects of this class must not -/// outlive the sparse tensor they depend on. -/// -/// Design Note: The reason we define this class instead of simply using -/// `SparseTensorEnumerator` is because we need to hide/generalize -/// the `` template parameters from MLIR client code (to simplify the -/// type parameters used for direct sparse-to-sparse conversion). And the -/// reason we define the `SparseTensorEnumerator` subclasses rather -/// than simply using this class, is to avoid the cost of virtual-method -/// dispatch within the loop-nest. -template -class SparseTensorEnumeratorBase { -public: - /// Constructs an enumerator which automatically applies the given - /// mapping from the source tensor's dimensions to the desired - /// target tensor dimensions. - /// - /// Preconditions: - /// * the `src` must have the same `V` value type. - /// * `trgSizes` must be valid for `trgRank`. - /// * `src2trg` must be valid for `srcRank`, and must map coordinates - /// valid for `src.getDimSizes()` to coordinates valid for `trgSizes`. - /// - /// Asserts: - /// * `trgSizes` must be nonnull and must contain only nonzero sizes. - /// * `srcRank == src.getDimRank()`. - /// * `src2trg` must be nonnull. - SparseTensorEnumeratorBase(const SparseTensorStorageBase &src, - uint64_t trgRank, const uint64_t *trgSizes, - uint64_t srcRank, const uint64_t *src2trg) - : src(src), trgSizes(trgSizes, trgSizes + trgRank), - lvl2trg(src.getLvlRank()), trgCursor(trgRank) { - assert(trgSizes && "Received nullptr for target-sizes"); - assert(src2trg && "Received nullptr for source-to-target mapping"); - assert(srcRank == src.getDimRank() && "Source-rank mismatch"); - for (uint64_t t = 0; t < trgRank; ++t) - assert(trgSizes[t] > 0 && "Target-size zero has trivial storage"); - const auto &lvl2src = src.getLvl2Dim(); - for (uint64_t lvlRank = src.getLvlRank(), l = 0; l < lvlRank; ++l) - lvl2trg[l] = src2trg[lvl2src[l]]; - } - - virtual ~SparseTensorEnumeratorBase() = default; - - // We disallow copying to help avoid leaking the `src` reference. - // (In addition to avoiding the problem of slicing.) - SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete; - SparseTensorEnumeratorBase & - operator=(const SparseTensorEnumeratorBase &) = delete; - - /// Gets the source's dimension-rank. - uint64_t getSrcDimRank() const { return src.getDimRank(); } - - /// Gets the target's dimension-/level-rank. (This is usually - /// "dimension-rank", though that may coincide with "level-rank" - /// depending on usage.) - uint64_t getTrgRank() const { return trgSizes.size(); } - - /// Gets the target's dimension-/level-sizes. (These are usually - /// "dimensions", though that may coincide with "level-rank" depending - /// on usage.) - const std::vector &getTrgSizes() const { return trgSizes; } - - /// Enumerates all elements of the source tensor, permutes their - /// coordinates, and passes the permuted element to the callback. - /// The callback must not store the cursor reference directly, - /// since this function reuses the storage. Instead, the callback - /// must copy it if they want to keep it. - virtual void forallElements(ElementConsumer yield) = 0; - -protected: - const SparseTensorStorageBase &src; - std::vector trgSizes; // in target order. - std::vector lvl2trg; // source-levels -> target-dims/lvls. - std::vector trgCursor; // in target order. -}; - -template -class SparseTensorEnumerator final : public SparseTensorEnumeratorBase { - using Base = SparseTensorEnumeratorBase; - using StorageImpl = SparseTensorStorage; - -public: - /// Constructs an enumerator which automatically applies the given - /// mapping from the source tensor's dimensions to the desired - /// target tensor dimensions. - /// - /// Preconditions/assertions are as per the `SparseTensorEnumeratorBase` ctor. - SparseTensorEnumerator(const StorageImpl &src, uint64_t trgRank, - const uint64_t *trgSizes, uint64_t srcRank, - const uint64_t *src2trg) - : Base(src, trgRank, trgSizes, srcRank, src2trg) {} - - ~SparseTensorEnumerator() final = default; - - void forallElements(ElementConsumer yield) final { - forallElements(yield, 0, 0); - } - -private: - // TODO: Once we functionalize the mappings, then we'll no longer - // be able to use the current approach of constructing `lvl2trg` in the - // ctor and using it to incrementally fill the `trgCursor` cursor as we - // recurse through `forallElements`. Instead we'll want to incrementally - // fill a `lvlCursor` as we recurse, and then use `src.getLvl2Dim()` - // and `src2trg` to convert it just before yielding to the callback. - // It's probably most efficient to just store the `srcCursor` and - // `trgCursor` buffers in this object, but we may want to benchmark - // that against using `std::calloc` to stack-allocate them instead. - // - /// The recursive component of the public `forallElements`. - void forallElements(ElementConsumer yield, uint64_t parentPos, - uint64_t l) { - // Recover the `` type parameters of `src`. - const auto &src = static_cast(this->src); - if (l == src.getLvlRank()) { - assert(parentPos < src.values.size()); - // TODO: - yield(this->trgCursor, src.values[parentPos]); + // Performs forall on level entries and inserts into dim COO. + void toCOO(uint64_t parentPos, uint64_t l, std::vector &dimCoords) { + if (l == getLvlRank()) { + map.pushbackward(lvlCursor.data(), dimCoords.data()); + assert(coo); + assert(parentPos < values.size()); + coo->add(dimCoords, values[parentPos]); return; } - uint64_t &cursorL = this->trgCursor[this->lvl2trg[l]]; - const auto dlt = src.getLvlType(l); // Avoid redundant bounds checking. - if (isCompressedDLT(dlt)) { + if (isCompressedLvl(l)) { // Look up the bounds of the `l`-level segment determined by the // `(l - 1)`-level position `parentPos`. - const std::vector

&positionsL = src.positions[l]; + const std::vector

&positionsL = positions[l]; assert(parentPos + 1 < positionsL.size()); const uint64_t pstart = static_cast(positionsL[parentPos]); const uint64_t pstop = static_cast(positionsL[parentPos + 1]); // Loop-invariant code for looking up the `l`-level coordinates. - const std::vector &coordinatesL = src.coordinates[l]; + const std::vector &coordinatesL = coordinates[l]; assert(pstop <= coordinatesL.size()); for (uint64_t pos = pstart; pos < pstop; ++pos) { - cursorL = static_cast(coordinatesL[pos]); - forallElements(yield, pos, l + 1); + lvlCursor[l] = static_cast(coordinatesL[pos]); + toCOO(pos, l + 1, dimCoords); } - } else if (isSingletonDLT(dlt)) { - cursorL = src.getCrd(l, parentPos); - forallElements(yield, parentPos, l + 1); + } else if (isSingletonLvl(l)) { + lvlCursor[l] = getCrd(l, parentPos); + toCOO(parentPos, l + 1, dimCoords); } else { // Dense level. - assert(isDenseDLT(dlt)); - const uint64_t sz = src.getLvlSizes()[l]; + assert(isDenseLvl(l)); + const uint64_t sz = getLvlSizes()[l]; const uint64_t pstart = parentPos * sz; for (uint64_t c = 0; c < sz; ++c) { - cursorL = c; - forallElements(yield, pstart + c, l + 1); + lvlCursor[l] = c; + toCOO(pstart + c, l + 1, dimCoords); } } } + + std::vector> positions; + std::vector> coordinates; + std::vector values; + std::vector lvlCursor; + SparseTensorCOO *coo; }; //===----------------------------------------------------------------------===// @@ -868,41 +662,24 @@ SparseTensorStorage *SparseTensorStorage::newEmpty( !forwarding); } -// TODO: MapRef template SparseTensorStorage *SparseTensorStorage::newFromCOO( - uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank, - const DimLevelType *lvlTypes, const uint64_t *dim2lvl, - const uint64_t *lvl2dim, SparseTensorCOO &lvlCOO) { - assert(dimShape && dim2lvl && lvl2dim); - const auto &lvlSizes = lvlCOO.getDimSizes(); - assert(lvlRank == lvlSizes.size() && "Level-rank mismatch"); - // Must reconstruct `dimSizes` from `lvlSizes`. While this is easy - // enough to do when `lvl2dim` is a permutation, this approach will - // not work for more general mappings; so we will need to move this - // computation off to codegen. - std::vector dimSizes(dimRank); - for (uint64_t l = 0; l < lvlRank; ++l) { - const uint64_t d = lvl2dim[l]; - assert((dimShape[d] == 0 || dimShape[d] == lvlSizes[l]) && - "Dimension sizes do not match expected shape"); - dimSizes[d] = lvlSizes[l]; - } - return new SparseTensorStorage(dimRank, dimSizes.data(), lvlRank, + uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, + const uint64_t *lvlSizes, const DimLevelType *lvlTypes, + const uint64_t *dim2lvl, const uint64_t *lvl2dim, + SparseTensorCOO &lvlCOO) { + return new SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, lvlCOO); } template SparseTensorStorage *SparseTensorStorage::packFromLvlBuffers( - uint64_t dimRank, const uint64_t *dimShape, uint64_t lvlRank, + uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, const uint64_t *lvlSizes, const DimLevelType *lvlTypes, - const uint64_t *src2lvl, // FIXME: dim2lvl - const uint64_t *lvl2dim, uint64_t srcRank, const intptr_t *buffers) { - assert(dimShape && "Got nullptr for dimension shape"); - auto *tensor = - new SparseTensorStorage(dimRank, dimShape, lvlRank, lvlSizes, - lvlTypes, src2lvl, lvl2dim, buffers); - return tensor; + const uint64_t *dim2lvl, const uint64_t *lvl2dim, uint64_t srcRank, + const intptr_t *buffers) { + return new SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, + lvlTypes, dim2lvl, lvl2dim, buffers); } //===----------------------------------------------------------------------===// @@ -915,11 +692,12 @@ template SparseTensorStorage::SparseTensorStorage( uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, const uint64_t *lvlSizes, const DimLevelType *lvlTypes, - const uint64_t *dim2lvl, const uint64_t *lvl2dim, SparseTensorCOO *coo, - bool initializeValuesIfAllDense) + const uint64_t *dim2lvl, const uint64_t *lvl2dim, + SparseTensorCOO *lvlCOO, bool initializeValuesIfAllDense) : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim) { - lvlCOO = coo; + assert(!lvlCOO || lvlRank == lvlCOO->getRank()); + coo = lvlCOO; // Provide hints on capacity of positions and coordinates. // TODO: needs much fine-tuning based on actual sparsity; currently // we reserve position/coordinate space based on all previous dense @@ -948,17 +726,16 @@ SparseTensorStorage::SparseTensorStorage( values.resize(sz, 0); } -// TODO: share more code with forwarding methods? template SparseTensorStorage::SparseTensorStorage( // NOLINT uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank, - const DimLevelType *lvlTypes, const uint64_t *dim2lvl, - const uint64_t *lvl2dim, SparseTensorCOO &lvlCOO) - : SparseTensorStorage(dimRank, dimSizes, lvlRank, - lvlCOO.getDimSizes().data(), lvlTypes, dim2lvl, - lvl2dim, nullptr, false) { + const uint64_t *lvlSizes, const DimLevelType *lvlTypes, + const uint64_t *dim2lvl, const uint64_t *lvl2dim, + SparseTensorCOO &lvlCOO) + : SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, + dim2lvl, lvl2dim, nullptr, false) { // Ensure lvlCOO is sorted. - assert(lvlRank == lvlCOO.getDimSizes().size() && "Level-rank mismatch"); + assert(lvlRank == lvlCOO.getRank()); lvlCOO.sort(); // Now actually insert the `elements`. const auto &elements = lvlCOO.getElements(); diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp index 6a4c0f292c5f8..36d888a08de6d 100644 --- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp @@ -129,7 +129,8 @@ extern "C" { assert(ptr && "Received nullptr for SparseTensorCOO object"); \ auto &coo = *static_cast *>(ptr); \ return SparseTensorStorage::newFromCOO( \ - dimRank, dimSizes, lvlRank, lvlTypes, dim2lvl, lvl2dim, coo); \ + dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \ + coo); \ } \ case Action::kFromReader: { \ assert(ptr && "Received nullptr for SparseTensorReader object"); \ @@ -140,7 +141,7 @@ extern "C" { case Action::kToCOO: { \ assert(ptr && "Received nullptr for SparseTensorStorage object"); \ auto &tensor = *static_cast *>(ptr); \ - return tensor.toCOO(lvlRank, lvlSizes, dimRank, dim2lvl, lvl2dim); \ + return tensor.toCOO(); \ } \ case Action::kPack: { \ assert(ptr && "Received nullptr for SparseTensorStorage object"); \