From 44f89e62faca28d7e8375bc1ca191d2f754fc467 Mon Sep 17 00:00:00 2001
From: YuanRisheng <yuanrisheng@baidu.com>
Date: Wed, 16 Feb 2022 09:56:10 +0000
Subject: [PATCH 1/6] remove infershape and Xshape

---
 paddle/fluid/framework/infershape_utils.cc | 82 +++++++++++++++++-----
 paddle/fluid/operators/reshape_op.cc       | 38 ++++------
 paddle/pten/core/infermeta_utils.h         |  1 -
 paddle/pten/infermeta/unary.cc             | 21 ++++--
 paddle/pten/infermeta/unary.h              |  3 +-
 paddle/pten/kernels/reshape_kernel.cc      | 25 -------
 paddle/pten/kernels/reshape_kernel.h       |  7 --
 python/paddle/utils/code_gen/api.yaml      |  1 +
 8 files changed, 99 insertions(+), 79 deletions(-)

diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc
index bc0344d405cf7..0284169b4429e 100644
--- a/paddle/fluid/framework/infershape_utils.cc
+++ b/paddle/fluid/framework/infershape_utils.cc
@@ -20,6 +20,7 @@ limitations under the License. */
 #include "paddle/fluid/framework/framework.pb.h"
 #include "paddle/fluid/framework/pten_utils.h"
 #include "paddle/fluid/platform/enforce.h"
+#include "paddle/pten/common/scalar_array.h"
 #include "paddle/pten/core/compat/arg_map_context.h"
 #include "paddle/pten/core/compat/convert_utils.h"
 #include "paddle/pten/core/compat/op_utils.h"
@@ -54,7 +55,12 @@ class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext {
   }
 
   size_t InputSize(const std::string& name) const override {
-    return ctx_.Inputs(name).size();
+    if (ctx_.HasInputs(name)) {
+      return ctx_.Inputs(name).size();
+    } else if (ctx_.HasInput(name)) {
+      return 1;
+    }
+    return 0;
   }
 
   size_t OutputSize(const std::string& name) const override {
@@ -288,6 +294,15 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
   auto& attr_names = std::get<1>(signature.args);
   auto& output_names = std::get<2>(signature.args);
 
+  auto kernels = pten::KernelFactory::Instance().kernels().find(signature.name);
+  if (kernels == pten::KernelFactory::Instance().kernels().end()) {
+    PADDLE_THROW(
+        platform::errors::Unimplemented("Not find `%s` kernels when construct "
+                                        "InferMetaContext.",
+                                        signature.name));
+  }
+  auto attr_defs = kernels->second.cbegin()->second.args_def().attribute_defs();
+
   // TODO(chenweihang): support multiple inputs and outputs later
   pten::InferMetaContext infer_mete_context;
   for (auto& in_name : input_names) {
@@ -299,11 +314,55 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
     }
   }
 
-  auto attr_reader = ctx->Attrs();
-  for (auto& attr_name : attr_names) {
-    if (ctx->HasAttr(attr_name)) {
-      auto& attr = attr_reader.GetAttr(attr_name);
-      if (std::type_index(attr.type()) == std::type_index(typeid(bool))) {
+  for (auto& out_name : output_names) {
+    if (ctx->HasOutput(out_name)) {
+      infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
+          ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
+    } else {
+      infer_meta_context.EmplaceBackOutput({nullptr});
+    }
+  }
+  for (size_t i = 0; i < attr_names.size(); ++i) {
+    auto attr_name = attr_names[i];
+    // When attr is a vector_tensor or tensor, transform it to ScalarArray
+    if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
+      const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
+      if (ctx->IsRuntime()) {
+        std::vector<Variable*> vars;
+        vars.reserve(infershape_inputs.size());
+        for (size_t i = 0; i < infershape_inputs.size(); i++) {
+          vars.push_back(BOOST_GET_CONST(Variable*, infershape_inputs[i]));
+        }
+
+        if (infershape_inputs.size() != 1) {
+          infer_meta_context.EmplaceBackAttr(
+              std::move(experimental::MakePtenScalarArrayFromVarList(vars)));
+        } else {
+          infer_meta_context.EmplaceBackAttr(
+              std::move(experimental::MakePtenScalarArrayFromVar(*vars[0])));
+        }
+      } else {
+        VLOG(6) << "Not in Runtime, the Attr( " << attr_name
+                << ") ScalarArray value will be set empty";
+        infer_meta_context.EmplaceBackAttr(std::move(pten::ScalarArray()));
+      }
+    } else {
+      // Emplace Back Attr according to the type of attr.
+      auto& attr = ctx->Attrs().GetAttr(attr_name);
+      if (attr_defs[i].type_index ==
+          std::type_index(typeid(pten::ScalarArray))) {
+        if (std::type_index(attr.type()) ==
+            std::type_index(typeid(std::vector<int32_t>))) {
+          infer_meta_context.EmplaceBackAttr(std::move(
+              pten::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
+        } else {
+          PADDLE_THROW(platform::errors::Unimplemented(
+              "Unsupported cast op attribute `%s` to ScalarArray when "
+              "construct KernelContext.",
+              attr_name));
+        }
+      } else if (std::type_index(attr.type()) ==
+                 std::type_index(typeid(bool))) {
         infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
       } else if (std::type_index(attr.type()) == std::type_index(typeid(int))) {
         infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr));
@@ -345,17 +404,6 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
             "Unsupported attribute type is received when call "
             "InferShapeFunctor."));
       }
-    } else {
-      // do nothing
-    }
-  }
-
-  for (auto& out_name : output_names) {
-    if (ctx->HasOutput(out_name)) {
-      infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
-          ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
-    } else {
-      infer_meta_context.EmplaceBackOutput({nullptr});
     }
   }
 
diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc
index 74095d2ce4e65..4a4b1c14fd2f3 100644
--- a/paddle/fluid/operators/reshape_op.cc
+++ b/paddle/fluid/operators/reshape_op.cc
@@ -14,6 +14,7 @@ limitations under the License. */
 
 #include <string>
 
+#include "paddle/fluid/framework/infershape_utils.h"
 #include "paddle/fluid/framework/op_registry.h"
 #include "paddle/fluid/framework/pten_utils.h"
 
