Skip to content

Commit

Permalink
recover infershape in fill_and_like
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg committed Feb 22, 2022
1 parent 14a51fc commit 9250e87
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 44 deletions.
36 changes: 0 additions & 36 deletions paddle/fluid/framework/infershape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,42 +377,6 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
}

} else if (attr_defs[i].type_index ==
std::type_index(typeid(phi::Scalar))) {
if (ctx->HasInput(attr_name)) {
const auto& scalar_var = ctx->GetInputVarPtrs(attr_name);
if (ctx->IsRuntime()) {
// If is in runtime, we will get tensor's value for Scala
// and push it into attrs
Variable* var = BOOST_GET_CONST(Variable*, scalar_var[0]);
infer_meta_context.EmplaceBackAttr(
experimental::MakePtenScalarFromVar(*var));
} else {
// If is not in runtime, we will set default value(-1) for Scalar
phi::Scalar scalar_attr(-1);
scalar_attr.SetFromTensor(true);
infer_meta_context.EmplaceBackAttr(scalar_attr);
}
} else if (ctx->HasAttr(attr_name)) {
auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(float, attr)));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(std::string, attr)));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int))) {
infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(int, attr)));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"InferMetaContext.",
attr_name));
}
}
} else if (ctx->HasAttr(attr_name)) {
// Emplace Back Attr according to the type of attr.
auto& attr = attr_reader.GetAttr(attr_name);
Expand Down
16 changes: 8 additions & 8 deletions paddle/fluid/operators/fill_any_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ limitations under the License. */

#include <string>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/generated.h"

namespace paddle {
namespace operators {
Expand All @@ -26,6 +23,13 @@ class FillAnyLikeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fill_any_like");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fill_any_like");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
Expand Down Expand Up @@ -83,12 +87,8 @@ class FillAnyLikeVarTypeInference : public framework::VarTypeInference {
} // namespace paddle

namespace ops = paddle::operators;

DELCARE_INFER_SHAPE_FUNCTOR(fill_any_like, FillAnyLikeInferShapeFunctor,
PT_INFER_META(phi::Full_likeInferMeta));

REGISTER_OPERATOR(
fill_any_like, ops::FillAnyLikeOp, ops::FillAnyLikeOpMaker,
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
FillAnyLikeInferShapeFunctor, ops::FillAnyLikeVarTypeInference)
ops::FillAnyLikeVarTypeInference)

1 comment on commit 9250e87

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.