Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 41 additions & 66 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,47 +829,40 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
}
};

// A trivial wrapper to help generate different operations for dense/sparse
// tensors.
struct TensorLike {
TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
ValueRange sizes)
: isSparse(rtt.getEncoding() != nullptr) {
ValueRange sizes) {
SmallVector<Value> dynSzs;
getDynamicSizes(rtt, sizes, dynSzs);

if (isSparse)
val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
else
val = allocDenseTensor(builder, loc, rtt, sizes);
};

void insertOrStore(OpBuilder &builder, Location loc, Value v,
ValueRange crds) {
if (isSparse)
val = builder.create<InsertOp>(loc, v, val, crds);
else
builder.create<memref::StoreOp>(loc, v, val, crds);
val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
if (!isSparse()) {
Value c0 = constantZero(builder, loc, rtt.getElementType());
val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
}
}

Value getSSA() const {
// We don't need to maintain the SSA chain for a memref value.
return isSparse ? val : nullptr;
void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
// TODO: Unify these two.
if (isSparse())
val = builder.create<sparse_tensor::InsertOp>(loc, v, val, crds);
else
val = builder.create<tensor::InsertOp>(loc, v, val, crds);
}

Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
if (isSparse)
if (isSparse())
return builder.create<LoadOp>(loc, val, true);
return builder.create<bufferization::ToTensorOp>(loc, rtp, val);
return val;
}

void updateSSA(Value v) {
// Dense memref is a non-SSA value.
assert(isSparse);
val = v;
bool isSparse() const {
return getSparseTensorEncoding(val.getType()) != nullptr;
}

private:
bool isSparse;
Value val; // either a memref (for dense tensor) or a sparse tensor.
Value val;
};

struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
Expand Down Expand Up @@ -901,14 +894,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {

TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
Value offset = constantIndex(rewriter, loc, 0);
Value iterArg = dstBuf.getSSA();
Value iterArg = dstBuf.val;

ForeachOp foreachOp;
for (Value input : op.getInputs()) {
// Builds a for op for each input tensor to append new values into the
// output tensor.
foreachOp = rewriter.create<ForeachOp>(
loc, input, iterArg ? ValueRange{iterArg} : ValueRange{},
loc, input, iterArg,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
SmallVector<Value> dstLcvs(dstTp.getLvlRank());
Expand All @@ -920,32 +913,26 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
// FIXME: `toStoredDim` is deprecated
dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd;
}

if (!reduc.empty())
dstBuf.updateSSA(reduc.front());

// Enters foreach, updates the SSA chain.
dstBuf.val = reduc.front();
if (!dstTp.isAllDense()) {
Value cond = genIsNonzero(builder, loc, v);
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
/*else*/ true);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
builder.create<scf::YieldOp>(loc, dstBuf.val);

builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
dstBuf.insertOrStore(builder, loc, v, dstLcvs);
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
dstBuf.insert(builder, loc, v, dstLcvs);
builder.create<scf::YieldOp>(loc, dstBuf.val);

// Exits the ifOp, update the sparse tensor SSA value.
builder.setInsertionPointAfter(ifOp);
assert(!reduc.empty());
dstBuf.updateSSA(ifOp.getResult(0));
dstBuf.val = ifOp.getResult(0);
} else {
dstBuf.insertOrStore(builder, loc, v, dstLcvs);
dstBuf.insert(builder, loc, v, dstLcvs);
}
if (reduc.empty())
builder.create<sparse_tensor::YieldOp>(loc);
else
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
});
// Accumulates the offset. Note that only static-shaped inputs are allowed
// by concatenate op verifier, which saves us from computing the offset
Expand All @@ -955,15 +942,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
offset = rewriter.create<arith::AddIOp>(
loc, offset, constantIndex(rewriter, loc, *sh));

if (!foreachOp.getResults().empty()) {
iterArg = foreachOp.getResult(0);
dstBuf.updateSSA(iterArg);
}
iterArg = foreachOp.getResult(0);
dstBuf.val = iterArg;
}

if (!foreachOp.getResults().empty())
dstBuf.updateSSA(iterArg);

