Skip to content

Commit

Permalink
【Infer Symbolic Shape No.92】【BUAA】 Add nll_loss op (#67884)
Browse files Browse the repository at this point in the history
* Unfinished yet

* Finished nll loss op

* Finished nll loss op

* Fixed errors

* Added return true

* Resolved suggested changes
  • Loading branch information
MufanColin authored Sep 6, 2024
1 parent 15a1462 commit 9aacea9
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2252,6 +2252,75 @@ bool MemoryEfficientAttentionOpInferSymbolicShape(

return true;
}

bool NllLossOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &x_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const std::vector<symbol::DimExpr> &x_shape = x_shape_or_data.shape();
const symbol::ShapeOrDataDimExprs &label_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const std::vector<symbol::DimExpr> &label_shape = label_shape_or_data.shape();
PADDLE_ENFORCE_EQ(x_shape.size() == 2 || x_shape.size() == 4,
true,
phi::errors::InvalidArgument(
"The tensor rank of Input(X) must be 2 or 4."));
infer_context->AddEqualCstr(x_shape[0], label_shape[0]);

if (op->operand_source(2)) {
const symbol::ShapeOrDataDimExprs &w_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(2));
const std::vector<symbol::DimExpr> &w_shape = w_shape_or_data.shape();
PADDLE_ENFORCE_EQ(
w_shape.size(),
1,
phi::errors::InvalidArgument("Input(Weight) should be a 1D tensor."));

infer_context->AddEqualCstr(x_shape[1], w_shape[0]);
}

const std::string &reduction =
op->attribute<pir::StrAttribute>("reduction").AsString();

std::vector<symbol::DimExpr> out_shape;
if (x_shape.size() == 2) {
if (reduction == "none") {
out_shape = {x_shape[0]};
} else {
out_shape = std::vector<symbol::DimExpr>{};
}
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});
} else if (x_shape.size() == 4) {
PADDLE_ENFORCE_EQ(label_shape.size(),
3,
phi::errors::InvalidArgument(
"Expected Input(Label) dimensions=3, received %d.",
label_shape.size()));

infer_context->AddEqualCstr(x_shape[0], label_shape[0]);
infer_context->AddEqualCstr(x_shape[2], label_shape[1]);
infer_context->AddEqualCstr(x_shape[3], label_shape[2]);

if (reduction == "none") {
out_shape = {x_shape[0], x_shape[2], x_shape[3]};
} else {
out_shape = std::vector<symbol::DimExpr>{};
}
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});
}
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(std::vector<symbol::DimExpr>{})});
return true;
}

bool RoiPoolOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(NearestInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nce)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(NllLoss)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PsroiPool)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(QuantizeLinear)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(QuantizeLinear_)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3449,6 +3449,7 @@
data_type : input
optional : weight
backward : nll_loss_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : nms
args : (Tensor x, float threshold = 1.0f)
Expand Down

0 comments on commit 9aacea9

Please sign in to comment.