From d0d80ed841a43d02b04e2e07d1c21fe1fe4feef3 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 29 Apr 2024 15:33:03 -0700 Subject: [PATCH 1/3] [mlir][sparse] fold explicit value during sparsification This ensures the explicit value is generated (and not a load into the values array). Note that actually not storing values array at all is still TBD, this is just the very first step. --- .../Transforms/Sparsification.cpp | 12 ++- .../Transforms/Utils/CodegenUtils.h | 10 +++ .../SparseTensor/sparse_matmul_one.mlir | 75 +++++++++++++++++++ 3 files changed, 95 insertions(+), 2 deletions(-) create mode 100755 mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 0a9bb40b458d6..b04ca11f714ba 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -498,9 +498,17 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) { Value val = env.exp(exp).val; if (val) return val; - // Load during insertion. + // Get tensor operand. linalg::GenericOp op = env.op(); + Location loc = op.getLoc(); OpOperand *t = &op->getOpOperand(env.exp(exp).tensor); + // Fold binary-valued tensor into explicit value. + const auto stt = getSparseTensorType(t->get()); + if (stt.hasEncoding()) { + if (auto explVal = stt.getExplicitVal()) + return genValFromAttr(builder, loc, explVal); + } + // Load during insertion. if (env.isSparseOutput(t)) { if (env.isCustomReduc()) return genInsertionLoadReduce(env, builder, t); @@ -509,7 +517,7 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) { // Actual load. SmallVector args; Value ptr = genSubscript(env, builder, t, args); - return builder.create(op.getLoc(), ptr, args); + return builder.create(loc, ptr, args); } /// Generates a store on a dense or sparse tensor. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h index ce5831d999e9a..cf3c35f5fa4c7 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h @@ -399,6 +399,16 @@ inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc, return constantI64(builder, loc, static_cast(lt)); } +// Generates a constant from a validated value carrying attribute. +inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) { + if (auto arrayAttr = dyn_cast(attr)) { + Type tp = cast(arrayAttr[0]).getType(); + return builder.create(loc, tp, arrayAttr); + } + return builder.create(loc, cast(attr)); +} + +// TODO: is this at the right place? inline bool isZeroRankedTensorOrScalar(Type type) { auto rtp = dyn_cast(type); return !rtp || rtp.getRank() == 0; diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir new file mode 100755 index 0000000000000..09ec43b393d52 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt %s --linalg-generalize-named-ops \ +// RUN: --sparsification-and-bufferization | FileCheck %s + +#CSR_ones_complex = #sparse_tensor.encoding<{ + map = (d0, d1) -> (d0 : dense, d1 : compressed) +// explicitVal = (1.0, 0.0) : complex, +// implicitVal = (1.0, 0.0) : complex +}> + +#CSR_ones_fp = #sparse_tensor.encoding<{ + map = (d0, d1) -> (d0 : dense, d1 : compressed), + explicitVal = 1.0 : f32, + implicitVal = 0.0 : f32 +}> + +#CSR_ones_int = #sparse_tensor.encoding<{ + map = (d0, d1) -> (d0 : dense, d1 : compressed), + explicitVal = 1 : i32, + implicitVal = 0 : i32 +}> + +// CHECK-LABEL: func.func @matmul_complex +// +// TODO: make this work +// +func.func @matmul_complex(%a: tensor<10x20xcomplex>, + %b: tensor<20x30xcomplex, #CSR_ones_complex>, + %c: tensor<10x30xcomplex>) -> tensor<10x30xcomplex> { + %0 = linalg.matmul + ins(%a, %b: tensor<10x20xcomplex>, tensor<20x30xcomplex,#CSR_ones_complex>) + outs(%c: tensor<10x30xcomplex>) -> tensor<10x30xcomplex> + return %0 : tensor<10x30xcomplex> +} + +// CHECK-LABEL: func.func @matmul_fp +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[X:.*]] = memref.load +// CHECK: scf.for +// CHECK: %[[I:.*]] = memref.load +// CHECK: %[[Y:.*]] = memref.load +// CHECK: %[[M:.*]] = arith.addf %[[Y]], %[[X]] : f32 +// CHECK: memref.store %[[M]] +// CHECK: } +// CHECK: } +// CHECK: } +func.func @matmul_fp(%a: tensor<10x20xf32>, + %b: tensor<20x30xf32, #CSR_ones_fp>, + %c: tensor<10x30xf32>) -> tensor<10x30xf32> { + %0 = linalg.matmul + ins(%a, %b: tensor<10x20xf32>, tensor<20x30xf32,#CSR_ones_fp>) + outs(%c: tensor<10x30xf32>) -> tensor<10x30xf32> + return %0 : tensor<10x30xf32> +} + +// CHECK-LABEL: func.func @matmul_int +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[X:.*]] = memref.load +// CHECK: scf.for +// CHECK: %[[I:.*]] = memref.load +// CHECK: %[[Y:.*]] = memref.load +// CHECK: %[[M:.*]] = arith.addi %[[Y]], %[[X]] : i32 +// CHECK: memref.store %[[M]] +// CHECK: } +// CHECK: } +// CHECK: } +func.func @matmul_int(%a: tensor<10x20xi32>, + %b: tensor<20x30xi32, #CSR_ones_int>, + %c: tensor<10x30xi32>) -> tensor<10x30xi32> { + %0 = linalg.matmul + ins(%a, %b: tensor<10x20xi32>, tensor<20x30xi32,#CSR_ones_int>) + outs(%c: tensor<10x30xi32>) -> tensor<10x30xi32> + return %0 : tensor<10x30xi32> +} From 63d3c90546d7c1253eb9a3a5a85e0f7bb745e452 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 29 Apr 2024 15:59:17 -0700 Subject: [PATCH 2/3] reviewer feedback --- mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp | 6 ++---- mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index b04ca11f714ba..0c8e431d8c996 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -504,10 +504,8 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) { OpOperand *t = &op->getOpOperand(env.exp(exp).tensor); // Fold binary-valued tensor into explicit value. const auto stt = getSparseTensorType(t->get()); - if (stt.hasEncoding()) { - if (auto explVal = stt.getExplicitVal()) - return genValFromAttr(builder, loc, explVal); - } + if (auto explVal = stt.getExplicitVal()) + return genValFromAttr(builder, loc, explVal); // Load during insertion. if (env.isSparseOutput(t)) { if (env.isCustomReduc()) diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir index 09ec43b393d52..82f3147d3206b 100755 --- a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir @@ -4,7 +4,7 @@ #CSR_ones_complex = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) // explicitVal = (1.0, 0.0) : complex, -// implicitVal = (1.0, 0.0) : complex +// implicitVal = (0.0, 0.0) : complex }> #CSR_ones_fp = #sparse_tensor.encoding<{ From da2b9b734fb9ac6660b32452cbfb42bc09b080a1 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 29 Apr 2024 16:11:13 -0700 Subject: [PATCH 3/3] enc guard --- .../mlir/Dialect/SparseTensor/IR/SparseTensorType.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h index 34d99913fbd51..ea3d8013b4567 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -344,10 +344,14 @@ class SparseTensorType { unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; } /// Returns the explicit value, defaulting to null Attribute for unset. - Attribute getExplicitVal() const { return enc.getExplicitVal(); } + Attribute getExplicitVal() const { + return enc ? enc.getExplicitVal() : nullptr; + } /// Returns the implicit value, defaulting to null Attribute for 0. - Attribute getImplicitVal() const { return enc.getImplicitVal(); } + Attribute getImplicitVal() const { + return enc ? enc.getImplicitVal() : nullptr; + } /// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`. Type getCrdType() const { return enc.getCrdElemType(); }