Skip to content

Commit

Permalink
add svd
Browse files Browse the repository at this point in the history
  • Loading branch information
Whsjrczr committed Aug 22, 2024
1 parent 16a749d commit 40cd98a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2592,12 +2592,65 @@ bool SumOpInferSymbolicShape(pir::Operation *op,
return true;
}

// bool SvdOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool SvdOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const std::vector<symbol::DimExpr> &x_shapes = x_shape_or_data.shape();

bool full_matrices =
op->attribute<pir::BoolAttribute>("full_matrices").data();

int x_rank = x_shapes.size();
PADDLE_ENFORCE_GE(
x_rank,
2,
common::errors::InvalidArgument(
"the rank of input must be greater than or equal to 2"));

symbol::DimExpr m = x_shapes[x_rank - 2];
symbol::DimExpr n = x_shapes[x_rank - 1];
symbol::DimExpr k = std::min(m, n);

auto UDDim = [&](const std::vector<symbol::DimExpr> &x_shapes,
const symbol::DimExpr &k) {
std::vector<symbol::DimExpr> x_vec = x_shapes;
x_vec[x_vec.size() - 1] = k;
return x_vec;
};

auto VHDDim = [&](const std::vector<symbol::DimExpr> &x_shapes,
const symbol::DimExpr &k) {
std::vector<symbol::DimExpr> x_vec = x_shapes;
x_vec[x_vec.size() - 2] = k;
return x_vec;
};

auto SDDim = [&](const std::vector<symbol::DimExpr> &x_shapes,
const symbol::DimExpr &k) {
std::vector<symbol::DimExpr> x_vec = x_shapes;
x_vec[x_vec.size() - 2] = k;
x_vec.erase(x_vec.end() - 1); // rank - 1
return x_vec;
};

std::vector<symbol::DimExpr> u_shape =
!full_matrices ? UDDim(x_shapes, k) : UDDim(x_shapes, m);
std::vector<symbol::DimExpr> vh_shape =
!full_matrices ? VHDDim(x_shapes, k) : VHDDim(x_shapes, n);
std::vector<symbol::DimExpr> s_shape = SDDim(x_shapes, k);

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(u_shape)});
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(s_shape)});
infer_context->SetShapeOrDataForValue(
op->result(2),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(vh_shape)});
return true;
}

bool SetValueOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Squeeze)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Squeeze_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(StridedSlice)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sum)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Svd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Svd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValueWithTensor)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4582,7 +4582,7 @@
kernel :
func : svd
backward : svd_grad
# interfaces : paddle::dialect::InferSymbolicShapeInterface
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : swiglu
args : (Tensor x, Tensor y)
Expand Down

0 comments on commit 40cd98a

Please sign in to comment.