Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ enum class Action : uint32_t {
kEmptyForward = 1,
kFromCOO = 2,
kSparseToSparse = 3,
kFromReader = 4,
kToCOO = 5,
kPack = 7,
kSortCOOInPlace = 8,
Expand Down
25 changes: 0 additions & 25 deletions mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,6 @@ MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_createCheckedSparseTensorReader(
char *filename, StridedMemRefType<index_type, 1> *dimShapeRef,
PrimaryType valTp);

/// Constructs a new sparse-tensor storage object with the given encoding,
/// initializes it by reading all the elements from the file, and then
/// closes the file.
MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensorFromReader(
void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
StridedMemRefType<index_type, 1> *dim2lvlRef,
StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
OverheadType crdTp, PrimaryType valTp);

/// SparseTensorReader method to obtain direct access to the
/// dimension-sizes array.
MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderDimSizes(
Expand Down Expand Up @@ -197,24 +187,9 @@ MLIR_SPARSETENSOR_FOREVERY_V(DECL_DELCOO)
/// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
MLIR_CRUNNERUTILS_EXPORT char *getTensorFilename(index_type id);

/// Helper function to read the header of a file and return the
/// shape/sizes, without parsing the elements of the file.
MLIR_CRUNNERUTILS_EXPORT void readSparseTensorShape(char *filename,
std::vector<uint64_t> *out);

/// Returns the rank of the sparse tensor being read.
MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderRank(void *p);

/// Returns the is_symmetric bit for the sparse tensor being read.
MLIR_CRUNNERUTILS_EXPORT bool getSparseTensorReaderIsSymmetric(void *p);

/// Returns the number of stored elements for the sparse tensor being read.
MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderNSE(void *p);

/// Returns the size of a dimension for the sparse tensor being read.
MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderDimSize(void *p,
index_type d);

/// Releases the SparseTensorReader and closes the associated file.
MLIR_CRUNNERUTILS_EXPORT void delSparseTensorReader(void *p);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,15 @@ class NewCallParams final {
/// type-level information such as the encoding and sizes), generating
/// MLIR buffers as needed, and returning `this` for method chaining.
NewCallParams &genBuffers(SparseTensorType stt,
ArrayRef<Value> dimSizesValues) {
ArrayRef<Value> dimSizesValues,
Value dimSizesBuffer = Value()) {
assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
// Sparsity annotations.
params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
// Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizesValues);
params[kParamDimSizes] = dimSizesBuffer
? dimSizesBuffer
: allocaBuffer(builder, loc, dimSizesValues);
params[kParamLvlSizes] =
genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes],
params[kParamDim2Lvl], params[kParamLvl2Dim]);
Expand Down Expand Up @@ -342,33 +345,15 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
const auto stt = getSparseTensorType(op);
if (!stt.hasEncoding())
return failure();
// Construct the reader opening method calls.
// Construct the `reader` opening method calls.
SmallVector<Value> dimShapesValues;
Value dimSizesBuffer;
Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
dimShapesValues, dimSizesBuffer);
// Now construct the lvlSizes, dim2lvl, and lvl2dim buffers.
Value dim2lvlBuffer;
Value lvl2dimBuffer;
Value lvlSizesBuffer =
genMapBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer,
dim2lvlBuffer, lvl2dimBuffer);
// Use the `reader` to parse the file.
Type opaqueTp = getOpaquePointerType(rewriter);
Type eltTp = stt.getElementType();
Value valTp = constantPrimaryTypeEncoding(rewriter, loc, eltTp);
SmallVector<Value, 8> params{
reader,
lvlSizesBuffer,
genLvlTypesBuffer(rewriter, loc, stt),
dim2lvlBuffer,
lvl2dimBuffer,
constantPosTypeEncoding(rewriter, loc, stt.getEncoding()),
constantCrdTypeEncoding(rewriter, loc, stt.getEncoding()),
valTp};
Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader",
opaqueTp, params, EmitCInterface::On)
.getResult(0);
Value tensor = NewCallParams(rewriter, loc)
.genBuffers(stt, dimShapesValues, dimSizesBuffer)
.genNewCall(Action::kFromReader, reader);
// Free the memory for `reader`.
createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
EmitCInterface::Off);
Expand Down
137 changes: 6 additions & 131 deletions mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ extern "C" {
dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
dimRank, tensor); \
} \
case Action::kFromReader: { \
assert(ptr && "Received nullptr for SparseTensorReader object"); \
SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr); \
return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \
} \
case Action::kToCOO: { \
assert(ptr && "Received nullptr for SparseTensorStorage object"); \
auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \
Expand Down Expand Up @@ -442,113 +448,6 @@ void _mlir_ciface_getSparseTensorReaderDimSizes(
MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
#undef IMPL_GETNEXT

void *_mlir_ciface_newSparseTensorFromReader(
void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
StridedMemRefType<index_type, 1> *dim2lvlRef,
StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
OverheadType crdTp, PrimaryType valTp) {
assert(p);
SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
ASSERT_NO_STRIDE(lvlSizesRef);
ASSERT_NO_STRIDE(lvlTypesRef);
ASSERT_NO_STRIDE(dim2lvlRef);
ASSERT_NO_STRIDE(lvl2dimRef);
const uint64_t dimRank = reader.getRank();
const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
ASSERT_USIZE_EQ(dim2lvlRef, dimRank);
ASSERT_USIZE_EQ(lvl2dimRef, lvlRank);
(void)dimRank;
const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
#define CASE(p, c, v, P, C, V) \
if (posTp == OverheadType::p && crdTp == OverheadType::c && \
valTp == PrimaryType::v) \
return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));
#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
// Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
// This is safe because of the static_assert above.
if (posTp == OverheadType::kIndex)
posTp = OverheadType::kU64;
if (crdTp == OverheadType::kIndex)
crdTp = OverheadType::kU64;
// Double matrices with all combinations of overhead storage.
CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
CASE(kU64, kU32, kF64, uint64_t, uint32_t, double);
CASE(kU64, kU16, kF64, uint64_t, uint16_t, double);
CASE(kU64, kU8, kF64, uint64_t, uint8_t, double);
CASE(kU32, kU64, kF64, uint32_t, uint64_t, double);
CASE(kU32, kU32, kF64, uint32_t, uint32_t, double);
CASE(kU32, kU16, kF64, uint32_t, uint16_t, double);
CASE(kU32, kU8, kF64, uint32_t, uint8_t, double);
CASE(kU16, kU64, kF64, uint16_t, uint64_t, double);
CASE(kU16, kU32, kF64, uint16_t, uint32_t, double);
CASE(kU16, kU16, kF64, uint16_t, uint16_t, double);
CASE(kU16, kU8, kF64, uint16_t, uint8_t, double);
CASE(kU8, kU64, kF64, uint8_t, uint64_t, double);
CASE(kU8, kU32, kF64, uint8_t, uint32_t, double);
CASE(kU8, kU16, kF64, uint8_t, uint16_t, double);
CASE(kU8, kU8, kF64, uint8_t, uint8_t, double);
// Float matrices with all combinations of overhead storage.
CASE(kU64, kU64, kF32, uint64_t, uint64_t, float);
CASE(kU64, kU32, kF32, uint64_t, uint32_t, float);
CASE(kU64, kU16, kF32, uint64_t, uint16_t, float);
CASE(kU64, kU8, kF32, uint64_t, uint8_t, float);
CASE(kU32, kU64, kF32, uint32_t, uint64_t, float);
CASE(kU32, kU32, kF32, uint32_t, uint32_t, float);
CASE(kU32, kU16, kF32, uint32_t, uint16_t, float);
CASE(kU32, kU8, kF32, uint32_t, uint8_t, float);
CASE(kU16, kU64, kF32, uint16_t, uint64_t, float);
CASE(kU16, kU32, kF32, uint16_t, uint32_t, float);
CASE(kU16, kU16, kF32, uint16_t, uint16_t, float);
CASE(kU16, kU8, kF32, uint16_t, uint8_t, float);
CASE(kU8, kU64, kF32, uint8_t, uint64_t, float);
CASE(kU8, kU32, kF32, uint8_t, uint32_t, float);
CASE(kU8, kU16, kF32, uint8_t, uint16_t, float);
CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
// Two-byte floats with both overheads of the same type.
CASE_SECSAME(kU64, kF16, uint64_t, f16);
CASE_SECSAME(kU64, kBF16, uint64_t, bf16);
CASE_SECSAME(kU32, kF16, uint32_t, f16);
CASE_SECSAME(kU32, kBF16, uint32_t, bf16);
CASE_SECSAME(kU16, kF16, uint16_t, f16);
CASE_SECSAME(kU16, kBF16, uint16_t, bf16);
CASE_SECSAME(kU8, kF16, uint8_t, f16);
CASE_SECSAME(kU8, kBF16, uint8_t, bf16);
// Integral matrices with both overheads of the same type.
CASE_SECSAME(kU64, kI64, uint64_t, int64_t);
CASE_SECSAME(kU64, kI32, uint64_t, int32_t);
CASE_SECSAME(kU64, kI16, uint64_t, int16_t);
CASE_SECSAME(kU64, kI8, uint64_t, int8_t);
CASE_SECSAME(kU32, kI64, uint32_t, int64_t);
CASE_SECSAME(kU32, kI32, uint32_t, int32_t);
CASE_SECSAME(kU32, kI16, uint32_t, int16_t);
CASE_SECSAME(kU32, kI8, uint32_t, int8_t);
CASE_SECSAME(kU16, kI64, uint16_t, int64_t);
CASE_SECSAME(kU16, kI32, uint16_t, int32_t);
CASE_SECSAME(kU16, kI16, uint16_t, int16_t);
CASE_SECSAME(kU16, kI8, uint16_t, int8_t);
CASE_SECSAME(kU8, kI64, uint8_t, int64_t);
CASE_SECSAME(kU8, kI32, uint8_t, int32_t);
CASE_SECSAME(kU8, kI16, uint8_t, int16_t);
CASE_SECSAME(kU8, kI8, uint8_t, int8_t);
// Complex matrices with wide overhead.
CASE_SECSAME(kU64, kC64, uint64_t, complex64);
CASE_SECSAME(kU64, kC32, uint64_t, complex32);

// Unsupported case (add above if needed).
MLIR_SPARSETENSOR_FATAL(
"unsupported combination of types: <P=%d, C=%d, V=%d>\n",
static_cast<int>(posTp), static_cast<int>(crdTp),
static_cast<int>(valTp));
#undef CASE_SECSAME
#undef CASE
}

