@@ -177,9 +177,16 @@ extern "C" {
177177#define CASE (p, c, v, P, C, V ) \
178178 if (posTp == (p) && crdTp == (c) && valTp == (v)) { \
179179 switch (action) { \
180- case Action::kEmpty : \
180+ case Action::kEmpty : { \
181181 return SparseTensorStorage<P, C, V>::newEmpty ( \
182- dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim); \
182+ dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
183+ false ); \
184+ } \
185+ case Action::kEmptyForward : { \
186+ return SparseTensorStorage<P, C, V>::newEmpty ( \
187+ dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
188+ true ); \
189+ } \
183190 case Action::kFromCOO : { \
184191 assert (ptr && " Received nullptr for SparseTensorCOO object" ); \
185192 auto &coo = *static_cast <SparseTensorCOO<V> *>(ptr); \
@@ -193,8 +200,9 @@ extern "C" {
193200 dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
194201 dimRank, tensor); \
195202 } \
196- case Action::kEmptyCOO : \
197- return new SparseTensorCOO<V>(lvlRank, lvlSizes); \
203+ case Action::kFuture : { \
204+ break ; \
205+ } \
198206 case Action::kToCOO : { \
199207 assert (ptr && " Received nullptr for SparseTensorStorage object" ); \
200208 auto &tensor = *static_cast <SparseTensorStorage<P, C, V> *>(ptr); \
@@ -405,29 +413,20 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
405413#undef IMPL_SPARSECOORDINATES
406414#undef IMPL_GETOVERHEAD
407415
408- // TODO: use MapRef here for translation of coordinates
409- // TODO: remove dim2lvl
410- #define IMPL_ADDELT (VNAME, V ) \
411- void *_mlir_ciface_addElt##VNAME( \
412- void *lvlCOO, StridedMemRefType<V, 0 > *vref, \
413- StridedMemRefType<index_type, 1 > *dimCoordsRef, \
414- StridedMemRefType<index_type, 1 > *dim2lvlRef) { \
415- assert (lvlCOO &&vref); \
416+ #define IMPL_FORWARDINGINSERT (VNAME, V ) \
417+ void _mlir_ciface_forwardingInsert##VNAME( \
418+ void *t, StridedMemRefType<V, 0 > *vref, \
419+ StridedMemRefType<index_type, 1 > *dimCoordsRef) { \
420+ assert (t &&vref); \
416421 ASSERT_NO_STRIDE (dimCoordsRef); \
417- ASSERT_NO_STRIDE (dim2lvlRef); \
418- const uint64_t rank = MEMREF_GET_USIZE (dimCoordsRef); \
419- ASSERT_USIZE_EQ (dim2lvlRef, rank); \
420422 const index_type *dimCoords = MEMREF_GET_PAYLOAD (dimCoordsRef); \
421- const index_type *dim2lvl = MEMREF_GET_PAYLOAD (dim2lvlRef); \
422- std::vector<index_type> lvlCoords (rank); \
423- for (uint64_t d = 0 ; d < rank; ++d) \
424- lvlCoords[dim2lvl[d]] = dimCoords[d]; \
425- V *value = MEMREF_GET_PAYLOAD (vref); \
426- static_cast <SparseTensorCOO<V> *>(lvlCOO)->add (lvlCoords, *value); \
427- return lvlCOO; \
423+ assert (dimCoords); \
424+ const V *value = MEMREF_GET_PAYLOAD (vref); \
425+ static_cast <SparseTensorStorageBase *>(t)->forwardingInsert (dimCoords, \
426+ *value); \
428427 }
429- MLIR_SPARSETENSOR_FOREVERY_V (IMPL_ADDELT )
430- #undef IMPL_ADDELT
428+ MLIR_SPARSETENSOR_FOREVERY_V (IMPL_FORWARDINGINSERT )
429+ #undef IMPL_FORWARDINGINSERT
431430
432431// NOTE: the `cref` argument uses the same coordinate-space as the `iter`
433432// (which can be either dim- or lvl-coords, depending on context).
@@ -692,8 +691,12 @@ index_type sparseDimSize(void *tensor, index_type d) {
692691 return static_cast <SparseTensorStorageBase *>(tensor)->getDimSize (d);
693692}
694693
695- void endInsert (void *tensor) {
696- return static_cast <SparseTensorStorageBase *>(tensor)->endInsert ();
694+ void endForwardingInsert (void *tensor) {
695+ return static_cast <SparseTensorStorageBase *>(tensor)->endForwardingInsert ();
696+ }
697+
698+ void endLexInsert (void *tensor) {
699+ return static_cast <SparseTensorStorageBase *>(tensor)->endLexInsert ();
697700}
698701
699702#define IMPL_OUTSPARSETENSOR (VNAME, V ) \
0 commit comments