-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Infer Symbolic Shape No.127】[BUAA] CrfDecoding op #67201
Changes from all commits
6a02b5d
f6e98d5
b7dc273
f495f9a
7a5da00
5e85a3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -664,12 +664,104 @@ bool CheckFiniteAndUnscale_OpInferSymbolicShape( | |
return CheckFiniteAndUnscaleOpInferSymbolicShape(op, infer_context); | ||
} | ||
|
||
// bool CrfDecodingOpInferSymbolicShape(pir::Operation *op, | ||
// pir::InferSymbolicShapeContext | ||
// *infer_context) { | ||
// // pass | ||
// return true; | ||
// } | ||
bool CrfDecodingOpInferSymbolicShape( | ||
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { | ||
const auto &emission_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const std::vector<symbol::DimExpr> emission_dims = | ||
emission_shape_or_data.shape(); | ||
|
||
const auto &transition_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(1)); | ||
const std::vector<symbol::DimExpr> transition_dims = | ||
BHmingyang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
transition_shape_or_data.shape(); | ||
|
||
const auto &label_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(2)); | ||
const std::vector<symbol::DimExpr> label_dims = label_shape_or_data.shape(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
|
||
const auto &length_shape_or_data = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(3)); | ||
|
||
const auto one = symbol::DimExpr{1}; | ||
|
||
if (!length_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) { | ||
PADDLE_ENFORCE_EQ(emission_dims.size(), | ||
3, | ||
common::errors::InvalidArgument( | ||
"The Input(Emission) should be a 3-D tensor. But " | ||
"received: input rank %u, input shape [%s]. ", | ||
emission_dims.size(), | ||
emission_dims)); | ||
} else { | ||
PADDLE_ENFORCE_EQ(emission_dims.size(), | ||
2, | ||
common::errors::InvalidArgument( | ||
"The Input(Emission) should be a 2-D tensor. But " | ||
"received: input rank %u, input shape [%s].", | ||
emission_dims.size(), | ||
emission_dims)); | ||
} | ||
|
||
PADDLE_ENFORCE_EQ(transition_dims.size(), | ||
2UL, | ||
common::errors::InvalidArgument( | ||
"The Input(Transition) should be a 2-D tensor. But " | ||
"received: input rank %u, input shape [%s].", | ||
transition_dims.size(), | ||
transition_dims)); | ||
infer_context->AddEqualCstr(transition_dims[0] - 2, transition_dims[1]); | ||
|
||
infer_context->AddEqualCstr(emission_dims[emission_dims.size() - 1], | ||
transition_dims[transition_dims.size() - 1]); | ||
|
||
if (!label_dims.empty()) { | ||
if (!length_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) { | ||
if (label_dims.size() == 3UL) { | ||
infer_context->AddEqualCstr(label_dims[2], one); | ||
} else { | ||
PADDLE_ENFORCE_EQ( | ||
label_dims.size(), | ||
2UL, | ||
common::errors::InvalidArgument( | ||
"The Input(Label) should be a 3-D tensor with last dimension " | ||
"fixed to 1 or a 2-D tensor in padding mode. But received: " | ||
"input " | ||
"rank %u, input shape [%s].", | ||
label_dims.size(), | ||
label_dims)); | ||
} | ||
} else { | ||
if (label_dims.size() == 2UL) { | ||
infer_context->AddEqualCstr(label_dims[2], one); | ||
} else { | ||
PADDLE_ENFORCE_EQ( | ||
label_dims.size(), | ||
1UL, | ||
common::errors::InvalidArgument( | ||
"The Input(Label) should be a 2-D tensor with last " | ||
"dimension fixed to 1 or a 1-D tensor. But received: " | ||
"input rank %u, input shape [%s].", | ||
label_dims.size(), | ||
label_dims)); | ||
} | ||
} | ||
|
||
infer_context->AddEqualCstr(emission_dims[0], label_dims[0]); | ||
|
||
std::vector<symbol::DimExpr> viterbi_path_dims; | ||
viterbi_path_dims.push_back(emission_dims[0]); | ||
if (!length_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) { | ||
viterbi_path_dims.push_back(emission_dims[1]); | ||
} else { | ||
viterbi_path_dims.push_back(symbol::DimExpr(1)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 上面已经定义过one了 |
||
} | ||
|
||
infer_context->SetShapeOrDataForValue( | ||
op->result(0), symbol::TensorShapeOrDataDimExprs(viterbi_path_dims)); | ||
} | ||
return true; | ||
} | ||
|
||
// bool CoalesceTensorOpInferSymbolicShape(pir::Operation *op, | ||
// pir::InferSymbolicShapeContext | ||
|
@@ -971,7 +1063,8 @@ bool FlashAttnOpInferSymbolicShape( | |
// } | ||
|
||
// bool GruOpInferSymbolicShape(pir::Operation *op, | ||
// pir::InferSymbolicShapeContext *infer_context) { | ||
// pir::InferSymbolicShapeContext *infer_context) | ||
// { | ||
// // pass | ||
// return true; | ||
// } | ||
|
@@ -1242,7 +1335,8 @@ bool RoiAlignOpInferSymbolicShape( | |
} | ||
|
||
// bool LstmOpInferSymbolicShape(pir::Operation *op, | ||
// pir::InferSymbolicShapeContext *infer_context) | ||
// pir::InferSymbolicShapeContext | ||
// *infer_context) | ||
Comment on lines
1337
to
+1339
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 自动format的工具注意一下,PR里尽量不要做出无关修改 |
||
// { | ||
// // pass | ||
// return true; | ||
|
@@ -1275,7 +1369,8 @@ bool RoiAlignOpInferSymbolicShape( | |
// } | ||
|
||
// bool MoeOpInferSymbolicShape(pir::Operation *op, | ||
// pir::InferSymbolicShapeContext *infer_context) { | ||
// pir::InferSymbolicShapeContext *infer_context) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 无关修改尽量不要引入 |
||
// { | ||
// // pass | ||
// return true; | ||
// } | ||
|
@@ -1345,7 +1440,8 @@ bool MovingAverageAbsMaxScale_OpInferSymbolicShape( | |
} | ||
|
||
// bool NceOpInferSymbolicShape(pir::Operation *op, | ||
// pir::InferSymbolicShapeContext *infer_context) { | ||
// pir::InferSymbolicShapeContext *infer_context) | ||
// { | ||
// // pass | ||
// return true; | ||
// } | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名不规范,emission_dims