@@ -21,8 +22,11 @@ limitations under the License. */
 #include "paddle/pten/api/lib/utils/tensor_utils.h"
 #include "paddle/pten/backends/cpu/cpu_context.h"
 #include "paddle/pten/common/scalar_array.h"
+#include "paddle/pten/core/infermeta_utils.h"
+#include "paddle/pten/infermeta/unary.h"
 #include "paddle/pten/kernels/reshape_grad_kernel.h"
 #include "paddle/pten/kernels/reshape_kernel.h"
+
 namespace paddle {
 namespace framework {
 class InferShapeContext;
@@ -472,22 +476,6 @@ class Reshape2Op : public ReshapeOp {
              const framework::VariableNameMap &outputs,
              const framework::AttributeMap &attrs)
       : ReshapeOp(type, inputs, outputs, attrs) {}
-
-  void InferShape(framework::InferShapeContext *ctx) const override {
-    PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
-                      platform::errors::InvalidArgument(
-                          "Output(XShape) of ReshapeOp should not be null."));
-    const auto &x_dims = ctx->GetInputDim("X");
-    std::vector<int64_t> xshape_dims(x_dims.size() + 1);
-    xshape_dims[0] = 0;
-    for (int i = 0; i < x_dims.size(); ++i) {
-      xshape_dims[i + 1] = x_dims[i];
-    }
-    ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
-    ctx->ShareLoD("X", /*->*/ "XShape");
-
-    ReshapeOp::InferShape(ctx);
-  }
 };
 
 class Reshape2OpMaker : public ReshapeOpMaker {
@@ -519,7 +507,7 @@ class Reshape2GradMaker : public framework::SingleGradOpMaker<T> {
 
   void Apply(GradOpPtr<T> grad_op) const override {
     grad_op->SetType("reshape2_grad");
-    grad_op->SetInput("XShape", this->Output("XShape"));
+    grad_op->SetInput("X", this->Input("X"));
     grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
     grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
     grad_op->SetAttrMap(this->Attrs());
@@ -550,15 +538,13 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
 
   void InferShape(framework::InferShapeContext *ctx) const override {
     PADDLE_ENFORCE_EQ(
-        ctx->HasInput("XShape"), true,
-        platform::errors::InvalidArgument("Input(XShape) shouldn't be null."));
+        ctx->HasInput("X"), true,
+        platform::errors::InvalidArgument("Input(X) shouldn't be null."));
     PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
                       platform::errors::InvalidArgument(
                           "Input(Out@GRAD) shouldn't be null."));
-    auto xshape_dims = ctx->GetInputDim("XShape");
-    auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
-    ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
-    ctx->ShareLoD("XShape", framework::GradVarName("X"));
+    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
+    ctx->ShareLoD("X", framework::GradVarName("X"));
   }
 
  protected:
@@ -645,10 +631,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
                                double, ops::ReshapeGradKernel, int,
                                ops::ReshapeGradKernel, int64_t,
                                ops::ReshapeGradKernel);
+
+DELCARE_INFER_SHAPE_FUNCTOR(reshape2, ReshapeInferShapeFunctor,
+                            PT_INFER_META(pten::ReshapeInferMeta));
+
 REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
                   ops::Reshape2GradMaker<paddle::framework::OpDesc>,
                   ops::Reshape2GradMaker<paddle::imperative::OpBase>,
-                  ops::ReshapeOpInplaceInferer);
+                  ReshapeInferShapeFunctor, ops::ReshapeOpInplaceInferer);
 REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,
                   ops::Reshape2DoubleGradMaker<paddle::framework::OpDesc>,
                   ops::Reshape2DoubleGradMaker<paddle::imperative::OpBase>,
diff --git a/paddle/pten/core/infermeta_utils.h b/paddle/pten/core/infermeta_utils.h
index 59d2a4ed3c089..e6d8b1499071c 100644
--- a/paddle/pten/core/infermeta_utils.h
+++ b/paddle/pten/core/infermeta_utils.h
@@ -27,7 +27,6 @@ limitations under the License. */
 #include "paddle/pten/core/type_defs.h"
 #include "paddle/utils/flat_hash_map.h"
 #include "paddle/utils/small_vector.h"
