Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Infer Symbolic Shape No.127】[BUAA] CrfDecoding op #67201

Merged
merged 6 commits into from
Aug 13, 2024

Conversation

BHmingyang
Copy link
Contributor

PR Category

CINN

PR Types

Others

Description

添加了CrfDecoding算子,感觉中等难度的比较难,所以先提交了一个。
test/legacy_test/test_crf_decoding_op.py 已经包含 check_test_output

Copy link

paddle-bot bot commented Aug 8, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@luotao1 luotao1 added contributor External developers HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务 labels Aug 8, 2024
@BHmingyang BHmingyang closed this Aug 8, 2024
@BHmingyang BHmingyang reopened this Aug 8, 2024
@BHmingyang BHmingyang closed this Aug 8, 2024
@BHmingyang BHmingyang reopened this Aug 8, 2024
@BHmingyang BHmingyang closed this Aug 8, 2024
@BHmingyang BHmingyang reopened this Aug 8, 2024
@BHmingyang
Copy link
Contributor Author

截屏2024-08-09 10 21 06 test_crf_decoding_op.py 中self.check_output() 中 check_pir 默认关闭,申请豁免coverage流水线。

Comment on lines 682 to 688
const auto &length_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(3));
bool has_length = !length_shape_or_data.shape().empty();

const auto one = symbol::DimExpr{1};

if (has_length) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

判断has_length的逻辑需修改,可以使用是否为NullShapeOrDataDimExpr来判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 717 to 724
if (m.isa<int64_t>() && n.isa<int64_t>()) {
int m_value = static_cast<int>(m.Get<std::int64_t>());
int n_value = static_cast<int>(n.Get<std::int64_t>());
if (m_value > 0 && n_value > 0) {
infer_context->AddEqualCstr(emission_dims[emission_dims.size() - 1],
transition_dims[transition_dims.size() - 1]);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infer meta里需要判断>0是因为-1为动态shape,这里直接添加对应约束即可

Comment on lines 759 to 766
const symbol::DimExpr n1 = label_dims[0];
if (m1.isa<int64_t>() && n1.isa<int64_t>()) {
int m1_value = static_cast<int>(m1.Get<std::int64_t>());
int n1_value = static_cast<int>(n1.Get<std::int64_t>());
if (m1_value > 0 && n1_value > 0) {
infer_context->AddEqualCstr(emission_dims[0], label_dims[0]);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Comment on lines 1353 to +1355
// bool LstmOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context)
// pir::InferSymbolicShapeContext
// *infer_context)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

自动format的工具注意一下,PR里尽量不要做出无关修改

pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &emission_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
std::vector<symbol::DimExpr> emission_dims = emission_shape_or_data.shape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emission_dims没有修改的话建议使用const std::vectorsymbol::DimExpr&, 以避免vector构建开销

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


const auto &label_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(2));
std::vector<symbol::DimExpr> label_dims = label_shape_or_data.shape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -1275,7 +1369,8 @@ bool RoiAlignOpInferSymbolicShape(
// }

// bool MoeOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context) {
// pir::InferSymbolicShapeContext *infer_context)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无关修改尽量不要引入

if (!length_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
viterbi_path_dims.push_back(emission_dims[1]);
} else {
viterbi_path_dims.push_back(symbol::DimExpr(1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面已经定义过one了

pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &emission_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const std::vector<symbol::DimExpr> emission_dims =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名不规范,emission_dims


const auto &label_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(2));
const std::vector<symbol::DimExpr> label_dims = label_shape_or_data.shape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor

@gongshaotian gongshaotian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
规范性问题可以放到下个pr修改,贴出修改pr链接

@gongshaotian
Copy link
Contributor

Description 描述不完善
check_output() 中未开启 check_pir flag,所以coverage 流水线检查才没过

@luotao1 luotao1 merged commit 6754aa0 into PaddlePaddle:develop Aug 13, 2024
31 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants