From 2a2ca540b63f44eb286a9732338a4348e7dbeb1b Mon Sep 17 00:00:00 2001
From: MufanColin <76479709+MufanColin@users.noreply.github.com>
Date: Wed, 7 Aug 2024 14:37:17 +0800
Subject: [PATCH 1/8] Finished norm and p_norm

---
 .../infer_symbolic_shape/unary_infer_sym.cc   | 109 ++++++++++++++++--
 .../infer_symbolic_shape/unary_infer_sym.h    |   4 +-
 paddle/phi/ops/yaml/ops.yaml                  |   2 +
 3 files changed, 101 insertions(+), 14 deletions(-)

diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
index 5f73602b90e9a..66952b2985671 100644
--- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
+++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
@@ -1158,12 +1158,33 @@ bool MeanAllOpInferSymbolicShape(
 //   return true;
 // }
 
-// bool NormOpInferSymbolicShape(pir::Operation *op,
-//                               pir::InferSymbolicShapeContext *infer_context)
-//                               {
-//   // pass
-//   return true;
-// }
+bool NormOpInferSymbolicShape(pir::Operation *op,
+                              pir::InferSymbolicShapeContext *infer_context) {
+  auto x_shape_or_data =
+      infer_context->GetShapeOrDataForValue(op->operand_source(0));
+  const auto &x_shape = x_shape_or_data.shape();
+
+  infer_context->SetShapeOrDataForValue(
+      op->result(0),
+      symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)});
+
+  int axis = op->attribute<pir::Int32Attribute>("axis").data();
+  float epsilon = op->attribute<pir::FloatAttribute>("epsilon").data();
+  bool is_test = op->attribute<pir::BoolAttribute>("is_test").data();
+
+  if (!is_test) {
+    if (axis < 0) axis += x_shape.size();
+
+    auto norm_shape = x_shape;
+    norm_shape[axis] = symbol::DimExpr(1);
+    infer_context->SetShapeOrDataForValue(
+        op->result(1),
+        symbol::ShapeOrDataDimExprs{
+            symbol::TensorShapeOrDataDimExprs(norm_shape)});
+  }
+
+  return true;
+}
 
 bool NonzeroOpInferSymbolicShape(
     pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
@@ -1198,12 +1219,76 @@ bool NumelOpInferSymbolicShape(pir::Operation *op,
 
   return true;
 }
-// bool P_NormOpInferSymbolicShape(pir::Operation *op,
-//                                 pir::InferSymbolicShapeContext
-//                                 *infer_context) {
-//   // pass
-//   return true;
-// }
+
+bool P_NormOpInferSymbolicShape(pir::Operation *op,
+                                pir::InferSymbolicShapeContext *infer_context) {
+  auto x_shape_or_data =
+      infer_context->GetShapeOrDataForValue(op->operand_source(0));
+  const auto &x_shape = x_shape_or_data.shape();
+  auto x_rank = x_shape.size();
+
+  float porder = op->attribute<pir::FloatAttribute>("porder").data();
+  int axis = op->attribute<pir::Int32Attribute>("axis").data();
+  float epsilon = op->attribute<pir::FloatAttribute>("epsilon").data();
+  bool keepdim = op->attribute<pir::BoolAttribute>("keepdim").data();
+  bool asvector = op->attribute<pir::BoolAttribute>("asvector").data();
+
+  PADDLE_ENFORCE_GE(axis,
+                    -x_rank,
+                    common::errors::InvalidArgument(
+                        "Attr(axis) value should be in range [-R, R-1], R is "
+                        "the rank of Input(X). But received axis: %d, R: %d. "
+                        "Current Input(X)'s shape is=[%s].",
+                        axis,
+                        x_rank,
+                        x_shape));
+  PADDLE_ENFORCE_LT(axis,
+                    x_rank,
+                    common::errors::InvalidArgument(
+                        "Attr(axis) value should be in range [-R, R-1], R is "
+                        "the rank of Input(X). But received axis: %d, R: %d. "
+                        "Current Input(X)'s shape is=[%s].",
+                        axis,
+                        x_rank,
+                        x_shape));
+
+  std::vector<symbol::DimExpr> out_dim_vector;
+
+  if (asvector) {
+    if (keepdim) {
+      for (size_t i = 0; i < x_rank; ++i) {
+        out_dim_vector.emplace_back(symbol::DimExpr(1));
+      }
+    } else {
+      out_dim_vector = {};
+    }
+  } else {
+    if (axis < 0) axis += x_rank;
+
+    if (keepdim) {
+      for (size_t i = 0; i < x_rank; ++i) {
+        if (static_cast<int>(i) == axis) {
+          out_dim_vector.emplace_back(symbol::DimExpr(1));
+        } else {
+          out_dim_vector.emplace_back(x_shape[i]);
+        }
+      }
+    } else {
+      for (size_t i = 0; i < x_rank; ++i) {
+        if (static_cast<int>(i) != axis) {
+          out_dim_vector.emplace_back(x_shape[i]);
+        }
+      }
+    }
+  }
+
+  infer_context->SetShapeOrDataForValue(
+      op->result(0),
+      symbol::ShapeOrDataDimExprs{
+          symbol::TensorShapeOrDataDimExprs(out_dim_vector)});
+
+  return true;
+}
 
 // bool PartialSumOpInferSymbolicShape(pir::Operation *op,
 //                                     pir::InferSymbolicShapeContext
diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h
index 9ffec0c4edb10..fe09591a9f8dc 100644
--- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h
+++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h
@@ -79,10 +79,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(MeanAll)
 // OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool3DWithIndex)
 // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Multinomial)
 // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nanmedian)
-// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Norm)
+OP_DECLARE_INFER_SYMBOLIC_SHAPE(Norm)
 OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero)
 OP_DECLARE_INFER_SYMBOLIC_SHAPE(Numel)
-// OP_DECLARE_INFER_SYMBOLIC_SHAPE(P_Norm)
+OP_DECLARE_INFER_SYMBOLIC_SHAPE(P_Norm)
 // OP_DECLARE_INFER_SYMBOLIC_SHAPE(PartialSum)
 OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad)
 OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d)
diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml
index 319d8bea6d9cd..03555972ad86f 100755
--- a/paddle/phi/ops/yaml/ops.yaml
+++ b/paddle/phi/ops/yaml/ops.yaml
@@ -3367,6 +3367,7 @@
   kernel :
     func : norm
   backward : norm_grad
+  interfaces : paddle::dialect::InferSymbolicShapeInterface
 
 - op : npu_identity
   args : (Tensor x, int format = -1)
@@ -3428,6 +3429,7 @@
   kernel :
     func : p_norm
   backward : p_norm_grad
+  interfaces : paddle::dialect::InferSymbolicShapeInterface
 
 - op : pad
   args : (Tensor x, int[] paddings, Scalar pad_value)

From 407ee474ed9e71ff241e8166244f1fa8ebb37250 Mon Sep 17 00:00:00 2001
From: MufanColin <76479709+MufanColin@users.noreply.github.com>
Date: Wed, 7 Aug 2024 15:27:44 +0800
Subject: [PATCH 2/8] Fixed name errors

---
 .../interface/infer_symbolic_shape/unary_infer_sym.cc         | 4 ++--
 .../operator/interface/infer_symbolic_shape/unary_infer_sym.h | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
index 66952b2985671..116938b6d8af6 100644
--- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
+++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
@@ -1220,8 +1220,8 @@ bool NumelOpInferSymbolicShape(pir::Operation *op,
   return true;
 }
 
-bool P_NormOpInferSymbolicShape(pir::Operation *op,
-                                pir::InferSymbolicShapeContext *infer_context) {
+bool PNormOpInferSymbolicShape(pir::Operation *op,
+                               pir::InferSymbolicShapeContext *infer_context) {
   auto x_shape_or_data =
       infer_context->GetShapeOrDataForValue(op->operand_source(0));
   const auto &x_shape = x_shape_or_data.shape();
diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h
index fe09591a9f8dc..a34f4768fbe13 100644
--- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h
+++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h
@@ -82,7 +82,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(MeanAll)
 OP_DECLARE_INFER_SYMBOLIC_SHAPE(Norm)
 OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero)
 OP_DECLARE_INFER_SYMBOLIC_SHAPE(Numel)
-OP_DECLARE_INFER_SYMBOLIC_SHAPE(P_Norm)
+OP_DECLARE_INFER_SYMBOLIC_SHAPE(PNorm)
 // OP_DECLARE_INFER_SYMBOLIC_SHAPE(PartialSum)
 OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad)
 OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d)

From 85ce4b32268e2302501f1351f86ea84b858f2bd4 Mon Sep 17 00:00:00 2001
From: MufanColin <76479709+MufanColin@users.noreply.github.com>
Date: Wed, 7 Aug 2024 17:15:56 +0800
Subject: [PATCH 3/8] Removed unused variables

---
 .../operator/interface/infer_symbolic_shape/unary_infer_sym.cc | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
index 116938b6d8af6..6a9e3a91ed066 100644
--- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
+++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
@@ -1169,7 +1169,6 @@ bool NormOpInferSymbolicShape(pir::Operation *op,
       symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)});
 
   int axis = op->attribute<pir::Int32Attribute>("axis").data();
-  float epsilon = op->attribute<pir::FloatAttribute>("epsilon").data();
   bool is_test = op->attribute<pir::BoolAttribute>("is_test").data();
 
   if (!is_test) {
@@ -1227,9 +1226,7 @@ bool PNormOpInferSymbolicShape(pir::Operation *op,
   const auto &x_shape = x_shape_or_data.shape();
   auto x_rank = x_shape.size();
 
-  float porder = op->attribute<pir::FloatAttribute>("porder").data();
   int axis = op->attribute<pir::Int32Attribute>("axis").data();
-  float epsilon = op->attribute<pir::FloatAttribute>("epsilon").data();
   bool keepdim = op->attribute<pir::BoolAttribute>("keepdim").data();
   bool asvector = op->attribute<pir::BoolAttribute>("asvector").data();
 

From 5ffe7905ebf20e5f0d66f271e8b3a1cbd3e5236f Mon Sep 17 00:00:00 2001
From: MufanColin <76479709+MufanColin@users.noreply.github.com>
Date: Fri, 9 Aug 2024 12:36:29 +0800
Subject: [PATCH 4/8] Updated norm and pnorm in unary_infer_sym.cc according to
 suggested changes

---
 .../infer_symbolic_shape/unary_infer_sym.cc   | 79 ++++++++++---------
 1 file changed, 42 insertions(+), 37 deletions(-)

diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
index 6a9e3a91ed066..7d4b8080c36ac 100644
--- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
+++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
@@ -1174,12 +1174,12 @@ bool NormOpInferSymbolicShape(pir::Operation *op,
   if (!is_test) {
     if (axis < 0) axis += x_shape.size();
 
-    auto norm_shape = x_shape;
-    norm_shape[axis] = symbol::DimExpr(1);
+    // Directly modify x_shape at the specified axis.
     infer_context->SetShapeOrDataForValue(
         op->result(1),
         symbol::ShapeOrDataDimExprs{
-            symbol::TensorShapeOrDataDimExprs(norm_shape)});
+            symbol::TensorShapeOrDataDimExprs(x_shape).SetDim(
+                axis, symbol::DimExpr(1))});
   }
 
   return true;
@@ -1221,59 +1221,64 @@ bool NumelOpInferSymbolicShape(pir::Operation *op,
 
 bool PNormOpInferSymbolicShape(pir::Operation *op,
                                pir::InferSymbolicShapeContext *infer_context) {
-  auto x_shape_or_data =
+  const auto &x_shape_or_data =
       infer_context->GetShapeOrDataForValue(op->operand_source(0));
   const auto &x_shape = x_shape_or_data.shape();
-  auto x_rank = x_shape.size();
+  int x_rank = x_shape.size();
 
   int axis = op->attribute<pir::Int32Attribute>("axis").data();
   bool keepdim = op->attribute<pir::BoolAttribute>("keepdim").data();
   bool asvector = op->attribute<pir::BoolAttribute>("asvector").data();
 
-  PADDLE_ENFORCE_GE(axis,
-                    -x_rank,
-                    common::errors::InvalidArgument(
-                        "Attr(axis) value should be in range [-R, R-1], R is "
-                        "the rank of Input(X). But received axis: %d, R: %d. "
-                        "Current Input(X)'s shape is=[%s].",
-                        axis,
-                        x_rank,
-                        x_shape));
-  PADDLE_ENFORCE_LT(axis,
-                    x_rank,
-                    common::errors::InvalidArgument(
-                        "Attr(axis) value should be in range [-R, R-1], R is "
-                        "the rank of Input(X). But received axis: %d, R: %d. "
-                        "Current Input(X)'s shape is=[%s].",
-                        axis,
-                        x_rank,
-                        x_shape));
+  if (axis < 0) {
+    axis += x_rank;
+  }
+
+  PADDLE_ENFORCE_GE(
+      axis,
+      0,
+      common::errors::InvalidArgument(
+          "Attr(axis) value should be in range [-R, R-1], R is the rank of "
+          "Input(X). "
+          "But received axis: %d, R: %d. Current Input(X)'s shape is=[%s].",
+          axis,
+          x_rank,
+          x_shape));
+
+  PADDLE_ENFORCE_LT(
+      axis,
+      x_rank,
+      common::errors::InvalidArgument(
+          "Attr(axis) value should be in range [-R, R-1], R is the rank of "
+          "Input(X). "
+          "But received axis: %d, R: %d. Current Input(X)'s shape is=[%s].",
+          axis,
+          x_rank,
+          x_shape));
 
-  std::vector<symbol::DimExpr> out_dim_vector;
+  std::vector<symbol::DimExpr> out_shape;
 
   if (asvector) {
     if (keepdim) {
-      for (size_t i = 0; i < x_rank; ++i) {
-        out_dim_vector.emplace_back(symbol::DimExpr(1));
+      for (int i = 0; i < x_rank; ++i) {
+        out_shape.emplace_back(symbol::DimExpr(1));
       }
     } else {
-      out_dim_vector = {};
+      out_shape = {};
     }
   } else {
-    if (axis < 0) axis += x_rank;
-
     if (keepdim) {
-      for (size_t i = 0; i < x_rank; ++i) {
-        if (static_cast<int>(i) == axis) {
-          out_dim_vector.emplace_back(symbol::DimExpr(1));
+      for (int i = 0; i < x_rank; ++i) {
+        if (i == axis) {
+          out_shape.emplace_back(symbol::DimExpr(1));
         } else {
-          out_dim_vector.emplace_back(x_shape[i]);
+          out_shape.emplace_back(x_shape[i]);
         }
       }
     } else {
-      for (size_t i = 0; i < x_rank; ++i) {
-        if (static_cast<int>(i) != axis) {
-          out_dim_vector.emplace_back(x_shape[i]);
+      for (int i = 0; i < x_rank; ++i) {
+        if (i != axis) {
+          out_shape.emplace_back(x_shape[i]);
         }
       }
     }
@@ -1282,7 +1287,7 @@ bool PNormOpInferSymbolicShape(pir::Operation *op,
   infer_context->SetShapeOrDataForValue(
       op->result(0),
       symbol::ShapeOrDataDimExprs{
-          symbol::TensorShapeOrDataDimExprs(out_dim_vector)});
+          symbol::TensorShapeOrDataDimExprs(out_shape)});
 
   return true;
 }

From 9919dbdec712fe80245ae94e5614a857fd55745f Mon Sep 17 00:00:00 2001
From: MufanColin <76479709+MufanColin@users.noreply.github.com>
Date: Fri, 9 Aug 2024 12:36:59 +0800
Subject: [PATCH 5/8] Removed comments

---
 .../operator/interface/infer_symbolic_shape/unary_infer_sym.cc   | 1 -
 1 file changed, 1 deletion(-)

diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
index 7d4b8080c36ac..e876f00cc8531 100644
--- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
+++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
@@ -1174,7 +1174,6 @@ bool NormOpInferSymbolicShape(pir::Operation *op,
   if (!is_test) {
     if (axis < 0) axis += x_shape.size();
 
-    // Directly modify x_shape at the specified axis.
     infer_context->SetShapeOrDataForValue(
         op->result(1),
         symbol::ShapeOrDataDimExprs{

From 19af8426ffb1af45dc0a314df39a81c8c7b83e93 Mon Sep 17 00:00:00 2001
From: MufanColin <76479709+MufanColin@users.noreply.github.com>
Date: Fri, 9 Aug 2024 16:28:54 +0800
Subject: [PATCH 6/8] Fixed errors returned by CI

---
 .../interface/infer_symbolic_shape/unary_infer_sym.cc         | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
index e876f00cc8531..13f941b8d5b2d 100644
--- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
+++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
@@ -1174,11 +1174,11 @@ bool NormOpInferSymbolicShape(pir::Operation *op,
   if (!is_test) {
     if (axis < 0) axis += x_shape.size();
 
+    x_shape[axis] = symbol::DimExpr(1);
     infer_context->SetShapeOrDataForValue(
         op->result(1),
         symbol::ShapeOrDataDimExprs{
-            symbol::TensorShapeOrDataDimExprs(x_shape).SetDim(
-                axis, symbol::DimExpr(1))});
+            symbol::TensorShapeOrDataDimExprs(x_shape)});
   }
 
   return true;

From 08663f196a4d14128a894ed62b978b0259b1af4d Mon Sep 17 00:00:00 2001
From: MufanColin <76479709+MufanColin@users.noreply.github.com>
Date: Fri, 9 Aug 2024 16:46:32 +0800
Subject: [PATCH 7/8] Put some old code back

---
 .../interface/infer_symbolic_shape/unary_infer_sym.cc        | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
index 13f941b8d5b2d..b7df8e607d0f8 100644
--- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
+++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
@@ -1174,11 +1174,12 @@ bool NormOpInferSymbolicShape(pir::Operation *op,
   if (!is_test) {
     if (axis < 0) axis += x_shape.size();
 
-    x_shape[axis] = symbol::DimExpr(1);
+    auto norm_shape = x_shape;
+    norm_shape[axis] = symbol::DimExpr(1);
     infer_context->SetShapeOrDataForValue(
         op->result(1),
         symbol::ShapeOrDataDimExprs{
-            symbol::TensorShapeOrDataDimExprs(x_shape)});
+            symbol::TensorShapeOrDataDimExprs(norm_shape)});
   }
 
   return true;

From 595694cde0b1612dcaeffc67d16a2e51afb36206 Mon Sep 17 00:00:00 2001
From: MufanColin <76479709+MufanColin@users.noreply.github.com>
Date: Fri, 9 Aug 2024 16:59:00 +0800
Subject: [PATCH 8/8] Use a bool variable to make life easier

---
 .../infer_symbolic_shape/unary_infer_sym.cc     | 17 ++++-------------
 1 file changed, 4 insertions(+), 13 deletions(-)

diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
index b7df8e607d0f8..1fab7a25b1cab 100644
--- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
+++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc
@@ -1234,20 +1234,11 @@ bool PNormOpInferSymbolicShape(pir::Operation *op,
     axis += x_rank;
   }
 
-  PADDLE_ENFORCE_GE(
-      axis,
-      0,
-      common::errors::InvalidArgument(
-          "Attr(axis) value should be in range [-R, R-1], R is the rank of "
-          "Input(X). "
-          "But received axis: %d, R: %d. Current Input(X)'s shape is=[%s].",
-          axis,
-          x_rank,
-          x_shape));
+  bool axis_valid = (axis >= 0) && (axis < x_rank);
 
-  PADDLE_ENFORCE_LT(
-      axis,
-      x_rank,
+  PADDLE_ENFORCE_EQ(
+      axis_valid,
+      true,
       common::errors::InvalidArgument(
           "Attr(axis) value should be in range [-R, R-1], R is the rank of "
           "Input(X). "