-
 namespace pten {
 
 class InferMetaContext {
diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc
index ca59937399a22..b2f29b7eb6a6f 100644
--- a/paddle/pten/infermeta/unary.cc
+++ b/paddle/pten/infermeta/unary.cc
@@ -15,8 +15,8 @@ limitations under the License. */
 #include "paddle/pten/infermeta/unary.h"
 
 #include <set>
-
 #include "paddle/pten/common/data_type.h"
+#include "paddle/pten/core/enforce.h"
 #include "paddle/pten/core/infermeta_utils.h"
 
 namespace pten {
@@ -213,7 +213,7 @@ void InferMetaFromVecValue(const MetaTensor& x,
                            MetaTensor* out) {
   PADDLE_ENFORCE_EQ(!shape.empty(),
                     true,
-                    paddle::platform::errors::InvalidArgument(
+                    pten::errors::InvalidArgument(
                         "The parameter 'shape' in ReshapeOp must be set. "
                         "But received 'shape' is empty."));
   auto x_dims = x.dims();
@@ -230,8 +230,21 @@ void InferMetaFromVecValue(const MetaTensor& x,
 
 void ReshapeInferMeta(const MetaTensor& x,
                       const ScalarArray& shape,
-                      MetaTensor* out) {
-  InferMetaFromVecValue(x, shape.GetData(), out);
+                      MetaTensor* out,
+                      MetaConfig config) {
+  auto& shape_data = shape.GetData();
+  PADDLE_ENFORCE_NOT_NULL(out,
+                          pten::errors::InvalidArgument(
+                              "Output(Out) of ReshapeOp should not be null."));
+  if (!config.is_runtime && shape_data.size() == 0) {
+    out->share_lod(x);
+    return;
+  }
+  PADDLE_ENFORCE_GT(shape_data.size(),
+                    0,
+                    pten::errors::InvalidArgument(
+                        "The shape's size in ReshapeOp can't be zero."));
+  InferMetaFromVecValue(x, shape_data, out);
 }
 
 /*  Why not use ReduceInferMeta directly?
diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h
index 4c816c4adbc23..19841c2c6c938 100644
--- a/paddle/pten/infermeta/unary.h
+++ b/paddle/pten/infermeta/unary.h
@@ -52,7 +52,8 @@ void InferMetaFromVecValue(const MetaTensor& x,
 
 void ReshapeInferMeta(const MetaTensor& x,
                       const ScalarArray& shape,
-                      MetaTensor* out);
+                      MetaTensor* out,
+                      MetaConfig config = MetaConfig());
 
 void ReduceInferMetaBase(const MetaTensor& x,
                          const std::vector<int64_t>& axis,
diff --git a/paddle/pten/kernels/reshape_kernel.cc b/paddle/pten/kernels/reshape_kernel.cc
index c52d251582bb5..f25d1d63388c9 100644
--- a/paddle/pten/kernels/reshape_kernel.cc
+++ b/paddle/pten/kernels/reshape_kernel.cc
@@ -41,16 +41,6 @@ void ReshapeKernel(const Context& dev_ctx,
   out->ResetLoD(x.lod());
 }
 
-template <typename Context>
-void ReshapeWithXShape(const Context& dev_ctx,
-                       const DenseTensor& x,
-                       const ScalarArray& shape,
-                       DenseTensor* xshape,
-                       DenseTensor* out) {
-  funcs::SetXShape(x, xshape);
-  ReshapeKernel(dev_ctx, x, shape, out);
-}
-
 }  // namespace pten
 
 PT_REGISTER_GENERAL_KERNEL(reshape,
@@ -58,11 +48,6 @@ PT_REGISTER_GENERAL_KERNEL(reshape,
                            ALL_LAYOUT,
                            pten::ReshapeKernel<pten::CPUContext>,
                            ALL_DTYPE) {}
-PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
-                           CPU,
-                           ALL_LAYOUT,
-                           pten::ReshapeWithXShape<pten::CPUContext>,
-                           ALL_DTYPE) {}
 
 #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
 PT_REGISTER_GENERAL_KERNEL(reshape,
@@ -70,11 +55,6 @@ PT_REGISTER_GENERAL_KERNEL(reshape,
                            ALL_LAYOUT,
                            pten::ReshapeKernel<pten::GPUContext>,
                            ALL_DTYPE) {}
-PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
-                           GPU,
-                           ALL_LAYOUT,
-                           pten::ReshapeWithXShape<pten::GPUContext>,
-                           ALL_DTYPE) {}
 #endif
 
 #ifdef PADDLE_WITH_XPU
@@ -83,9 +63,4 @@ PT_REGISTER_GENERAL_KERNEL(reshape,
                            ALL_LAYOUT,
                            pten::ReshapeKernel<pten::XPUContext>,
                            ALL_DTYPE) {}
-PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
-                           XPU,
-                           ALL_LAYOUT,
-                           pten::ReshapeWithXShape<pten::XPUContext>,
-                           ALL_DTYPE) {}
 #endif
diff --git a/paddle/pten/kernels/reshape_kernel.h b/paddle/pten/kernels/reshape_kernel.h
index a5672ad6e5b04..75f3e6595472b 100644
--- a/paddle/pten/kernels/reshape_kernel.h
+++ b/paddle/pten/kernels/reshape_kernel.h
@@ -27,13 +27,6 @@ void ReshapeKernel(const Context& dev_ctx,
                    const ScalarArray& shape,
                    DenseTensor* out);
 
-template <typename Context>
-void ReshapeWithXShape(const Context& dev_ctx,
-                       const DenseTensor& x,
-                       const ScalarArray& shape,
-                       DenseTensor* xshape,
-                       DenseTensor* out);
-
 template <typename T, typename Context>
 DenseTensor Reshape(const Context& dev_ctx,
                     const DenseTensor& x,
diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml
index 66411d00f1517..27db0b65e1a11 100644
--- a/python/paddle/utils/code_gen/api.yaml
+++ b/python/paddle/utils/code_gen/api.yaml
@@ -145,6 +145,7 @@
   output : Tensor
   infer_meta :
     func : ReshapeInferMeta
+    param : [x, shape]
   kernel :
     func : reshape
 

From d099835da504a7c2f8aee65c4ebb3b4d836dbdfe Mon Sep 17 00:00:00 2001
From: YuanRisheng <yuanrisheng@baidu.com>
Date: Thu, 17 Feb 2022 10:14:36 +0000
Subject: [PATCH 2/6] add xshape

---
 paddle/fluid/framework/infershape_utils.cc    |  5 ++--
 .../fluid/framework/infershape_utils_test.cc  | 17 +++++++++++++
 paddle/fluid/operators/reshape_op.cc          | 14 ++++++-----
 paddle/pten/core/kernel_utils.h               |  4 ++++
 paddle/pten/infermeta/unary.cc                | 20 ++++++++++++++++
 paddle/pten/infermeta/unary.h                 |  6 +++++
 paddle/pten/kernels/reshape_kernel.cc         | 24 +++++++++++++++++++
 paddle/pten/kernels/reshape_kernel.h          |  7 ++++++
 paddle/pten/ops/compat/reshape_sig.cc         | 23 ++++++++++++++----
 9 files changed, 107 insertions(+), 13 deletions(-)

diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc
index 0284169b4429e..601b425b7f474 100644
--- a/paddle/fluid/framework/infershape_utils.cc
+++ b/paddle/fluid/framework/infershape_utils.cc
@@ -322,6 +322,7 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
       infer_meta_context.EmplaceBackOutput({nullptr});
     }
   }
+  auto attr_reader = ctx->Attrs();
   for (size_t i = 0; i < attr_names.size(); ++i) {
     auto attr_name = attr_names[i];
     // When attr is a vector_tensor or tensor, transform it to ScalarArray
@@ -346,9 +347,9 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
                 << ") ScalarArray value will be set empty";
         infer_meta_context.EmplaceBackAttr(std::move(pten::ScalarArray()));
       }
-    } else {
+    } else if (ctx->HasAttr(attr_name)) {
       // Emplace Back Attr according to the type of attr.
-      auto& attr = ctx->Attrs().GetAttr(attr_name);
+      auto& attr = attr_reader.GetAttr(attr_name);
       if (attr_defs[i].type_index ==
           std::type_index(typeid(pten::ScalarArray))) {
         if (std::type_index(attr.type()) ==
diff --git a/paddle/fluid/framework/infershape_utils_test.cc b/paddle/fluid/framework/infershape_utils_test.cc
index 755ca3f5ce90b..fc816f6f4a1e1 100644
--- a/paddle/fluid/framework/infershape_utils_test.cc
+++ b/paddle/fluid/framework/infershape_utils_test.cc
@@ -23,8 +23,11 @@ limitations under the License. */
 #include "paddle/fluid/framework/op_registry.h"
 #include "paddle/fluid/framework/operator.h"
 #include "paddle/fluid/framework/program_desc.h"
+#include "paddle/pten/backends/cpu/cpu_context.h"
 #include "paddle/pten/core/compat/op_utils.h"
+#include "paddle/pten/core/dense_tensor.h"
 #include "paddle/pten/core/infermeta_utils.h"
+#include "paddle/pten/core/kernel_registry.h"
 
 namespace paddle {
 namespace framework {
@@ -93,6 +96,17 @@ pten::KernelSignature InferShapeUtilsTestOpArgumentMapping(
       {});
 }
 
+template <typename T, typename Context>
+void InferShapeUtilsTestKernel(
+    const Context& dev_ctx, const pten::DenseTensor& x, bool attr1, int attr2,
+    int64_t attr3, float attr4, const std::string& attr5,
+    const std::vector<bool>& attr6, const std::vector<int>& attr7,
+    const std::vector<int64_t>& attr8, const std::vector<float>& attr9,
+    const std::vector<double>& attr10, const std::vector<std::string>& attr11,
+    pten::DenseTensor* out) {
+  VLOG(6) << "Come into InferShapeUtilsTestKernel";
+}
+
 }  // namespace framework
 }  // namespace paddle
 
@@ -104,6 +118,9 @@ REGISTER_OPERATOR(infer_shape_utils_test,
                   paddle::framework::InferShapeUtilsTestOpMaker,
                   InferShapeUtilsTestInferShapeFunctor);
 
+PT_REGISTER_KERNEL(infer_shape_utils_test, CPU, ALL_LAYOUT,
+                   paddle::framework::InferShapeUtilsTestKernel, int) {}
+
 TEST(InferShapeUtilsTest, ALL) {
   paddle::framework::ProgramDesc prog;
   paddle::framework::proto::BlockDesc proto_block;
diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc
index 4a4b1c14fd2f3..3c3d00c9f1c91 100644
--- a/paddle/fluid/operators/reshape_op.cc
+++ b/paddle/fluid/operators/reshape_op.cc
@@ -507,7 +507,7 @@ class Reshape2GradMaker : public framework::SingleGradOpMaker<T> {
 
   void Apply(GradOpPtr<T> grad_op) const override {
     grad_op->SetType("reshape2_grad");
-    grad_op->SetInput("X", this->Input("X"));
+    grad_op->SetInput("XShape", this->Output("XShape"));
     grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
     grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
     grad_op->SetAttrMap(this->Attrs());
@@ -538,13 +538,15 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
 
   void InferShape(framework::InferShapeContext *ctx) const override {
     PADDLE_ENFORCE_EQ(
-        ctx->HasInput("X"), true,
-        platform::errors::InvalidArgument("Input(X) shouldn't be null."));
+        ctx->HasInput("XShape"), true,
+        platform::errors::InvalidArgument("Input(XShape) shouldn't be null."));
     PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
                       platform::errors::InvalidArgument(
                           "Input(Out@GRAD) shouldn't be null."));
-    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
-    ctx->ShareLoD("X", framework::GradVarName("X"));
+    auto xshape_dims = ctx->GetInputDim("XShape");
+    auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
+    ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
+    ctx->ShareLoD("XShape", framework::GradVarName("X"));
   }
 
  protected:
@@ -633,7 +635,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
                                ops::ReshapeGradKernel);
 
 DELCARE_INFER_SHAPE_FUNCTOR(reshape2, ReshapeInferShapeFunctor,
-                            PT_INFER_META(pten::ReshapeInferMeta));
+                            PT_INFER_META(pten::ReshapeWithXShapeInferMeta));
 
 REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
                   ops::Reshape2GradMaker<paddle::framework::OpDesc>,
diff --git a/paddle/pten/core/kernel_utils.h b/paddle/pten/core/kernel_utils.h
index 7c611d7eccd11..8049c6f2635da 100644
--- a/paddle/pten/core/kernel_utils.h
+++ b/paddle/pten/core/kernel_utils.h
@@ -239,6 +239,10 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
   PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&);
   PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
   PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::string&);
+  PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<bool>&);
+  PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<float>&);
+  PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<double>&);
+  PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<std::string>&);
 
   /* Output Helpers */
 
diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc
index b2f29b7eb6a6f..5e1ce67572e27 100644
--- a/paddle/pten/infermeta/unary.cc
+++ b/paddle/pten/infermeta/unary.cc
@@ -247,6 +247,26 @@ void ReshapeInferMeta(const MetaTensor& x,
   InferMetaFromVecValue(x, shape_data, out);
 }
 
+void ReshapeWithXShapeInferMeta(const MetaTensor& x,
+                                const ScalarArray& shape,
+                                MetaTensor* xshape,
+                                MetaTensor* out,
+                                MetaConfig config) {
+  PADDLE_ENFORCE_NOT_NULL(
+      xshape,
+      pten::errors::InvalidArgument(
+          "Output(XShape) of ReshapeOp should not be null."));
+  const auto& x_dims = x.dims();
+  std::vector<int64_t> xshape_dims(x_dims.size() + 1);
+  xshape_dims[0] = 0;
+  for (int i = 0; i < x_dims.size(); ++i) {
+    xshape_dims[i + 1] = x_dims[i];
+  }
+  xshape->set_dims(pten::framework::make_ddim(xshape_dims));
+  xshape->share_lod(x);
+  ReshapeInferMeta(x, shape, out, config);
+}
+
 /*  Why not use ReduceInferMeta directly?
     Because we need make InferMetaFunction's args follow the design of api.yaml
 */
diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h
index 19841c2c6c938..b5dca3203f484 100644
--- a/paddle/pten/infermeta/unary.h
+++ b/paddle/pten/infermeta/unary.h
@@ -55,6 +55,12 @@ void ReshapeInferMeta(const MetaTensor& x,
                       MetaTensor* out,
                       MetaConfig config = MetaConfig());
 
+void ReshapeWithXShapeInferMeta(const MetaTensor& x,
+                                const ScalarArray& shape,
+                                MetaTensor* xshape,
+                                MetaTensor* out,
+                                MetaConfig config = MetaConfig());
+
 void ReduceInferMetaBase(const MetaTensor& x,
                          const std::vector<int64_t>& axis,
                          bool keep_dim,
diff --git a/paddle/pten/kernels/reshape_kernel.cc b/paddle/pten/kernels/reshape_kernel.cc
index f25d1d63388c9..d9501d168b21c 100644
--- a/paddle/pten/kernels/reshape_kernel.cc
+++ b/paddle/pten/kernels/reshape_kernel.cc
@@ -41,6 +41,15 @@ void ReshapeKernel(const Context& dev_ctx,
   out->ResetLoD(x.lod());
 }
 
+template <typename Context>
+void ReshapeWithXShape(const Context& dev_ctx,
+                       const DenseTensor& x,
+                       const ScalarArray& shape,
+                       DenseTensor* xshape,
+                       DenseTensor* out) {
+  ReshapeKernel(dev_ctx, x, shape, out);
+}
+
 }  // namespace pten
 
 PT_REGISTER_GENERAL_KERNEL(reshape,
@@ -48,6 +57,11 @@ PT_REGISTER_GENERAL_KERNEL(reshape,
                            ALL_LAYOUT,
                            pten::ReshapeKernel<pten::CPUContext>,
                            ALL_DTYPE) {}
+PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
+                           CPU,
+                           ALL_LAYOUT,
+                           pten::ReshapeWithXShape<pten::CPUContext>,
+                           ALL_DTYPE) {}
 
 #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
 PT_REGISTER_GENERAL_KERNEL(reshape,
@@ -55,6 +69,11 @@ PT_REGISTER_GENERAL_KERNEL(reshape,
                            ALL_LAYOUT,
                            pten::ReshapeKernel<pten::GPUContext>,
                            ALL_DTYPE) {}
+PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
+                           GPU,
+                           ALL_LAYOUT,
+                           pten::ReshapeWithXShape<pten::GPUContext>,
+                           ALL_DTYPE) {}
 #endif
 
 #ifdef PADDLE_WITH_XPU
@@ -63,4 +82,9 @@ PT_REGISTER_GENERAL_KERNEL(reshape,
                            ALL_LAYOUT,
                            pten::ReshapeKernel<pten::XPUContext>,
                            ALL_DTYPE) {}
+PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
+                           XPU,
+                           ALL_LAYOUT,
+                           pten::ReshapeWithXShape<pten::XPUContext>,
+                           ALL_DTYPE) {}
 #endif
diff --git a/paddle/pten/kernels/reshape_kernel.h b/paddle/pten/kernels/reshape_kernel.h
index 75f3e6595472b..a5672ad6e5b04 100644
--- a/paddle/pten/kernels/reshape_kernel.h
+++ b/paddle/pten/kernels/reshape_kernel.h
@@ -27,6 +27,13 @@ void ReshapeKernel(const Context& dev_ctx,
                    const ScalarArray& shape,
                    DenseTensor* out);
 
