Skip to content

Commit

Permalink
[mlir][sparse] remove most bufferization.alloc_tensor ops from sparse (
Browse files Browse the repository at this point in the history
…#66847)

The only ones left need actual deprecation in bufferization module.
  • Loading branch information
aartbik authored Sep 20, 2023
1 parent a009fa7 commit 3e4a8c2
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 50 deletions.
45 changes: 43 additions & 2 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
};

/// Sparse codegen rule for the alloc operator.
/// TODO(springerm): remove when bufferization.alloc_tensor is gone
class SparseTensorAllocConverter
: public OpConversionPattern<bufferization::AllocTensorOp> {
public:
Expand Down Expand Up @@ -764,6 +765,46 @@ class SparseTensorAllocConverter
bool enableBufferInitialization;
};

/// Sparse codegen rule for the empty tensor operator.
/// TODO(springerm): remove when bufferization.alloc_tensor is gone
class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
public:
using OpConversionPattern::OpConversionPattern;
SparseTensorEmptyConverter(TypeConverter &typeConverter, MLIRContext *context,
bool enableInit)
: OpConversionPattern(typeConverter, context),
enableBufferInitialization(enableInit) {}

LogicalResult
matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto resType = getSparseTensorType(op);
if (!resType.hasEncoding())
return failure();

// Construct allocation for each field.
const Location loc = op.getLoc();
const Value sizeHint; // none
const ValueRange dynSizes = adaptor.getDynamicSizes();
const size_t found = dynSizes.size();
const int64_t expected = resType.getNumDynamicDims();
if (found != static_cast<size_t>(expected))
return rewriter.notifyMatchFailure(
op, llvm::formatv(
"Got wrong number of dynamic sizes: Found={0}, Expected={1}",
found, expected));
SmallVector<Value> fields;
createAllocFields(rewriter, loc, resType, dynSizes,
enableBufferInitialization, fields, sizeHint);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
return success();
}

private:
bool enableBufferInitialization;
};

/// Sparse codegen rule for the dealloc operator.
class SparseTensorDeallocConverter
: public OpConversionPattern<bufferization::DeallocTensorOp> {
Expand Down Expand Up @@ -1546,6 +1587,6 @@ void mlir::populateSparseTensorCodegenPatterns(
patterns.getContext());
patterns.add<SparseTensorDeallocConverter>(
typeConverter, patterns.getContext(), createSparseDeallocs);
patterns.add<SparseTensorAllocConverter>(typeConverter, patterns.getContext(),
enableBufferInitialization);
patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
typeConverter, patterns.getContext(), enableBufferInitialization);
}
3 changes: 1 addition & 2 deletions mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
// CHECK: %[[VAL_2:.*]] = arith.constant 8 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 9 : index
// CHECK: %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<8x8xf32, #{{.*}}>
// CHECK: %[[VAL_6:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] : tensor<8x8xf32, #{{.*}}>
// CHECK: %[[VAL_7:.*]] = sparse_tensor.number_of_entries %[[VAL_1]] : tensor<8x8xf32, #{{.*}}>
// CHECK: %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf32, #{{.*}}>
Expand Down Expand Up @@ -92,7 +91,7 @@
// CHECK: }
func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
%B: tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR> {
%init = bufferization.alloc_tensor() : tensor<8x8xf32, #CSR>
%init = tensor.empty() : tensor<8x8xf32, #CSR>
%C = linalg.matmul
ins(%A, %B: tensor<8x8xf32, #CSR>,
tensor<8x8xf32, #CSR>)
Expand Down
23 changes: 5 additions & 18 deletions mlir/test/Dialect/SparseTensor/codegen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ func.func @sparse_dealloc_csr(%arg0: tensor<?x?xf64, #CSR>) {
// CHECK: %[[A26:.*]] = sparse_tensor.storage_specifier.set %[[A18]] pos_mem_sz at 1 with %[[A25]] : !sparse_tensor.storage_specifier
// CHECK: return %[[A23]], %[[A6]], %[[A8]], %[[A26]] : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
%0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC>
%0 = tensor.empty(%arg0) : tensor<10x?xf64, #CSC>
%1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC>
return %1 : tensor<10x?xf64, #CSC>
}
Expand All @@ -351,24 +351,11 @@ func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> {
// CHECK: %[[A16:.*]] = sparse_tensor.storage_specifier.set %[[A10]] val_mem_sz with %[[A14]] : !sparse_tensor.storage_specifier
// CHECK: return %[[A15]], %[[A16]] : memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> {
%0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D>
%0 = tensor.empty() : tensor<10x20x30xf64, #Dense3D>
%1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D>
return %1 : tensor<10x20x30xf64, #Dense3D>
}

// CHECK-LABEL: func.func @sparse_alloc_coo_with_size_hint(
// CHECK-SAME: %[[HINT:.*]]: index)
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[M2:.*]] = arith.muli %[[HINT]], %c2 : index
// CHECK: %[[A1:.*]] = memref.alloc() : memref<2xindex>
// CHECK: %[[A2:.*]] = memref.alloc(%[[M2]]) : memref<?xindex>
// CHECK: %[[A3:.*]] = memref.alloc(%[[HINT]]) : memref<?xf64>
func.func @sparse_alloc_coo_with_size_hint(%arg0: index) -> tensor<10x20xf64, #Coo> {
%0 = bufferization.alloc_tensor() size_hint=%arg0 : tensor<10x20xf64, #Coo>
%1 = sparse_tensor.load %0 : tensor<10x20xf64, #Coo>
return %1 : tensor<10x20xf64, #Coo>
}

