From 40cd98a135de092d4c265a8dbe282796c3cd2a7a Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Thu, 22 Aug 2024 14:09:22 +0000 Subject: [PATCH] add svd --- .../infer_symbolic_shape/unary_infer_sym.cc | 65 +++++++++++++++++-- .../infer_symbolic_shape/unary_infer_sym.h | 2 +- paddle/phi/ops/yaml/ops.yaml | 2 +- 3 files changed, 61 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 9ef12303a5c0b..5a979739aaa15 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -2592,12 +2592,65 @@ bool SumOpInferSymbolicShape(pir::Operation *op, return true; } -// bool SvdOpInferSymbolicShape(pir::Operation *op, -// pir::InferSymbolicShapeContext -// *infer_context) { -// // pass -// return true; -// } +bool SvdOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const std::vector &x_shapes = x_shape_or_data.shape(); + + bool full_matrices = + op->attribute("full_matrices").data(); + + int x_rank = x_shapes.size(); + PADDLE_ENFORCE_GE( + x_rank, + 2, + common::errors::InvalidArgument( + "the rank of input must be greater than or equal to 2")); + + symbol::DimExpr m = x_shapes[x_rank - 2]; + symbol::DimExpr n = x_shapes[x_rank - 1]; + symbol::DimExpr k = std::min(m, n); + + auto UDDim = [&](const std::vector &x_shapes, + const symbol::DimExpr &k) { + std::vector x_vec = x_shapes; + x_vec[x_vec.size() - 1] = k; + return x_vec; + }; + + auto VHDDim = [&](const std::vector &x_shapes, + const symbol::DimExpr &k) { + std::vector x_vec = x_shapes; + x_vec[x_vec.size() - 2] = k; + return x_vec; + }; + + auto SDDim = [&](const std::vector &x_shapes, + const symbol::DimExpr &k) { + std::vector x_vec = x_shapes; + x_vec[x_vec.size() - 2] = k; + x_vec.erase(x_vec.end() - 1); // rank - 1 + return x_vec; + }; + + std::vector u_shape = + !full_matrices ? UDDim(x_shapes, k) : UDDim(x_shapes, m); + std::vector vh_shape = + !full_matrices ? VHDDim(x_shapes, k) : VHDDim(x_shapes, n); + std::vector s_shape = SDDim(x_shapes, k); + + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(u_shape)}); + infer_context->SetShapeOrDataForValue( + op->result(1), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(s_shape)}); + infer_context->SetShapeOrDataForValue( + op->result(2), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(vh_shape)}); + return true; +} bool SetValueOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 9600245696389..382cf18954f66 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -129,7 +129,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Squeeze) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Squeeze_) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(StridedSlice) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sum) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Svd) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Svd) OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue) OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValueWithTensor) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index c8fa663264140..9f048c6801bcd 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -4582,7 +4582,7 @@ kernel : func : svd backward : svd_grad - # interfaces : paddle::dialect::InferSymbolicShapeInterface + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : swiglu args : (Tensor x, Tensor y)