-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Infer Symbolic Shape No.127】[BUAA] CrfDecoding op #67201
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
const auto &length_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(3)); | ||
bool has_length = !length_shape_or_data.shape().empty(); | ||
|
||
const auto one = symbol::DimExpr{1}; | ||
|
||
if (has_length) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
判断has_length的逻辑需修改,可以使用是否为NullShapeOrDataDimExpr来判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if (m.isa<int64_t>() && n.isa<int64_t>()) { | ||
int m_value = static_cast<int>(m.Get<std::int64_t>()); | ||
int n_value = static_cast<int>(n.Get<std::int64_t>()); | ||
if (m_value > 0 && n_value > 0) { | ||
infer_context->AddEqualCstr(emission_dims[emission_dims.size() - 1], | ||
transition_dims[transition_dims.size() - 1]); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
infer meta里需要判断>0是因为-1为动态shape,这里直接添加对应约束即可
const symbol::DimExpr n1 = label_dims[0]; | ||
if (m1.isa<int64_t>() && n1.isa<int64_t>()) { | ||
int m1_value = static_cast<int>(m1.Get<std::int64_t>()); | ||
int n1_value = static_cast<int>(n1.Get<std::int64_t>()); | ||
if (m1_value > 0 && n1_value > 0) { | ||
infer_context->AddEqualCstr(emission_dims[0], label_dims[0]); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
// bool LstmOpInferSymbolicShape(pir::Operation *op, | ||
// pir::InferSymbolicShapeContext *infer_context) | ||
// pir::InferSymbolicShapeContext | ||
// *infer_context) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
自动format的工具注意一下,PR里尽量不要做出无关修改
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const auto &emission_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
std::vector<symbol::DimExpr> emission_dims = emission_shape_or_data.shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emission_dims没有修改的话建议使用const std::vectorsymbol::DimExpr&, 以避免vector构建开销
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
const auto &label_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(2)); | ||
std::vector<symbol::DimExpr> label_dims = label_shape_or_data.shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
@@ -1275,7 +1369,8 @@ bool RoiAlignOpInferSymbolicShape( | |||
// } | |||
|
|||
// bool MoeOpInferSymbolicShape(pir::Operation *op, | |||
// pir::InferSymbolicShapeContext *infer_context) { | |||
// pir::InferSymbolicShapeContext *infer_context) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无关修改尽量不要引入
if (!length_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) { | ||
viterbi_path_dims.push_back(emission_dims[1]); | ||
} else { | ||
viterbi_path_dims.push_back(symbol::DimExpr(1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面已经定义过one了
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const auto &emission_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const std::vector<symbol::DimExpr> emission_dims = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不规范,emission_dims
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc
Show resolved
Hide resolved
|
||
const auto &label_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(2)); | ||
const std::vector<symbol::DimExpr> label_dims = label_shape_or_data.shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
规范性问题可以放到下个pr修改,贴出修改pr链接
Description 描述不完善 |
PR Category
CINNPR Types
OthersDescription
添加了CrfDecoding算子,感觉中等难度的比较难,所以先提交了一个。
test/legacy_test/test_crf_decoding_op.py 已经包含 check_test_output