Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Infer Symbolic Shape No.232】【BUAA】Add svd, changed 3 files #67664

Merged
merged 64 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
163571f
add 3 files
Whsjrczr Aug 8, 2024
75facec
update .cc
Whsjrczr Aug 9, 2024
5972d24
undo change in .cc file
Whsjrczr Aug 9, 2024
545b90b
update num_samples in .cc
Whsjrczr Aug 9, 2024
0b44b7e
attribute
Whsjrczr Aug 9, 2024
9e0d933
update
Whsjrczr Aug 9, 2024
bf533d3
update
Whsjrczr Aug 9, 2024
2880140
update
Whsjrczr Aug 9, 2024
e2a5d42
update
Whsjrczr Aug 9, 2024
5dc6a76
update
Whsjrczr Aug 9, 2024
b21f08f
update
Whsjrczr Aug 9, 2024
89ff139
update
Whsjrczr Aug 9, 2024
ecc7735
updata
Whsjrczr Aug 9, 2024
b56968e
updata
Whsjrczr Aug 9, 2024
b7a8875
update
Whsjrczr Aug 9, 2024
30c9e3f
update
Whsjrczr Aug 9, 2024
d6b1104
update
Whsjrczr Aug 9, 2024
652573e
update
Whsjrczr Aug 9, 2024
475014d
update
Whsjrczr Aug 9, 2024
97caa5e
update
Whsjrczr Aug 9, 2024
7945ac9
u
Whsjrczr Aug 9, 2024
620525c
batch_function
Whsjrczr Aug 12, 2024
16be43e
bincount
Whsjrczr Aug 12, 2024
47c5d45
update batchfc
Whsjrczr Aug 12, 2024
c4142fe
update batchfc
Whsjrczr Aug 12, 2024
6379805
update EQ
Whsjrczr Aug 12, 2024
e90891f
update {-1}
Whsjrczr Aug 12, 2024
8b3ad1d
update binary with output_size
Whsjrczr Aug 12, 2024
949b95c
undo change
Whsjrczr Aug 12, 2024
f2c974a
{-1}
Whsjrczr Aug 12, 2024
3f1c49f
add Bincount
Whsjrczr Aug 12, 2024
7a5a3ed
update class_center_sample
Whsjrczr Aug 12, 2024
eef3cd6
delete batch_fc
Whsjrczr Aug 12, 2024
28a64d7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 13, 2024
5d447ea
update batchnormop
Whsjrczr Aug 13, 2024
7c24bfe
update batchnormop
Whsjrczr Aug 13, 2024
877670c
changed bn
Whsjrczr Aug 13, 2024
4b2cc16
changed bn
Whsjrczr Aug 13, 2024
9685c7d
unduo _
Whsjrczr Aug 14, 2024
a8f1277
Update DimExpr
Whsjrczr Aug 15, 2024
2774f51
Update unary_infer_sym.cc
Whsjrczr Aug 15, 2024
52ef749
out_unknown
Whsjrczr Aug 16, 2024
e53bc52
Expr -> Exprs
Whsjrczr Aug 16, 2024
b09ad2c
{{out_dims}} -> out_dims
Whsjrczr Aug 16, 2024
a010e0f
delete ()
Whsjrczr Aug 16, 2024
88bb994
Merge branch 'api1' of https://github.com/Whsjrczr/Paddle into develop
Whsjrczr Aug 19, 2024
eacc7f9
Merge branch 'api2' of https://github.com/Whsjrczr/Paddle into develop
Whsjrczr Aug 19, 2024
95bb57b
update dimexpr
Whsjrczr Aug 19, 2024
f58bd89
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 20, 2024
de5ed23
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 21, 2024
ad1f042
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 22, 2024
1fa61f6
restore changes
Whsjrczr Aug 22, 2024
f7a4020
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 22, 2024
43a0978
undo
Whsjrczr Aug 22, 2024
16a749d
undo some changes
Whsjrczr Aug 22, 2024
40cd98a
add svd
Whsjrczr Aug 22, 2024
4f8b6ff
add get
Whsjrczr Sep 1, 2024
7449c46
.
Whsjrczr Sep 1, 2024
bd94761
add builder
Whsjrczr Sep 4, 2024
5250d3d
m -> M
Whsjrczr Sep 4, 2024
90120da
removed .get
Whsjrczr Sep 9, 2024
d973ae1
rerun
Whsjrczr Sep 10, 2024
d858afb
rerun
Whsjrczr Sep 11, 2024
799e1b1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Sep 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3212,12 +3212,65 @@ bool SumOpInferSymbolicShape(pir::Operation *op,
return true;
}

// bool SvdOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool SvdOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();
bool full_matrices =
op->attribute<pir::BoolAttribute>("full_matrices").data();

const int x_rank = x_shape.size();
PADDLE_ENFORCE_GE(
x_rank,
2,
common::errors::InvalidArgument(
"the rank of input must be greater than or equal to 2"));

const symbol::DimExpr m = x_shape[x_rank - 2];
const symbol::DimExpr n = x_shape[x_rank - 1];
symbol::DimExprBuilder builder;
const symbol::DimExpr k = builder.Min(m, n);

auto UDDim = [&](const std::vector<symbol::DimExpr> &x_shape,
const symbol::DimExpr &k) {
std::vector<symbol::DimExpr> x_vec = x_shape;
x_vec[x_vec.size() - 1] = k;
return x_vec;
};

auto VHDDim = [&](const std::vector<symbol::DimExpr> &x_shape,
const symbol::DimExpr &k) {
std::vector<symbol::DimExpr> x_vec = x_shape;
x_vec[x_vec.size() - 2] = k;
return x_vec;
};

auto SDDim = [&](const std::vector<symbol::DimExpr> &x_shape,
const symbol::DimExpr &k) {
std::vector<symbol::DimExpr> x_vec = x_shape;
x_vec[x_vec.size() - 2] = k;
x_vec.erase(x_vec.end() - 1); // rank - 1
return x_vec;
};

std::vector<symbol::DimExpr> u_shape =
!full_matrices ? UDDim(x_shape, k) : UDDim(x_shape, m);
std::vector<symbol::DimExpr> vh_shape =
!full_matrices ? VHDDim(x_shape, k) : VHDDim(x_shape, n);
std::vector<symbol::DimExpr> s_shape = SDDim(x_shape, k);

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(u_shape)});
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(s_shape)});
infer_context->SetShapeOrDataForValue(
op->result(2),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(vh_shape)});
return true;
}

bool SetValueOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Squeeze)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Squeeze_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(StridedSlice)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sum)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Svd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Svd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValue_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SetValueWithTensor)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4625,7 +4625,7 @@
kernel :
func : svd
backward : svd_grad
# interfaces : paddle::dialect::InferSymbolicShapeInterface
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : swiglu
args : (Tensor x, Tensor y)
Expand Down