void _mlir_ciface_outSparseTensorWriterMetaData(
void *p, index_type dimRank, index_type nse,
StridedMemRefType<index_type, 1> *dimSizesRef) {
Expand Down Expand Up @@ -635,34 +534,10 @@ char *getTensorFilename(index_type id) {
return env;
}

void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
assert(out && "Received nullptr for out-parameter");
SparseTensorReader reader(filename);
reader.openFile();
reader.readHeader();
reader.closeFile();
const uint64_t dimRank = reader.getRank();
const uint64_t *dimSizes = reader.getDimSizes();
out->reserve(dimRank);
out->assign(dimSizes, dimSizes + dimRank);
}

index_type getSparseTensorReaderRank(void *p) {
return static_cast<SparseTensorReader *>(p)->getRank();
}

bool getSparseTensorReaderIsSymmetric(void *p) {
return static_cast<SparseTensorReader *>(p)->isSymmetric();
}

index_type getSparseTensorReaderNSE(void *p) {
return static_cast<SparseTensorReader *>(p)->getNSE();
}

index_type getSparseTensorReaderDimSize(void *p, index_type d) {
return static_cast<SparseTensorReader *>(p)->getDimSize(d);
}

void delSparseTensorReader(void *p) {
delete static_cast<SparseTensorReader *>(p);
}
Expand Down
30 changes: 15 additions & 15 deletions mlir/test/Dialect/SparseTensor/conversion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
// CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimShape]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
// CHECK: call @delSparseTensorReader(%[[Reader]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
Expand All @@ -96,11 +96,11 @@ func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector>
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
// CHECK: call @delSparseTensorReader(%[[Reader]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
Expand All @@ -114,15 +114,15 @@ func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
// CHECK: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
// CHECK: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
// CHECK: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
// CHECK: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
// CHECK: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
// CHECK: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
// CHECK: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
// CHECK: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[LvlSizes]], %[[LvlTypes]], %[[Dim2Lvl]], %[[Lvl2Dim]], %{{.*}}, %{{.*}}, %{{.*}})
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
// CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
// CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
// CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
// CHECK-DAG: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
// CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Dim2Lvl]], %[[Lvl2Dim]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
// CHECK: call @delSparseTensorReader(%[[Reader]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
Expand Down