Skip to content

Commit

Permalink
rm opdesc.hasinput (PaddlePaddle#381)
Browse files Browse the repository at this point in the history
  • Loading branch information
gglin001 authored Jan 11, 2022
1 parent 504e02d commit d5e5d5c
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 16 deletions.
4 changes: 0 additions & 4 deletions paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,6 @@ const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
return it->second;
}

const bool OpDesc::HasInput(const std::string &name) const {
return inputs_.find(name) != inputs_.end();
}

std::vector<std::string> OpDesc::InputArgumentNames() const {
std::vector<std::string> retv;
for (auto &ipt : this->inputs_) {
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/framework/op_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ class OpDesc {

const std::vector<std::string> &Input(const std::string &name) const;

const bool HasInput(const std::string &name) const;

std::vector<std::string> InputArgumentNames() const;

void SetInput(const std::string &param_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Node *mean_handler(Graph *graph, Node *node) {

Node *pow_handler(Graph *graph, Node *node) {
auto *op = node->Op();
if (op->HasInput("FactorTensor") && !op->Input("FactorTensor").empty()) {
if (!op->Input("FactorTensor").empty()) {
return CreateBaseOp(
graph, node, "popart_pow",
{GetInputVarNode("X", node), GetInputVarNode("FactorTensor", node)},
Expand Down Expand Up @@ -161,7 +161,7 @@ Node *scale_handler(Graph *graph, Node *node) {
static_cast<int>(framework::proto::VarType::FP32));

Node *result = nullptr;
if (op->HasInput("ScaleTensor") && !op->Input("ScaleTensor").empty()) {
if (!op->Input("ScaleTensor").empty()) {
auto scale = GetInputVarNode("ScaleTensor", node);
if (is_float_equal(bias_, 0.0)) {
result = CreateBaseOp(graph, node, "popart_mul",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Node *conv2d_handler(Graph *graph, Node *node) {
auto pads = std::vector<int64_t>{pads_.begin(), pads_.end()};
auto stride_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("strides"));
auto stride = std::vector<int64_t>{stride_.begin(), stride_.end()};
if (op->HasInput("Bias") && !op->Input("Bias").empty()) {
if (!op->Input("Bias").empty()) {
return CreateConv(
graph, node,
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Node *topk_handler(Graph *graph, Node *node) {

Node *var_x = GetInputVarNode("X", node);
Node *var_k = nullptr;
if (op->HasInput("K") && !op->Input("K").empty()) {
if (!op->Input("K").empty()) {
var_k = GetInputVarNode("K", node);
} else {
auto k = BOOST_GET_CONST(int, op->GetAttr("k"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace {

Node *fill_constant_handler(Graph *graph, Node *node) {
auto *op = node->Op();
if (op->HasInput("ShapeTensor") && !op->Input("ShapeTensor").empty()) {
if (!op->Input("ShapeTensor").empty()) {
PADDLE_THROW(
platform::errors::Unimplemented("op fill_constant with ShapeTensor"));
}
Expand Down Expand Up @@ -328,7 +328,7 @@ Node *shape_handler(Graph *graph, Node *node) {
Node *slice_handler(Graph *graph, Node *node) {
auto *op = node->Op();
Node *starts = nullptr;
if (op->HasInput("StartsTensor") && !op->Input("StartsTensor").empty()) {
if (!op->Input("StartsTensor").empty()) {
starts = GetInputVarNode("StartsTensor", node);
} else {
auto starts_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("starts"));
Expand All @@ -338,7 +338,7 @@ Node *slice_handler(Graph *graph, Node *node) {
starts = starts->outputs[0];
}
Node *ends = nullptr;
if (op->HasInput("EndsTensor") && !op->Input("EndsTensor").empty()) {
if (!op->Input("EndsTensor").empty()) {
ends = GetInputVarNode("EndsTensor", node);
} else {
auto ends_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("ends"));
Expand Down Expand Up @@ -384,14 +384,13 @@ Node *slice_handler(Graph *graph, Node *node) {

Node *expand_handler(Graph *graph, Node *node) {
auto *op = node->Op();
if (op->HasInput("expand_times_tensor") &&
!op->Input("expand_times_tensor").empty()) {
if (!op->Input("expand_times_tensor").empty()) {
PADDLE_THROW(
platform::errors::Unimplemented("Expand op with expand_times_tensor"));
}

Node *expand_times = nullptr;
if (op->HasInput("ExpandTimes") && !op->Input("ExpandTimes").empty()) {
if (!op->Input("ExpandTimes").empty()) {
// cast to int64
expand_times =
CreateCast(graph, node, {GetInputVarNode("ExpandTimes", node)}, {},
Expand Down

0 comments on commit d5e5d5c

Please sign in to comment.