// CHECK-LABEL: func.func @sparse_expansion1()
// CHECK: %[[A:.*]] = memref.alloc() : memref<8xf64>
// CHECK: %[[B:.*]] = memref.alloc() : memref<8xi1>
Expand All @@ -378,7 +365,7 @@ func.func @sparse_alloc_coo_with_size_hint(%arg0: index) -> tensor<10x20xf64, #C
// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<8xi1>)
// CHECK: return %[[D]] : memref<?xindex>
func.func @sparse_expansion1() -> memref<?xindex> {
%0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSR>
%0 = tensor.empty() : tensor<4x8xf64, #CSR>
%values, %filled, %added, %count = sparse_tensor.expand %0
: tensor<4x8xf64, #CSR> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return %added : memref<?xindex>
Expand All @@ -393,7 +380,7 @@ func.func @sparse_expansion1() -> memref<?xindex> {
// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<4xi1>)
// CHECK: return %[[D]] : memref<?xindex>
func.func @sparse_expansion2() -> memref<?xindex> {
%0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSC>
%0 = tensor.empty() : tensor<4x8xf64, #CSC>
%values, %filled, %added, %count = sparse_tensor.expand %0
: tensor<4x8xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return %added : memref<?xindex>
Expand All @@ -409,7 +396,7 @@ func.func @sparse_expansion2() -> memref<?xindex> {
// CHECK: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<?xi1>)
// CHECK: return %[[D]] : memref<?xindex>
func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
%0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #CSC>
%0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf64, #CSC>
%values, %filled, %added, %count = sparse_tensor.expand %0
: tensor<?x?xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return %added : memref<?xindex>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>

// CHECK-LABEL: func.func @sparse_alloc_sparse_vector(
// CHECK-LABEL: func.func @empty_sparse_vector(
// CHECK-SAME: %[[VAL_0:.*]]: index) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64
Expand All @@ -24,8 +24,8 @@
// CHECK: %[[VAL_19:.*]], %[[VAL_21:.*]] = sparse_tensor.push_back %[[VAL_17]], %[[VAL_15]], %[[VAL_3]], %[[VAL_1]] : index, memref<?xindex>, index, index
// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] pos_mem_sz at 0 with %[[VAL_21]] : !sparse_tensor.storage_specifier
// CHECK: return %[[VAL_19]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @sparse_alloc_sparse_vector(%arg0: index) -> tensor<?xf64, #SV> {
%0 = bufferization.alloc_tensor(%arg0) : tensor<?xf64, #SV>
func.func @empty_sparse_vector(%arg0: index) -> tensor<?xf64, #SV> {
%0 = tensor.empty(%arg0) : tensor<?xf64, #SV>
%1 = sparse_tensor.load %0 : tensor<?xf64, #SV>
return %1 : tensor<?xf64, #SV>
}
14 changes: 7 additions & 7 deletions mlir/test/Dialect/SparseTensor/conversion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %[[Empty]], %[[NP]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func.func @sparse_init(%arg0: index, %arg1: index) -> tensor<?x?xf64, #CSR> {
%0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #CSR>
%0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf64, #CSR>
%1 = sparse_tensor.load %0 : tensor<?x?xf64, #CSR>
return %1 : tensor<?x?xf64, #CSR>
}
Expand Down Expand Up @@ -334,7 +334,7 @@ func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<8xi1>)
// CHECK: return %[[D]] : memref<?xindex>
func.func @sparse_expansion1() -> memref<?xindex> {
%0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSR>
%0 = tensor.empty() : tensor<4x8xf64, #CSR>
%values, %filled, %added, %count = sparse_tensor.expand %0
: tensor<4x8xf64, #CSR> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return %added : memref<?xindex>
Expand All @@ -350,7 +350,7 @@ func.func @sparse_expansion1() -> memref<?xindex> {
// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<4xi1>)
// CHECK: return %[[D]] : memref<?xindex>
func.func @sparse_expansion2() -> memref<?xindex> {
%0 = bufferization.alloc_tensor() : tensor<4x8xf64, #CSC>
%0 = tensor.empty() : tensor<4x8xf64, #CSC>
%values, %filled, %added, %count = sparse_tensor.expand %0
: tensor<4x8xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return %added : memref<?xindex>
Expand All @@ -367,7 +367,7 @@ func.func @sparse_expansion2() -> memref<?xindex> {
// CHECK-DAG: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<?xi1>)
// CHECK: return %[[C]] : memref<?xindex>
func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
%0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #CSC>
%0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf64, #CSC>
%values, %filled, %added, %count = sparse_tensor.expand %0
: tensor<?x?xf64, #CSC> to memref<?xf64>, memref<?xi1>, memref<?xindex>
return %added : memref<?xindex>
Expand Down Expand Up @@ -430,12 +430,12 @@ func.func @sparse_out2(%arg0: tensor<?x?x?xf32, #SparseTensor>, %arg1: !llvm.ptr

// CHECK-LABEL: func @sparse_and_dense_init(
// CHECK: %[[S:.*]] = call @newSparseTensor
// CHECK: %[[D:.*]] = bufferization.alloc_tensor
// CHECK: %[[D:.*]] = tensor.empty
// CHECK: return %[[S]], %[[D]] : !llvm.ptr<i8>, tensor<?x?xf64>
func.func @sparse_and_dense_init(%arg0: index, %arg1: index)
-> (tensor<?x?xf64, #CSR>, tensor<?x?xf64>) {
%0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #CSR>
%0 = tensor.empty(%arg0, %arg1) : tensor<?x?xf64, #CSR>
%1 = sparse_tensor.load %0 : tensor<?x?xf64, #CSR>
%2 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64>
%2 = tensor.empty(%arg0, %arg1) : tensor<?x?xf64>
return %1, %2 : tensor<?x?xf64, #CSR>, tensor<?x?xf64>
}
4 changes: 2 additions & 2 deletions mlir/test/Dialect/SparseTensor/pre_rewriting.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func.func @sparse_fuse_slice(%a : tensor<2x3xi64, #SortedCOO>) -> tensor<1x3xi64
// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>) -> tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>> {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[VAL_4:.*]] = bufferization.alloc_tensor() : tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>
// CHECK-DAG: %[[VAL_4:.*]] = tensor.empty() : tensor<4x4xf64, #sparse_tensor.encoding<{{.*}}>>
// CHECK-NEXT: %[[VAL_5:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
// CHECK-NEXT: ^bb0(%[[VAL_6:.*]]: i1, %[[VAL_7:.*]]: f64, %[[VAL_8:.*]]: f64, %[[VAL_9:.*]]: f64):
Expand All @@ -90,7 +90,7 @@ func.func @sparse_fuse_slice(%a : tensor<2x3xi64, #SortedCOO>) -> tensor<1x3xi64
func.func @sparse_select(%cond: tensor<4x4xi1>,
%arga: tensor<4x4xf64, #DCSR>,
%argb: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> {
%xv = bufferization.alloc_tensor() : tensor<4x4xf64, #DCSR>
%xv = tensor.empty() : tensor<4x4xf64, #DCSR>
%0 = linalg.generic #sel_trait
ins(%cond, %arga, %argb: tensor<4x4xi1>, tensor<4x4xf64, #DCSR>, tensor<4x4xf64, #DCSR>)
outs(%xv: tensor<4x4xf64, #DCSR>) {
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/SparseTensor/sparse_2d.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ func.func @cmp_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_5:.*]] = bufferization.alloc_tensor() : tensor<2x3xf64, #{{.*}}>>
// CHECK-DAG: %[[VAL_5:.*]] = tensor.empty() : tensor<2x3xf64, #{{.*}}>>
// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<2x3xf64, #{{.*}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<2x3xf64, #{{.*}}>> to memref<?xf64>
Expand Down Expand Up @@ -1142,7 +1142,7 @@ func.func @cmp_ss_ss(%arga: tensor<32x16xf32, #Tss>, %argb: tensor<32x16xf32, #T
// CHECK: }
func.func @sub_ss_batched(%0: tensor<2x3xf64, #BatchedVector>, %1: tensor<2x3xf64, #BatchedVector>)
-> tensor<2x3xf64, #BatchedVector> {
%2 = bufferization.alloc_tensor() : tensor<2x3xf64, #BatchedVector>
%2 = tensor.empty() : tensor<2x3xf64, #BatchedVector>
%3 = linalg.generic #trait2
ins(%0, %1 : tensor<2x3xf64, #BatchedVector>, tensor<2x3xf64, #BatchedVector>)
outs(%2 : tensor<2x3xf64, #BatchedVector>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[VAL_11:.*]] = arith.constant false
// CHECK-DAG: %[[VAL_12:.*]] = bufferization.alloc_tensor() : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
// CHECK-DAG: %[[VAL_12:.*]] = tensor.empty() : tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>
// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
Expand Down Expand Up @@ -261,7 +261,7 @@
// CHECK: }
func.func @conv2d_all_sparse_CSR(%arg0: tensor<8x8xi32, #DCSR>,
%arg1: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> {
%0 = bufferization.alloc_tensor() : tensor<6x6xi32, #DCSR>
%0 = tensor.empty() : tensor<6x6xi32, #DCSR>
%1 = linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
// CHECK: }
func.func @fill_zero_after_alloc(%arg0: tensor<100x200xf64, #DCSR>,
%arg1: tensor<200x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR> {
%0 = bufferization.alloc_tensor() : tensor<100x300xf64, #DCSR>
%0 = tensor.empty() : tensor<100x300xf64, #DCSR>
%cst = arith.constant 0.000000e+00 : f64
%1 = linalg.fill ins(%cst : f64)
outs(%0 : tensor<100x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR>
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/SparseTensor/sparse_index.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_5:.*]] = bufferization.alloc_tensor(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_5:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_8:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse_tensor.encoding
Expand Down Expand Up @@ -52,7 +52,7 @@ func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
%c1 = arith.constant 0 : index
%0 = tensor.dim %arga, %c0 : tensor<?x?xi64, #DenseMatrix>
%1 = tensor.dim %arga, %c1 : tensor<?x?xi64, #DenseMatrix>
%init = bufferization.alloc_tensor(%0, %1) : tensor<?x?xi64, #DenseMatrix>
%init = tensor.empty(%0, %1) : tensor<?x?xi64, #DenseMatrix>
%r = linalg.generic #trait
ins(%arga: tensor<?x?xi64, #DenseMatrix>)
outs(%init: tensor<?x?xi64, #DenseMatrix>) {
Expand All @@ -75,7 +75,7 @@ func.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_5:.*]] = bufferization.alloc_tensor(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_5:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xi64, #sparse_tensor.encoding
// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xi64, #sparse_tensor.encoding
Expand Down Expand Up @@ -109,7 +109,7 @@ func.func @sparse_index(%arga: tensor<?x?xi64, #SparseMatrix>)
%c1 = arith.constant 0 : index
%0 = tensor.dim %arga, %c0 : tensor<?x?xi64, #SparseMatrix>
%1 = tensor.dim %arga, %c1 : tensor<?x?xi64, #SparseMatrix>
%init = bufferization.alloc_tensor(%0, %1) : tensor<?x?xi64, #SparseMatrix>
%init = tensor.empty(%0, %1) : tensor<?x?xi64, #SparseMatrix>
%r = linalg.generic #trait
ins(%arga: tensor<?x?xi64, #SparseMatrix>)
outs(%init: tensor<?x?xi64, #SparseMatrix>) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
// CHECK: return %[[VAL_77]]#0, %[[VAL_77]]#1, %[[VAL_77]]#2, %[[VAL_77]]#3 : memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
func.func @matmul(%A: tensor<4x8xf64, #CSR>,
%B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> {
%C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR>
%C = tensor.empty() : tensor<4x4xf64, #CSR>
%D = linalg.matmul
ins(%A, %B: tensor<4x8xf64, #CSR>, tensor<8x4xf64, #CSR>)
outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR>
Expand Down
Loading

0 comments on commit 3e4a8c2

Please sign in to comment.