dstBuf.val = iterArg;
Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
rewriter.replaceOp(op, ret);
return success();
Expand Down Expand Up @@ -1010,15 +993,12 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
ValueRange vs;
TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);

Value iterArg = dstBuf.getSSA();
auto foreachOp = rewriter.create<ForeachOp>(
loc, src, iterArg ? ValueRange{iterArg} : ValueRange{}, foreachOrder,
loc, src, dstBuf.val, foreachOrder,
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
ValueRange reduc) {
// Enters the loop, update the SSA value for insertion chain.
if (!reduc.empty())
dstBuf.updateSSA(reduc.front());

dstBuf.val = reduc.front();
const Dimension dimRank = dstStt.getDimRank();
const Level lvlRank = dstStt.getLvlRank();
SmallVector<Value> lcvs(lvlRank);
Expand All @@ -1028,34 +1008,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
}

if (!skipZeroCheck) {
assert(!reduc.empty());
Value cond = genIsNonzero(builder, loc, v);
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
/*else*/ true);
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
builder.create<scf::YieldOp>(loc, dstBuf.val);

builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
dstBuf.insertOrStore(builder, loc, v, lcvs);
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
dstBuf.insert(builder, loc, v, lcvs);
builder.create<scf::YieldOp>(loc, dstBuf.val);

// Exits the ifOp, update the sparse tensor SSA value.
builder.setInsertionPointAfter(ifOp);
dstBuf.updateSSA(ifOp.getResult(0));
dstBuf.val = ifOp.getResult(0);
} else {
dstBuf.insertOrStore(builder, loc, v, lcvs);
dstBuf.insert(builder, loc, v, lcvs);
}
if (reduc.empty())
builder.create<sparse_tensor::YieldOp>(loc);
else
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
});

rewriter.setInsertionPointAfter(foreachOp);

// Exits the for loop, links the SSA chain.
if (!foreachOp.getResults().empty())
dstBuf.updateSSA(foreachOp.getResult(0));
dstBuf.val = foreachOp.getResult(0);

Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
rewriter.replaceOp(op, ret);
Expand Down
35 changes: 14 additions & 21 deletions mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,83 +14,76 @@

// CHECK-LABEL: func.func @sparse_convert_1d
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32> {
%0 = sparse_tensor.convert %arg0 : tensor<13xi32, #SparseVector> to tensor<13xi32>
return %0 : tensor<13xi32>
}

// CHECK-LABEL: func.func @sparse_convert_1d_dyn
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<?xi32> {
%0 = sparse_tensor.convert %arg0 : tensor<?xi32, #SparseVector> to tensor<?xi32>
return %0 : tensor<?xi32>
}

// CHECK-LABEL: func.func @sparse_convert_2d
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x4xf64, #SparseMatrix> to tensor<2x4xf64>
return %0 : tensor<2x4xf64>
}

// CHECK-LABEL: func.func @sparse_convert_2d_dyn
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tensor<?x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<?x4xf64, #SparseMatrix> to tensor<?x4xf64>
return %0 : tensor<?x4xf64>
}

// CHECK-LABEL: func.func @sparse_convert_2d_dyn1
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tensor<2x?xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x?xf64, #SparseMatrix> to tensor<2x?xf64>
return %0 : tensor<2x?xf64>
}

// CHECK-LABEL: func.func @sparse_convert_2d_dyn2
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x?xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<?x?xf64, #SparseMatrix> to tensor<?x?xf64>
return %0 : tensor<?x?xf64>
}

// CHECK-LABEL: func.func @sparse_convert_3d
// CHECK-NOT: sparse_tensor.reorder_coo
// CHECK: memref.alloc
// CHECK: bufferization.alloc_tensor
// CHECK: linalg.fill
// CHECK: sparse_tensor.foreach
// CHECK: memref.store
// CHECK: bufferization.to_tensor
// CHECK: tensor.insert
func.func @sparse_convert_3d(%arg0: tensor<2x3x4xf64, #SparseTensor>) -> tensor<2x3x4xf64> {
%0 = sparse_tensor.convert %arg0 : tensor<2x3x4xf64, #SparseTensor> to tensor<2x3x4xf64>
return %0 : tensor<2x3x4xf64>
Expand Down
Loading