+template <typename Context>
+void ReshapeWithXShape(const Context& dev_ctx,
+                       const DenseTensor& x,
+                       const ScalarArray& shape,
+                       DenseTensor* xshape,
+                       DenseTensor* out);
+
 template <typename T, typename Context>
 DenseTensor Reshape(const Context& dev_ctx,
                     const DenseTensor& x,
diff --git a/paddle/pten/ops/compat/reshape_sig.cc b/paddle/pten/ops/compat/reshape_sig.cc
index 823fb5d3cdd41..0316a73e55fa0 100644
--- a/paddle/pten/ops/compat/reshape_sig.cc
+++ b/paddle/pten/ops/compat/reshape_sig.cc
@@ -17,12 +17,25 @@ limitations under the License. */
 namespace pten {
 
 KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
-  if (ctx.InputSize("ShapeTensor") > 0) {
-    return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"});
-  } else if (ctx.HasInput("Shape")) {
-    return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"});
+  if (ctx.HasOutput("XShape")) {
+    if (ctx.InputSize("ShapeTensor") > 0) {
+      return KernelSignature(
+          "reshape_with_xshape", {"X"}, {"ShapeTensor"}, {"XShape", "Out"});
+    } else if (ctx.HasInput("Shape")) {
+      return KernelSignature(
+          "reshape_with_xshape", {"X"}, {"Shape"}, {"XShape", "Out"});
+    } else {
+      return KernelSignature(
+          "reshape_with_xshape", {"X"}, {"shape"}, {"XShape", "Out"});
+    }
   } else {
-    return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"});
+    if (ctx.InputSize("ShapeTensor") > 0) {
+      return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"});
+    } else if (ctx.HasInput("Shape")) {
+      return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"});
+    } else {
+      return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"});
+    }
   }
 }
 

From fcc29905b091b88da9e24a59fd701e3a219df3d7 Mon Sep 17 00:00:00 2001
From: YuanRisheng <yuanrisheng@baidu.com>
Date: Thu, 17 Feb 2022 13:27:34 +0000
Subject: [PATCH 3/6] fix bugs when run ci

---
 paddle/fluid/framework/infershape_utils.cc | 56 +++++++++++-----------
 1 file changed, 29 insertions(+), 27 deletions(-)

diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc
index 601b425b7f474..badaeccaf2daa 100644
--- a/paddle/fluid/framework/infershape_utils.cc
+++ b/paddle/fluid/framework/infershape_utils.cc
@@ -325,33 +325,31 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
   auto attr_reader = ctx->Attrs();
   for (size_t i = 0; i < attr_names.size(); ++i) {
     auto attr_name = attr_names[i];
-    // When attr is a vector_tensor or tensor, transform it to ScalarArray
-    if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
-      const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
-      if (ctx->IsRuntime()) {
-        std::vector<Variable*> vars;
-        vars.reserve(infershape_inputs.size());
-        for (size_t i = 0; i < infershape_inputs.size(); i++) {
-          vars.push_back(BOOST_GET_CONST(Variable*, infershape_inputs[i]));
-        }
-
-        if (infershape_inputs.size() != 1) {
-          infer_meta_context.EmplaceBackAttr(
-              std::move(experimental::MakePtenScalarArrayFromVarList(vars)));
+    if (attr_defs[i].type_index == std::type_index(typeid(pten::ScalarArray))) {
+      // When attr is a vector_tensor or tensor, transform it to ScalarArray
+      if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
+        const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
+        if (ctx->IsRuntime()) {
+          std::vector<Variable*> vars;
+          vars.reserve(infershape_inputs.size());
+          for (size_t i = 0; i < infershape_inputs.size(); i++) {
+            vars.push_back(BOOST_GET_CONST(Variable*, infershape_inputs[i]));
+          }
+
+          if (infershape_inputs.size() != 1) {
+            infer_meta_context.EmplaceBackAttr(
+                std::move(experimental::MakePtenScalarArrayFromVarList(vars)));
+          } else {
+            infer_meta_context.EmplaceBackAttr(
+                std::move(experimental::MakePtenScalarArrayFromVar(*vars[0])));
+          }
         } else {
-          infer_meta_context.EmplaceBackAttr(
-              std::move(experimental::MakePtenScalarArrayFromVar(*vars[0])));
+          VLOG(6) << "Not in Runtime, the Attr( " << attr_name
+                  << ") ScalarArray value will be set empty";
+          infer_meta_context.EmplaceBackAttr(std::move(pten::ScalarArray()));
         }
-      } else {
-        VLOG(6) << "Not in Runtime, the Attr( " << attr_name
-                << ") ScalarArray value will be set empty";
-        infer_meta_context.EmplaceBackAttr(std::move(pten::ScalarArray()));
-      }
-    } else if (ctx->HasAttr(attr_name)) {
-      // Emplace Back Attr according to the type of attr.
-      auto& attr = attr_reader.GetAttr(attr_name);
-      if (attr_defs[i].type_index ==
-          std::type_index(typeid(pten::ScalarArray))) {
+      } else if (ctx->HasAttr(attr_name)) {
+        auto& attr = attr_reader.GetAttr(attr_name);
         if (std::type_index(attr.type()) ==
             std::type_index(typeid(std::vector<int32_t>))) {
           infer_meta_context.EmplaceBackAttr(std::move(
@@ -362,8 +360,12 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
               "construct KernelContext.",
               attr_name));
         }
-      } else if (std::type_index(attr.type()) ==
-                 std::type_index(typeid(bool))) {
+      }
+
+    } else if (ctx->HasAttr(attr_name)) {
+      // Emplace Back Attr according to the type of attr.
+      auto& attr = attr_reader.GetAttr(attr_name);
+      if (std::type_index(attr.type()) == std::type_index(typeid(bool))) {
         infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
       } else if (std::type_index(attr.type()) == std::type_index(typeid(int))) {
         infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr));

From e593a1596de311b1a47b05bb7ebf55fb00b832cf Mon Sep 17 00:00:00 2001
From: YuanRisheng <yuanrisheng@baidu.com>
Date: Fri, 18 Feb 2022 07:34:23 +0000
Subject: [PATCH 4/6] fix bugs when run ci

---
 paddle/fluid/framework/infershape_utils.cc | 22 ++++++++++++++++++----
 paddle/pten/api/lib/utils/tensor_utils.cc  |  2 +-
 paddle/pten/common/scalar.h                |  6 +++---
 paddle/pten/common/scalar_array.h          | 10 +++++-----
 paddle/pten/infermeta/unary.cc             |  3 ++-
 paddle/pten/kernels/cpu/split_kernel.cc    |  2 +-
 paddle/pten/kernels/gpu/split_kernel.cu    |  2 +-
 7 files changed, 31 insertions(+), 16 deletions(-)

diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc
index badaeccaf2daa..89b25bfea28a9 100644
--- a/paddle/fluid/framework/infershape_utils.cc
+++ b/paddle/fluid/framework/infershape_utils.cc
@@ -330,12 +330,13 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
       if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
         const auto& infershape_inputs = ctx->GetInputVarPtrs(attr_name);
         if (ctx->IsRuntime()) {
+          // If is in runtime, we will get tensor's value for ScalarArray
+          // and push it into attrs
           std::vector<Variable*> vars;
           vars.reserve(infershape_inputs.size());
           for (size_t i = 0; i < infershape_inputs.size(); i++) {
             vars.push_back(BOOST_GET_CONST(Variable*, infershape_inputs[i]));
           }
-
           if (infershape_inputs.size() != 1) {
             infer_meta_context.EmplaceBackAttr(
                 std::move(experimental::MakePtenScalarArrayFromVarList(vars)));
@@ -344,9 +345,22 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
                 std::move(experimental::MakePtenScalarArrayFromVar(*vars[0])));
           }
         } else {
-          VLOG(6) << "Not in Runtime, the Attr( " << attr_name
-                  << ") ScalarArray value will be set empty";
-          infer_meta_context.EmplaceBackAttr(std::move(pten::ScalarArray()));
+          // If is not in runtime, we will set default value(-1) for ScalarArray
+          int64_t num_ele = 1;
+          std::vector<VarDesc*> vars;
+          vars.reserve(infershape_inputs.size());
+          for (size_t i = 0; i < infershape_inputs.size(); i++) {
+            vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i]));
+          }
+          for (auto& var : vars) {
+            const auto& tensor_dims = var->GetShape();
+            for (size_t i = 0; i < tensor_dims.size(); ++i) {
+              num_ele *= tensor_dims[i];
+            }
+          }
+          pten::ScalarArray tensor_attr(std::vector<int32_t>(num_ele, -1));
+          tensor_attr.SetInitFromTensor(true);
+          infer_meta_context.EmplaceBackAttr(std::move(tensor_attr));
         }
       } else if (ctx->HasAttr(attr_name)) {
         auto& attr = attr_reader.GetAttr(attr_name);
diff --git a/paddle/pten/api/lib/utils/tensor_utils.cc b/paddle/pten/api/lib/utils/tensor_utils.cc
index ea0e9d8cd3e0a..1f16773d3ef87 100644
--- a/paddle/pten/api/lib/utils/tensor_utils.cc
+++ b/paddle/pten/api/lib/utils/tensor_utils.cc
@@ -131,7 +131,7 @@ pten::ScalarArray MakePtenScalarArrayFromVarList(
   }
 
   pten::ScalarArray result{vector_data};
-  result.setInitByTensor(true);
+  result.SetInitFromTensor(true);
 
   return result;
 }
