Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] replace specialized buffer setup with util code #68461

Merged
merged 1 commit into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensor( // NOLINT
StridedMemRefType<index_type, 1> *dimSizesRef,
StridedMemRefType<index_type, 1> *lvlSizesRef,
StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
StridedMemRefType<index_type, 1> *lvl2dimRef,
StridedMemRefType<index_type, 1> *dim2lvlRef, OverheadType posTp,
StridedMemRefType<index_type, 1> *dim2lvlRef,
StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
OverheadType crdTp, PrimaryType valTp, Action action, void *ptr);

/// Tensor-storage method to obtain direct access to the values array.
Expand All @@ -85,6 +85,7 @@ MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
#undef DECL_SPARSECOORDINATES

/// Coordinate-scheme method for adding a new element.
/// TODO: remove dim2lvl
#define DECL_ADDELT(VNAME, V) \
MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_addElt##VNAME( \
void *lvlCOO, StridedMemRefType<V, 0> *vref, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,25 +187,38 @@ static Value genLvlPtrsBuffers(OpBuilder &builder, Location loc,

/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
/// the "swiss army knife" method of the sparse runtime support library
/// for materializing sparse tensors into the computation. This abstraction
/// reduces the need to make modifications to client code whenever that
/// API changes.
/// for materializing sparse tensors into the computation. This abstraction
/// reduces the need for modifications when the API changes.
class NewCallParams final {
public:
/// Allocates the `ValueRange` for the `func::CallOp` parameters,
/// but does not initialize them.
/// Allocates the `ValueRange` for the `func::CallOp` parameters.
NewCallParams(OpBuilder &builder, Location loc)
: builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}

/// Initializes all static parameters (i.e., those which indicate
/// type-level information such as the encoding and sizes), generating
/// MLIR buffers as needed, and returning `this` for method chaining.
/// This method does not set the action and pointer arguments, since
/// those are handled by `genNewCall` instead.
NewCallParams &genBuffers(SparseTensorType stt, ValueRange dimSizes);
NewCallParams &genBuffers(SparseTensorType stt,
ArrayRef<Value> dimSizesValues) {
const Dimension dimRank = stt.getDimRank();
assert(dimSizesValues.size() == static_cast<size_t>(dimRank));
// Sparsity annotations.
params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
// Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizesValues);
params[kParamLvlSizes] = genReaderBuffers(
builder, loc, stt, dimSizesValues, params[kParamDimSizes],
params[kParamDim2Lvl], params[kParamLvl2Dim]);
// Secondary and primary types encoding.
setTemplateTypes(stt);
// Finally, make note that initialization is complete.
assert(isInitialized() && "Initialization failed");
// And return `this` for method chaining.
return *this;
}

/// (Re)sets the C++ template type parameters, and returns `this`
/// for method chaining. This is already done as part of `genBuffers`,
/// for method chaining. This is already done as part of `genBuffers`,
/// but is factored out so that it can also be called independently
/// whenever subsequent `genNewCall` calls want to reuse the same
/// buffers but different type parameters.
Expand Down Expand Up @@ -236,7 +249,7 @@ class NewCallParams final {
// this one-off getter, and to avoid potential mixups)?
Value getDimToLvl() const {
assert(isInitialized() && "Must initialize before getDimToLvl");
return params[kParamDimToLvl];
return params[kParamDim2Lvl];
}

/// Generates a function call, with the current static parameters
Expand All @@ -257,8 +270,8 @@ class NewCallParams final {
static constexpr unsigned kParamDimSizes = 0;
static constexpr unsigned kParamLvlSizes = 1;
static constexpr unsigned kParamLvlTypes = 2;
static constexpr unsigned kParamLvlToDim = 3;
static constexpr unsigned kParamDimToLvl = 4;
static constexpr unsigned kParamDim2Lvl = 3;
static constexpr unsigned kParamLvl2Dim = 4;
static constexpr unsigned kParamPosTp = 5;
static constexpr unsigned kParamCrdTp = 6;
static constexpr unsigned kParamValTp = 7;
Expand All @@ -271,62 +284,6 @@ class NewCallParams final {
Value params[kNumParams];
};

// TODO: see the note at `_mlir_ciface_newSparseTensor` about how
aartbik marked this conversation as resolved.
Show resolved Hide resolved
// the meaning of the various arguments (e.g., "sizes" vs "shapes")
// is inconsistent between the different actions.
NewCallParams &NewCallParams::genBuffers(SparseTensorType stt,
ValueRange dimSizes) {
const Level lvlRank = stt.getLvlRank();
const Dimension dimRank = stt.getDimRank();
// Sparsity annotations.
params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
// Dimension-sizes array of the enveloping tensor. Useful for either
// verification of external data, or for construction of internal data.
assert(dimSizes.size() == static_cast<size_t>(dimRank) &&
"Dimension-rank mismatch");
params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizes);
// The level-sizes array must be passed as well, since for arbitrary
// dimToLvl mappings it cannot be trivially reconstructed at runtime.
// For now however, since we're still assuming permutations, we will
// initialize this parameter alongside the `dimToLvl` and `lvlToDim`
// parameters below. We preinitialize `lvlSizes` for code symmetry.
SmallVector<Value> lvlSizes(lvlRank);
// The dimension-to-level mapping and its inverse. We must preinitialize
// `dimToLvl` so that the true branch below can perform random-access
// `operator[]` assignment. We preinitialize `lvlToDim` for code symmetry.
SmallVector<Value> dimToLvl(dimRank);
SmallVector<Value> lvlToDim(lvlRank);
if (!stt.isIdentity()) {
const auto dimToLvlMap = stt.getDimToLvl();
assert(dimToLvlMap.isPermutation());
for (Level l = 0; l < lvlRank; l++) {
// The `d`th source variable occurs in the `l`th result position.
const Dimension d = dimToLvlMap.getDimPosition(l);
dimToLvl[d] = constantIndex(builder, loc, l);
lvlToDim[l] = constantIndex(builder, loc, d);
lvlSizes[l] = dimSizes[d];
}
} else {
// The `SparseTensorType` ctor already ensures `dimRank == lvlRank`
// when `isIdentity`; so no need to re-assert it here.
for (Level l = 0; l < lvlRank; l++) {
dimToLvl[l] = lvlToDim[l] = constantIndex(builder, loc, l);
lvlSizes[l] = dimSizes[l];
}
}
params[kParamLvlSizes] = allocaBuffer(builder, loc, lvlSizes);
params[kParamLvlToDim] = allocaBuffer(builder, loc, lvlToDim);
params[kParamDimToLvl] = stt.isIdentity()
? params[kParamLvlToDim]
: allocaBuffer(builder, loc, dimToLvl);
// Secondary and primary types encoding.
setTemplateTypes(stt);
// Finally, make note that initialization is complete.
assert(isInitialized() && "Initialization failed");
// And return `this` for method chaining.
return *this;
}

/// Generates a call to obtain the values array.
static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
ValueRange ptr) {
Expand Down
12 changes: 8 additions & 4 deletions mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
StridedMemRefType<index_type, 1> *dimSizesRef,
StridedMemRefType<index_type, 1> *lvlSizesRef,
StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
StridedMemRefType<index_type, 1> *lvl2dimRef,
StridedMemRefType<index_type, 1> *dim2lvlRef, OverheadType posTp,
StridedMemRefType<index_type, 1> *dim2lvlRef,
StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
OverheadType crdTp, PrimaryType valTp, Action action, void *ptr) {
ASSERT_NO_STRIDE(dimSizesRef);
ASSERT_NO_STRIDE(lvlSizesRef);
Expand All @@ -250,6 +250,9 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);

// Prepare map.
// TODO: start using MapRef map(dimRank, lvlRank, dim2lvl, lvl2dim) below

// Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
// This is safe because of the static_assert above.
if (posTp == OverheadType::kIndex)
Expand Down Expand Up @@ -400,6 +403,7 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
#undef IMPL_GETOVERHEAD

// TODO: use MapRef here for translation of coordinates
// TOOD: remove dim2lvl
#define IMPL_ADDELT(VNAME, V) \
void *_mlir_ciface_addElt##VNAME( \
void *lvlCOO, StridedMemRefType<V, 0> *vref, \
Expand Down Expand Up @@ -540,13 +544,13 @@ void *_mlir_ciface_newSparseTensorFromReader(
SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
ASSERT_NO_STRIDE(lvlSizesRef);
ASSERT_NO_STRIDE(lvlTypesRef);
ASSERT_NO_STRIDE(lvl2dimRef);
ASSERT_NO_STRIDE(dim2lvlRef);
ASSERT_NO_STRIDE(lvl2dimRef);
const uint64_t dimRank = reader.getRank();
const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
ASSERT_USIZE_EQ(lvl2dimRef, lvlRank);
ASSERT_USIZE_EQ(dim2lvlRef, dimRank);
ASSERT_USIZE_EQ(lvl2dimRef, lvlRank);
(void)dimRank;
const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
Expand Down
12 changes: 5 additions & 7 deletions mlir/test/Dialect/SparseTensor/conversion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,16 @@ func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor
// CHECK-DAG: %[[Empty:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[DimSizes0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
// CHECK-DAG: %[[Sizes0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[DimSizes:.*]] = memref.cast %[[DimSizes0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
// CHECK-DAG: %[[Sizes:.*]] = memref.cast %[[Sizes0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I]], %[[DimSizes0]][%[[C0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[J]], %[[DimSizes0]][%[[C1]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[I]], %[[Sizes0]][%[[C0]]] : memref<2xindex>
// CHECK-DAG: memref.store %[[J]], %[[Sizes0]][%[[C1]]] : memref<2xindex>
aartbik marked this conversation as resolved.
Show resolved Hide resolved
// CHECK: %[[NP:.*]] = llvm.mlir.zero : !llvm.ptr<i8>
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %[[Empty]], %[[NP]])
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[Sizes]], %[[Sizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %[[Empty]], %[[NP]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func.func @sparse_init(%arg0: index, %arg1: index) -> tensor<?x?xf64, #CSR> {
%0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf64, #CSR>
Expand Down
Loading