From 9250e87ad89bc4c4d40cd17b6860eca95fb164b7 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 22 Feb 2022 09:50:18 +0000 Subject: [PATCH] recover infershape in fill_and_like --- paddle/fluid/framework/infershape_utils.cc | 36 ---------------------- paddle/fluid/operators/fill_any_like_op.cc | 16 +++++----- 2 files changed, 8 insertions(+), 44 deletions(-) diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 7712b5dccd3946..aae36cf455dfee 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -377,42 +377,6 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } } - } else if (attr_defs[i].type_index == - std::type_index(typeid(phi::Scalar))) { - if (ctx->HasInput(attr_name)) { - const auto& scalar_var = ctx->GetInputVarPtrs(attr_name); - if (ctx->IsRuntime()) { - // If is in runtime, we will get tensor's value for Scala - // and push it into attrs - Variable* var = BOOST_GET_CONST(Variable*, scalar_var[0]); - infer_meta_context.EmplaceBackAttr( - experimental::MakePtenScalarFromVar(*var)); - } else { - // If is not in runtime, we will set default value(-1) for Scalar - phi::Scalar scalar_attr(-1); - scalar_attr.SetFromTensor(true); - infer_meta_context.EmplaceBackAttr(scalar_attr); - } - } else if (ctx->HasAttr(attr_name)) { - auto& attr = attr_reader.GetAttr(attr_name); - if (std::type_index(attr.type()) == std::type_index(typeid(float))) { - infer_meta_context.EmplaceBackAttr( - phi::Scalar(BOOST_GET_CONST(float, attr))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::string))) { - infer_meta_context.EmplaceBackAttr( - phi::Scalar(BOOST_GET_CONST(std::string, attr))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(int))) { - infer_meta_context.EmplaceBackAttr( - phi::Scalar(BOOST_GET_CONST(int, attr))); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported cast op attribute `%s` to Scalar when construct " - "InferMetaContext.", - attr_name)); - } - } } else if (ctx->HasAttr(attr_name)) { // Emplace Back Attr according to the type of attr. auto& attr = attr_reader.GetAttr(attr_name); diff --git a/paddle/fluid/operators/fill_any_like_op.cc b/paddle/fluid/operators/fill_any_like_op.cc index 9f57e215048c34..e6de430a78c1a3 100644 --- a/paddle/fluid/operators/fill_any_like_op.cc +++ b/paddle/fluid/operators/fill_any_like_op.cc @@ -14,10 +14,7 @@ limitations under the License. */ #include -#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/generated.h" namespace paddle { namespace operators { @@ -26,6 +23,13 @@ class FillAnyLikeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fill_any_like"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fill_any_like"); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); + } + protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -83,12 +87,8 @@ class FillAnyLikeVarTypeInference : public framework::VarTypeInference { } // namespace paddle namespace ops = paddle::operators; - -DELCARE_INFER_SHAPE_FUNCTOR(fill_any_like, FillAnyLikeInferShapeFunctor, - PT_INFER_META(phi::Full_likeInferMeta)); - REGISTER_OPERATOR( fill_any_like, ops::FillAnyLikeOp, ops::FillAnyLikeOpMaker, ::paddle::framework::EmptyGradOpMaker, ::paddle::framework::EmptyGradOpMaker, - FillAnyLikeInferShapeFunctor, ops::FillAnyLikeVarTypeInference) + ops::FillAnyLikeVarTypeInference)