Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Infer Symbolic Shape No.127】[BUAA] CrfDecoding op #67201

Merged
merged 6 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -664,12 +664,104 @@ bool CheckFiniteAndUnscale_OpInferSymbolicShape(
return CheckFiniteAndUnscaleOpInferSymbolicShape(op, infer_context);
}

// bool CrfDecodingOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool CrfDecodingOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &emission_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const std::vector<symbol::DimExpr> emission_dims =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名不规范,emission_dims

emission_shape_or_data.shape();

const auto &transition_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const std::vector<symbol::DimExpr> transition_dims =
BHmingyang marked this conversation as resolved.
Show resolved Hide resolved
transition_shape_or_data.shape();

const auto &label_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(2));
const std::vector<symbol::DimExpr> label_dims = label_shape_or_data.shape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上


const auto &length_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(3));

const auto one = symbol::DimExpr{1};

if (!length_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
PADDLE_ENFORCE_EQ(emission_dims.size(),
3,
common::errors::InvalidArgument(
"The Input(Emission) should be a 3-D tensor. But "
"received: input rank %u, input shape [%s]. ",
emission_dims.size(),
emission_dims));
} else {
PADDLE_ENFORCE_EQ(emission_dims.size(),
2,
common::errors::InvalidArgument(
"The Input(Emission) should be a 2-D tensor. But "
"received: input rank %u, input shape [%s].",
emission_dims.size(),
emission_dims));
}

PADDLE_ENFORCE_EQ(transition_dims.size(),
2UL,
common::errors::InvalidArgument(
"The Input(Transition) should be a 2-D tensor. But "
"received: input rank %u, input shape [%s].",
transition_dims.size(),
transition_dims));
infer_context->AddEqualCstr(transition_dims[0] - 2, transition_dims[1]);

infer_context->AddEqualCstr(emission_dims[emission_dims.size() - 1],
transition_dims[transition_dims.size() - 1]);

if (!label_dims.empty()) {
if (!length_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
if (label_dims.size() == 3UL) {
infer_context->AddEqualCstr(label_dims[2], one);
} else {
PADDLE_ENFORCE_EQ(
label_dims.size(),
2UL,
common::errors::InvalidArgument(
"The Input(Label) should be a 3-D tensor with last dimension "
"fixed to 1 or a 2-D tensor in padding mode. But received: "
"input "
"rank %u, input shape [%s].",
label_dims.size(),
label_dims));
}
} else {
if (label_dims.size() == 2UL) {
infer_context->AddEqualCstr(label_dims[2], one);
} else {
PADDLE_ENFORCE_EQ(
label_dims.size(),
1UL,
common::errors::InvalidArgument(
"The Input(Label) should be a 2-D tensor with last "
"dimension fixed to 1 or a 1-D tensor. But received: "
"input rank %u, input shape [%s].",
label_dims.size(),
label_dims));
}
}

infer_context->AddEqualCstr(emission_dims[0], label_dims[0]);

std::vector<symbol::DimExpr> viterbi_path_dims;
viterbi_path_dims.push_back(emission_dims[0]);
if (!length_shape_or_data.isa<symbol::NullShapeOrDataDimExpr>()) {
viterbi_path_dims.push_back(emission_dims[1]);
} else {
viterbi_path_dims.push_back(symbol::DimExpr(1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面已经定义过one了

}

infer_context->SetShapeOrDataForValue(
op->result(0), symbol::TensorShapeOrDataDimExprs(viterbi_path_dims));
}
return true;
}

// bool CoalesceTensorOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
Expand Down Expand Up @@ -971,7 +1063,8 @@ bool FlashAttnOpInferSymbolicShape(
// }

// bool GruOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context) {
// pir::InferSymbolicShapeContext *infer_context)
// {
// // pass
// return true;
// }
Expand Down Expand Up @@ -1242,7 +1335,8 @@ bool RoiAlignOpInferSymbolicShape(
}

// bool LstmOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context)
// pir::InferSymbolicShapeContext
// *infer_context)
Comment on lines 1337 to +1339
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

自动format的工具注意一下,PR里尽量不要做出无关修改

// {
// // pass
// return true;
Expand Down Expand Up @@ -1275,7 +1369,8 @@ bool RoiAlignOpInferSymbolicShape(
// }

// bool MoeOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context) {
// pir::InferSymbolicShapeContext *infer_context)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无关修改尽量不要引入

// {
// // pass
// return true;
// }
Expand Down Expand Up @@ -1345,7 +1440,8 @@ bool MovingAverageAbsMaxScale_OpInferSymbolicShape(
}

// bool NceOpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext *infer_context) {
// pir::InferSymbolicShapeContext *infer_context)
// {
// // pass
// return true;
// }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bilinear)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BilinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CheckFiniteAndUnscale)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CheckFiniteAndUnscale_)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrfDecoding)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrfDecoding)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Concat)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CrossEntropyWithSoftmax_)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,7 @@
func: crf_decoding
data_type: emission
optional: label, length
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : crop
args : (Tensor x, IntArray shape = {}, IntArray offsets = {})
Expand Down