diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 6005f46f24d3a5..d1b258c77ab617 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -2252,6 +2252,75 @@ bool MemoryEfficientAttentionOpInferSymbolicShape( return true; } + +bool NllLossOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const symbol::ShapeOrDataDimExprs &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const std::vector &x_shape = x_shape_or_data.shape(); + const symbol::ShapeOrDataDimExprs &label_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + const std::vector &label_shape = label_shape_or_data.shape(); + PADDLE_ENFORCE_EQ(x_shape.size() == 2 || x_shape.size() == 4, + true, + phi::errors::InvalidArgument( + "The tensor rank of Input(X) must be 2 or 4.")); + infer_context->AddEqualCstr(x_shape[0], label_shape[0]); + + if (op->operand_source(2)) { + const symbol::ShapeOrDataDimExprs &w_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(2)); + const std::vector &w_shape = w_shape_or_data.shape(); + PADDLE_ENFORCE_EQ( + w_shape.size(), + 1, + phi::errors::InvalidArgument("Input(Weight) should be a 1D tensor.")); + + infer_context->AddEqualCstr(x_shape[1], w_shape[0]); + } + + const std::string &reduction = + op->attribute("reduction").AsString(); + + std::vector out_shape; + if (x_shape.size() == 2) { + if (reduction == "none") { + out_shape = {x_shape[0]}; + } else { + out_shape = std::vector{}; + } + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_shape)}); + } else if (x_shape.size() == 4) { + PADDLE_ENFORCE_EQ(label_shape.size(), + 3, + phi::errors::InvalidArgument( + "Expected Input(Label) dimensions=3, received %d.", + label_shape.size())); + + infer_context->AddEqualCstr(x_shape[0], label_shape[0]); + infer_context->AddEqualCstr(x_shape[2], label_shape[1]); + infer_context->AddEqualCstr(x_shape[3], label_shape[2]); + + if (reduction == "none") { + out_shape = {x_shape[0], x_shape[2], x_shape[3]}; + } else { + out_shape = std::vector{}; + } + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_shape)}); + } + infer_context->SetShapeOrDataForValue( + op->result(1), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(std::vector{})}); + return true; +} + bool RoiPoolOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index 3c582a49d583ef..39fa3a65c44f7a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -91,6 +91,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(NearestInterp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nce) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(NllLoss) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(PsroiPool) OP_DECLARE_INFER_SYMBOLIC_SHAPE(QuantizeLinear) OP_DECLARE_INFER_SYMBOLIC_SHAPE(QuantizeLinear_) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 8e11a01498de78..9569bbd10a3edf 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3449,6 +3449,7 @@ data_type : input optional : weight backward : nll_loss_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : nms args : (Tensor x, float threshold = 1.0f)