diff --git a/paddle/pten/common/scalar.h b/paddle/pten/common/scalar.h
index 0ab880d6218f8..1331c3365db48 100644
--- a/paddle/pten/common/scalar.h
+++ b/paddle/pten/common/scalar.h
@@ -25,7 +25,7 @@ namespace experimental {
 template <typename T>
 class ScalarBase {
  public:
-  bool IsInitByTensor() const { return is_init_by_tensor_; }
+  bool FromTensor() const { return is_from_tensor_; }
   // Constructor support implicit
   ScalarBase(double val) : dtype_(DataType::FLOAT64) {  // NOLINT
     data_.f64 = val;
@@ -104,7 +104,7 @@ class ScalarBase {
 
   // The Tensor must have one dim
   ScalarBase(const T& tensor) : dtype_(tensor.dtype()) {  // NOLINT
-    is_init_by_tensor_ = true;
+    is_from_tensor_ = true;
     PD_CHECK(
         tensor.numel() == 1,
         "The Scalar only supports Tensor with 1 element, but now Tensor has `",
@@ -196,7 +196,7 @@ class ScalarBase {
   friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst);
 
  private:
-  bool is_init_by_tensor_{false};
+  bool is_from_tensor_{false};
   DataType dtype_;
   union data {
     bool b;
diff --git a/paddle/pten/common/scalar_array.h b/paddle/pten/common/scalar_array.h
index dcc8ff6748b86..87fd7342c3624 100644
--- a/paddle/pten/common/scalar_array.h
+++ b/paddle/pten/common/scalar_array.h
@@ -43,13 +43,13 @@ class ScalarArrayBase {
     AssignData(date_value, n);
   }
 
-  bool IsInitByTensor() const { return is_init_by_tensor_; }
+  bool FromTensor() const { return is_from_tensor_; }
 
-  void setInitByTensor(bool val) { is_init_by_tensor_ = val; }
+  void SetInitFromTensor(bool val) { is_from_tensor_ = val; }
 
   // The Tensor must have one dim
   ScalarArrayBase(const T& tensor) {  // NOLINT
-    is_init_by_tensor_ = true;
+    is_from_tensor_ = true;
     size_t n = tensor.numel();
     array_.reserve(n);
     switch (tensor.dtype()) {
@@ -71,7 +71,7 @@ class ScalarArrayBase {
 
   // The Tensor in vec must have only one element
   ScalarArrayBase(const std::vector<T>& tensor_list) {  // NOLINT
-    is_init_by_tensor_ = true;
+    is_from_tensor_ = true;
 
     for (size_t i = 0; i < tensor_list.size(); ++i) {
       DataType data_type = tensor_list[i].dtype();
@@ -117,7 +117,7 @@ class ScalarArrayBase {
   // TODO(zhangyunfei) Replace std::vector with a more efficient container
   // structure.
   std::vector<int64_t> array_;
-  bool is_init_by_tensor_{false};
+  bool is_from_tensor_{false};
 };
 
 using ScalarArray =
diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc
index 5e1ce67572e27..f12ce51f95ab5 100644
--- a/paddle/pten/infermeta/unary.cc
+++ b/paddle/pten/infermeta/unary.cc
@@ -236,7 +236,8 @@ void ReshapeInferMeta(const MetaTensor& x,
   PADDLE_ENFORCE_NOT_NULL(out,
                           pten::errors::InvalidArgument(
                               "Output(Out) of ReshapeOp should not be null."));
-  if (!config.is_runtime && shape_data.size() == 0) {
+  if (!config.is_runtime && shape.FromTensor()) {
+    out->set_dims(pten::framework::make_ddim(shape_data));
     out->share_lod(x);
     return;
   }
diff --git a/paddle/pten/kernels/cpu/split_kernel.cc b/paddle/pten/kernels/cpu/split_kernel.cc
index 78fcdcb155cf2..f3d29247d5421 100644
--- a/paddle/pten/kernels/cpu/split_kernel.cc
+++ b/paddle/pten/kernels/cpu/split_kernel.cc
@@ -29,7 +29,7 @@ void SplitKernel(const Context& dev_ctx,
                  const Scalar& axis_scalar,
                  std::vector<DenseTensor*> outs) {
   // need to infershape output
-  if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) {
+  if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) {
     std::vector<MetaTensor> out_metas;
     for (size_t i = 0; i < outs.size(); ++i) {
       out_metas.push_back(outs[i]);
diff --git a/paddle/pten/kernels/gpu/split_kernel.cu b/paddle/pten/kernels/gpu/split_kernel.cu
index 46d18b75b611b..052c5d3f58316 100644
--- a/paddle/pten/kernels/gpu/split_kernel.cu
+++ b/paddle/pten/kernels/gpu/split_kernel.cu
@@ -28,7 +28,7 @@ void SplitKernel(const Context& dev_ctx,
                  const Scalar& axis_scalar,
                  std::vector<DenseTensor*> outs) {
   // need to infershape output
-  if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) {
+  if (num_or_sections.FromTensor() || axis_scalar.FromTensor()) {
     std::vector<MetaTensor> out_metas;
     for (size_t i = 0; i < outs.size(); ++i) {
       out_metas.push_back(outs[i]);

From cdd4b679ebbd314450a48175992ac2be64d4d8d0 Mon Sep 17 00:00:00 2001
From: YuanRisheng <yuanrisheng@baidu.com>
Date: Sat, 19 Feb 2022 05:27:56 +0000
Subject: [PATCH 5/6] fix bugs when run infrt test

---
 paddle/fluid/framework/infershape_utils.cc               | 9 +++++----
 paddle/pten/api/lib/utils/tensor_utils.cc                | 2 +-
 paddle/pten/common/scalar_array.h                        | 2 +-
 .../tests/unittests/ir/inference/test_trt_reshape_op.py  | 8 ++------
 python/paddle/utils/code_gen/api.yaml                    | 1 -
 5 files changed, 9 insertions(+), 13 deletions(-)

diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc
index 89b25bfea28a9..013438f57e835 100644
--- a/paddle/fluid/framework/infershape_utils.cc
+++ b/paddle/fluid/framework/infershape_utils.cc
@@ -294,14 +294,15 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
   auto& attr_names = std::get<1>(signature.args);
   auto& output_names = std::get<2>(signature.args);
 
-  auto kernels = pten::KernelFactory::Instance().kernels().find(signature.name);
-  if (kernels == pten::KernelFactory::Instance().kernels().end()) {
+  auto kernels_map =
+      pten::KernelFactory::Instance().SelectKernelMap(signature.name);
+  if (kernels_map.size() == 0) {
     PADDLE_THROW(
         platform::errors::Unimplemented("Not find `%s` kernels when construct "
                                         "InferMetaContext.",
                                         signature.name));
   }
-  auto attr_defs = kernels->second.cbegin()->second.args_def().attribute_defs();
+  auto attr_defs = kernels_map.cbegin()->second.args_def().attribute_defs();
 
   // TODO(chenweihang): support multiple inputs and outputs later
   pten::InferMetaContext infer_mete_context;
@@ -359,7 +360,7 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
             }
           }
           pten::ScalarArray tensor_attr(std::vector<int32_t>(num_ele, -1));
-          tensor_attr.SetInitFromTensor(true);
+          tensor_attr.SetFromTensor(true);
           infer_meta_context.EmplaceBackAttr(std::move(tensor_attr));
         }
       } else if (ctx->HasAttr(attr_name)) {
diff --git a/paddle/pten/api/lib/utils/tensor_utils.cc b/paddle/pten/api/lib/utils/tensor_utils.cc
index 1f16773d3ef87..26dd06921f16a 100644
--- a/paddle/pten/api/lib/utils/tensor_utils.cc
+++ b/paddle/pten/api/lib/utils/tensor_utils.cc
@@ -131,7 +131,7 @@ pten::ScalarArray MakePtenScalarArrayFromVarList(
   }
 
   pten::ScalarArray result{vector_data};
-  result.SetInitFromTensor(true);
+  result.SetFromTensor(true);
 
   return result;
 }
diff --git a/paddle/pten/common/scalar_array.h b/paddle/pten/common/scalar_array.h
index 87fd7342c3624..22b93448f8a46 100644
--- a/paddle/pten/common/scalar_array.h
+++ b/paddle/pten/common/scalar_array.h
@@ -45,7 +45,7 @@ class ScalarArrayBase {
 
   bool FromTensor() const { return is_from_tensor_; }
 
-  void SetInitFromTensor(bool val) { is_from_tensor_ = val; }
+  void SetFromTensor(bool val) { is_from_tensor_ = val; }
 
   // The Tensor must have one dim
   ScalarArrayBase(const T& tensor) {  // NOLINT
diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py
index b23c7d9b493d0..0522df3a9219d 100644
--- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py
+++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape_op.py
@@ -91,14 +91,10 @@ def setUp(self):
         with fluid.program_guard(self.main_program, self.startup_program):
             data = fluid.data(
                 name='data', shape=self.data_shape, dtype='float32')
-            actual_reshape = fluid.data(
-                name='actual_reshape', shape=[4], dtype='int32')
-            reshape_out = fluid.layers.reshape(
-                x=data, shape=self.reshape, actual_shape=actual_reshape)
+            reshape_out = fluid.layers.reshape(x=data, shape=self.reshape)
             out = fluid.layers.batch_norm(reshape_out, is_test=True)
         self.feeds = {
-            'data': np.random.random(self.data_shape).astype('float32'),
-            'actual_reshape': np.array([2, 0, -1, 6]).astype('int32')
+            'data': np.random.random(self.data_shape).astype('float32')
         }
         self.enable_trt = True
         self.trt_parameters = TRTReshapeTest.TensorRTParam(
diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml
index e07271ad9c81f..60e64c028430c 100644
--- a/python/paddle/utils/code_gen/api.yaml
+++ b/python/paddle/utils/code_gen/api.yaml
@@ -145,7 +145,6 @@
   output : Tensor(out)
   infer_meta :
     func : ReshapeInferMeta
-    param : [x, shape]
   kernel :
     func : reshape
   inplace : (x -> out)

From b731e5651cbacfe4c50c88d88aaee83ab07e911d Mon Sep 17 00:00:00 2001
From: YuanRisheng <yuanrisheng@baidu.com>
Date: Sun, 20 Feb 2022 04:05:59 +0000
Subject: [PATCH 6/6] pass converage

---
 paddle/pten/ops/compat/reshape_sig.cc | 9 +--------
 1 file changed, 1 insertion(+), 8 deletions(-)

diff --git a/paddle/pten/ops/compat/reshape_sig.cc b/paddle/pten/ops/compat/reshape_sig.cc
index 0316a73e55fa0..e03338440f214 100644
--- a/paddle/pten/ops/compat/reshape_sig.cc
+++ b/paddle/pten/ops/compat/reshape_sig.cc
@@ -28,15 +28,8 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) {
       return KernelSignature(
           "reshape_with_xshape", {"X"}, {"shape"}, {"XShape", "Out"});
     }
-  } else {
-    if (ctx.InputSize("ShapeTensor") > 0) {
-      return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"});
-    } else if (ctx.HasInput("Shape")) {
-      return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"});
-    } else {
-      return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"});
-    }
   }
+  return KernelSignature("unregistered", {}, {}, {});
 }
 
 KernelSignature ReshapeGradOpArgumentMapping(