Skip to content

Commit

Permalink
[mlir][sparse] fix crash when generate rotated convolution kernels. (l…
Browse files Browse the repository at this point in the history
  • Loading branch information
PeimingLiu authored Dec 1, 2023
1 parent 3d89f2a commit 8206b75
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1329,11 +1329,11 @@ void LoopEmitter::enterTensorsAtDenseLvls(
// Update the slice information as we enter the new loop.
info.minCrd = info.offset = MULI(iv, C_IDX(stride));
info.isNonEmpty = constantI1(builder, loc, true);
levelReducedDep[tid][lvl]++;
} else {
posits[tid][lvl] =
genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv));
}
levelReducedDep[tid][lvl]++;
} else {
// Skips the synthetic tensor
if (isSynTensor(tid))
Expand Down Expand Up @@ -1369,11 +1369,11 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
// moves forward to the next slice.
invalidateSliceIterIdx(rewriter, loc, tid, lvl);
info.minCrd = info.offset = info.isNonEmpty = Value();
levelReducedDep[tid][lvl]--;
} else {
forwardsReducedSliceLevelTreeIt(rewriter, loc, tid, lvl,
constantIndex(rewriter, loc, 1));
}
levelReducedDep[tid][lvl]--;
}
if (auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop)) {
if (!reduc.empty()) {
Expand Down Expand Up @@ -1460,8 +1460,8 @@ void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
// level (but not resolved). Since we forward an iterator at higher level of
// the tree, the subtree need to be pruned.
Level leafLvl = rootLvl + 1;
while (leafLvl < stt.getLvlRank() && !dependentLvlMap[tid][leafLvl].empty()) {
assert(depFullyReduced(tid, leafLvl));
while (leafLvl < stt.getLvlRank() && !dependentLvlMap[tid][leafLvl].empty() &&
depFullyReduced(tid, leafLvl)) {
leafLvl++;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
map = (d0, d1) -> (d1 : dense, d0 : compressed)
}>

#map = affine_map<(d0, d1, d2, d3) -> (d0 + d1, d3 + d2)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3)>

// An example of a 2D convolution with a sparse filter.
module {

Expand All @@ -50,6 +54,21 @@ module {
return %0 : tensor<6x6xi32>
}

func.func @conv2d_CSR_dense_rotated(%arg0: tensor<8x8xi32, #CSR>,
%arg1: tensor<3x3xi32>) -> tensor<6x6xi32> {
%s = tensor.empty() : tensor<6x6xi32>
%0 = linalg.generic {indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "reduction", "reduction", "parallel"]}
ins(%arg0, %arg1 : tensor<8x8xi32, #CSR>, tensor<3x3xi32>)
outs(%s : tensor<6x6xi32>) attrs = {sorted = true} {
^bb0(%in: i32, %in_0: i32, %out: i32):
%1 = arith.muli %in, %in_0 : i32
%2 = arith.addi %out, %1 : i32
linalg.yield %2 : i32
} -> tensor<6x6xi32>
return %0 : tensor<6x6xi32>
}

func.func @conv2d_sparse_out(%input: tensor<8x8xi32>,
%filter: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> {
%s = tensor.empty() : tensor<6x6xi32, #DCSR>
Expand Down Expand Up @@ -146,7 +165,9 @@ module {
%5 = call @conv2d_all_sparse_CSC(%sparse_input_CSC, %filter)
: (tensor<8x8xi32, #CSC>,
tensor<3x3xi32>) -> tensor<6x6xi32, #CSC>

%6 = call @conv2d_CSR_dense_rotated(%sparse_input_CSR, %filter)
: (tensor<8x8xi32, #CSR>,
tensor<3x3xi32>) -> tensor<6x6xi32>

// Verify the output.
//
Expand Down Expand Up @@ -236,6 +257,20 @@ module {
: tensor<6x6xi32>, vector<6x6xi32>
vector.print %v5 : vector<6x6xi32>

//
// Should be the same as dense output
// CHECK: ( ( 0, 0, -1, -6, -1, 6 ),
// CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ),
// CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ),
// CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ),
// CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ),
// CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) )
//
%v6 = vector.transfer_read %6[%c0, %c0], %i0
: tensor<6x6xi32>, vector<6x6xi32>
vector.print %v : vector<6x6xi32>


// Release the resources.
bufferization.dealloc_tensor %sparse_input_DCSR : tensor<8x8xi32, #DCSR>
bufferization.dealloc_tensor %sparse_input_CSR : tensor<8x8xi32, #CSR>
Expand Down

0 comments on commit 8206b75

Please sign in to comment.