Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Infer Symbolic Shape No.230】【BUAA】Add stft, changed 3 files #67663

Merged
merged 69 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
163571f
add 3 files
Whsjrczr Aug 8, 2024
75facec
update .cc
Whsjrczr Aug 9, 2024
5972d24
undo change in .cc file
Whsjrczr Aug 9, 2024
545b90b
update num_samples in .cc
Whsjrczr Aug 9, 2024
0b44b7e
attribute
Whsjrczr Aug 9, 2024
9e0d933
update
Whsjrczr Aug 9, 2024
bf533d3
update
Whsjrczr Aug 9, 2024
2880140
update
Whsjrczr Aug 9, 2024
e2a5d42
update
Whsjrczr Aug 9, 2024
5dc6a76
update
Whsjrczr Aug 9, 2024
b21f08f
update
Whsjrczr Aug 9, 2024
89ff139
update
Whsjrczr Aug 9, 2024
ecc7735
updata
Whsjrczr Aug 9, 2024
b56968e
updata
Whsjrczr Aug 9, 2024
b7a8875
update
Whsjrczr Aug 9, 2024
30c9e3f
update
Whsjrczr Aug 9, 2024
d6b1104
update
Whsjrczr Aug 9, 2024
652573e
update
Whsjrczr Aug 9, 2024
475014d
update
Whsjrczr Aug 9, 2024
97caa5e
update
Whsjrczr Aug 9, 2024
7945ac9
u
Whsjrczr Aug 9, 2024
620525c
batch_function
Whsjrczr Aug 12, 2024
16be43e
bincount
Whsjrczr Aug 12, 2024
47c5d45
update batchfc
Whsjrczr Aug 12, 2024
c4142fe
update batchfc
Whsjrczr Aug 12, 2024
6379805
update EQ
Whsjrczr Aug 12, 2024
e90891f
update {-1}
Whsjrczr Aug 12, 2024
8b3ad1d
update binary with output_size
Whsjrczr Aug 12, 2024
949b95c
undo change
Whsjrczr Aug 12, 2024
f2c974a
{-1}
Whsjrczr Aug 12, 2024
3f1c49f
add Bincount
Whsjrczr Aug 12, 2024
7a5a3ed
update class_center_sample
Whsjrczr Aug 12, 2024
eef3cd6
delete batch_fc
Whsjrczr Aug 12, 2024
28a64d7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 13, 2024
5d447ea
update batchnormop
Whsjrczr Aug 13, 2024
7c24bfe
update batchnormop
Whsjrczr Aug 13, 2024
877670c
changed bn
Whsjrczr Aug 13, 2024
4b2cc16
changed bn
Whsjrczr Aug 13, 2024
9685c7d
unduo _
Whsjrczr Aug 14, 2024
a8f1277
Update DimExpr
Whsjrczr Aug 15, 2024
2774f51
Update unary_infer_sym.cc
Whsjrczr Aug 15, 2024
52ef749
out_unknown
Whsjrczr Aug 16, 2024
e53bc52
Expr -> Exprs
Whsjrczr Aug 16, 2024
b09ad2c
{{out_dims}} -> out_dims
Whsjrczr Aug 16, 2024
a010e0f
delete ()
Whsjrczr Aug 16, 2024
88bb994
Merge branch 'api1' of https://github.com/Whsjrczr/Paddle into develop
Whsjrczr Aug 19, 2024
eacc7f9
Merge branch 'api2' of https://github.com/Whsjrczr/Paddle into develop
Whsjrczr Aug 19, 2024
95bb57b
update dimexpr
Whsjrczr Aug 19, 2024
f58bd89
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 20, 2024
de5ed23
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 21, 2024
ad1f042
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 22, 2024
1fa61f6
restore changes
Whsjrczr Aug 22, 2024
f7a4020
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Aug 22, 2024
43a0978
undo
Whsjrczr Aug 22, 2024
7b56613
add stft
Whsjrczr Aug 22, 2024
78b5545
update annotations
Whsjrczr Aug 22, 2024
200b27d
undo some changes
Whsjrczr Aug 22, 2024
1a195a9
add constrain
Whsjrczr Aug 26, 2024
ab103c6
Update type
Whsjrczr Aug 31, 2024
4e0d9f0
Update binary_infer_sym.cc
Whsjrczr Aug 31, 2024
05c741c
Update binary_infer_sym.cc
Whsjrczr Aug 31, 2024
77b1194
add if
Whsjrczr Sep 2, 2024
5d97a26
change place of constrain
Whsjrczr Sep 2, 2024
fdd9b3f
add defination of seq_length
Whsjrczr Sep 3, 2024
c7afa9c
LE into if
Whsjrczr Sep 3, 2024
4f44d79
change get
Whsjrczr Sep 4, 2024
c8eba0a
add symbol::DimExpr onto 1
Whsjrczr Sep 4, 2024
c116940
rerun
Whsjrczr Sep 11, 2024
05c2dfe
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Whsjrczr Sep 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1546,11 +1546,68 @@ bool SequenceMaskOpInferSymbolicShape(
// return true;
// }

// bool StftOpInferSymbolicShape(
// pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
// // pass
// return true;
// }
bool StftOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &window_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();
const std::vector<symbol::DimExpr> &window_shape =
window_shape_or_data.shape();

int n_fft = op->attribute<pir::Int32Attribute>("n_fft").data();
int hop_length = op->attribute<pir::Int32Attribute>("hop_length").data();
bool onesided = op->attribute<pir::BoolAttribute>("onesided").data();

const int x_rank = x_shape.size();

PADDLE_ENFORCE_EQ(
x_rank,
2,
common::errors::InvalidArgument(
"Input(X) of StftOp should be a tensor with shape [N, T], "
"but got rank %s.",
x_rank));

PADDLE_ENFORCE_GT(
hop_length,
0,
common::errors::InvalidArgument(
"Attribute(hop_length) should be greater than 0, but got %s.",
hop_length));

infer_context->AddEqualCstr(window_shape[0], symbol::DimExpr{n_fft});
const symbol::DimExpr seq_length = x_shape[x_rank - 1];
const symbol::DimExpr n_frames =
(symbol::DimExpr{1}) +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里多了个括号

(seq_length - symbol::DimExpr{n_fft}) / symbol::DimExpr{hop_length};

if (seq_length.isa<int64_t>()) {
PADDLE_ENFORCE_LE(n_fft,
seq_length.Get<std::int64_t>(),
common::errors::InvalidArgument(
"Attribute(frame_length) should be less equal than "
"sequence length, but got (%s) > (%s).",
n_fft,
seq_length.Get<std::int64_t>()));
}

std::vector<symbol::DimExpr> output_shape;
output_shape.push_back(x_shape[0]);
if (onesided) {
output_shape.push_back(symbol::DimExpr{n_fft / 2 + 1});
} else {
output_shape.push_back(symbol::DimExpr{n_fft});
}
output_shape.push_back(n_frames);

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(output_shape)});
return true;
}

bool SwigluOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(SegmentPool)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SequenceMask)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShuffleBatch)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Solve)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Stft)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Stft)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Swiglu)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TakeAlongAxis)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TopPSampling)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4592,7 +4592,7 @@
func: stft
data_type: x
backward: stft_grad
# interfaces : paddle::dialect::InferSymbolicShapeInterface
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : strided_slice
args : (Tensor x, int[] axes, IntArray starts, IntArray ends, IntArray strides)
Expand Down