-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Infer Symbolic Shape No.230】【BUAA】Add stft, changed 3 files #67663
Conversation
|
||
infer_context->AddEqualCstr(window_shape[0], symbol::DimExpr{n_fft}); | ||
|
||
int seq_length = x_shape[x_rank - 1].Get<std::int64_t>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里加个判断吧,是int类型再get
int seq_length = x_shape[x_rank - 1].Get<std::int64_t>(); | ||
symbol::DimExpr n_frames = 1 + (seq_length - n_fft) / hop_length; | ||
|
||
PADDLE_ENFORCE_LE(n_fft, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
放在上面的if分支里
|
||
if (x_shape[x_rank - 1].isa<int64_t>()) { | ||
int seq_length = x_shape[x_rank - 1].Get<std::int64_t>(); | ||
int n_frames = 1 + (seq_length - n_fft) / hop_length; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seq_length 和 n_frames放在if外面,直接声明为DimExpr类型,if判断是否是int只服务于ENFORCE_LE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
} else { | ||
output_shape.push_back(symbol::DimExpr{n_fft}); | ||
} | ||
output_shape.push_back(symbol::DimExpr{n_frames}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面声明用的是DimExpr的话,这里就不用转换了
infer_context->AddEqualCstr(window_shape[0], symbol::DimExpr{n_fft}); | ||
const symbol::DimExpr seq_length = x_shape[x_rank - 1]; | ||
const symbol::DimExpr n_frames = | ||
(symbol::DimExpr{1}) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里多了个括号
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
CINN
PR Types
Others
Description
加入stft,单测位于
test\legacy_test\test_stft_op.py
check_dygraph=False
,默认的check_pir=False, check_symbol_infer=True