diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ab69170322ce3..01536fd36ff83 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -50,6 +50,7 @@ if (WITH_TESTING) endif(WITH_TESTING) cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PATTERN_DETECTOR_DEPS}) +cc_library(op_compat_sensible_pass SRCS op_compat_sensible_pass.cc DEPS graph_pattern_detector) cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor) cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass) @@ -139,6 +140,7 @@ cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) +cc_test(test_op_compat_sensible_pass SRCS op_compat_sensible_pass_tester.cc DEPS op_compat_sensible_pass) cc_test(test_fc_fuse_pass_cc SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_fc_lstm_fuse_pass_cc SRCS fc_lstm_fuse_pass_tester.cc DEPS fc_lstm_fuse_pass framework_proto) cc_test(test_fc_gru_fuse_pass_cc SRCS fc_gru_fuse_pass_tester.cc DEPS fc_gru_fuse_pass framework_proto) diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc new file mode 100644 index 0000000000000..f7312ca555531 --- /dev/null +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc @@ -0,0 +1,178 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +AttrCompat& AttrCompat::IsStringIn(const std::set& candidates) { + conditions_.emplace_back([candidates](const Attribute& attr) -> bool { + std::string value = BOOST_GET_CONST(std::string, attr); + for (auto& str : candidates) { + if (str == value) { + return true; + } + } + return false; + }); + return *this; +} + +AttrCompat& AttrCompat::IsStringMatch( + const std::function& func) { + conditions_.emplace_back([func](const Attribute& attr) -> bool { + std::string value = BOOST_GET_CONST(std::string, attr); + return func(value); + }); + return *this; +} + +AttrCompat& AttrCompat::IsIntIn(const std::set& candidates) { + conditions_.emplace_back([candidates](const Attribute& attr) -> bool { + int value = BOOST_GET_CONST(int, attr); + return candidates.find(value) != candidates.end(); + }); + return *this; +} + +//! Todo: append the definition. +AttrCompat& AttrCompat::IsLeftDefault() { return *this; } + +bool AttrCompat::operator()(const OpDesc& op_desc) { + if (!op_desc.HasAttr(attr_name_)) { + return false; + } + const Attribute attr = op_desc.GetAttr(attr_name_); + for (auto& func : conditions_) { + if (!func(attr)) { + return false; + } + } + return true; +} + +AttrCompat& AttrCompat::IsBoolEQ(bool v) { + conditions_.emplace_back([v](const Attribute& attr) -> bool { + bool value = BOOST_GET_CONST(bool, attr); + return value == v; + }); + return *this; +} + +InputOrOutputCompat& InputOrOutputCompat::IsTensor() { + conditions_.emplace_back([](const std::vector& input) -> bool { + return input.size() == 1u; + }); + return *this; +} + +InputOrOutputCompat& InputOrOutputCompat::IsOptional() { + optional_ = true; + return *this; +} + +bool InputOrOutputCompat::operator()( + const std::vector& input) const { + if (input.empty()) return false; + for (auto& func : conditions_) { + if (!func(input)) { + return false; + } + } + return true; +} + +AttrCompat& OpCompat::AddAttr(const std::string& attr_name) { + attr_compats_.emplace_back(attr_name, this); + return attr_compats_.back(); +} + +InputOrOutputCompat& OpCompat::AddInput(const std::string& name) { + PADDLE_ENFORCE_EQ(input_compats_.find(name), input_compats_.end(), + platform::errors::InvalidArgument( + "The input with the same name has been added")); + input_compats_.emplace(name, InputOrOutputCompat(name, this)); + return input_compats_.at(name); +} + +InputOrOutputCompat& OpCompat::AddOutput(const std::string& name) { + PADDLE_ENFORCE_EQ(output_compats_.find(name), output_compats_.end(), + platform::errors::InvalidArgument( + "The output with the same name has been added")); + output_compats_.emplace(name, InputOrOutputCompat(name, this)); + return output_compats_.at(name); +} + +bool OpCompat::Judge(const OpDesc& op_desc) { + for (auto& attr_compat : attr_compats_) { + if (!attr_compat(op_desc)) { + return false; + } + } + + const VariableNameMap& inputs_map = op_desc.Inputs(); + for (auto& input_desc : inputs_map) { + if (input_compats_.find(input_desc.first) == input_compats_.end()) { + if (!input_desc.second.empty()) { + return false; + } + } + } + for (auto& input_val : input_compats_) { + if (inputs_map.find(input_val.first) == inputs_map.end()) { + if (!input_val.second.Optional()) { + return false; + } + } else { + if (!input_val.second(inputs_map.at(input_val.first))) { + return false; + } + } + } + + const VariableNameMap& outputs_map = op_desc.Outputs(); + for (auto& output_desc : outputs_map) { + if (output_compats_.find(output_desc.first) == output_compats_.end()) { + if (!output_desc.second.empty()) { + return false; + } + } + } + for (auto& output_val : output_compats_) { + if (outputs_map.find(output_val.first) == outputs_map.end()) { + if (!output_val.second.Optional()) { + return false; + } + } else { + if (!output_val.second(outputs_map.at(output_val.first))) { + return false; + } + } + } + return true; +} + +OpCompat& OpCompatSensiblePass::AddOpCompat(OpCompat&& op_compat) { + std::string name = op_compat.Name(); + op_compat_judgers_[name].reset(new OpCompat(std::move(op_compat))); + return *(op_compat_judgers_[name]); +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.h b/paddle/fluid/framework/ir/op_compat_sensible_pass.h new file mode 100644 index 0000000000000..6c0860549fbfe --- /dev/null +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.h @@ -0,0 +1,294 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class OpCompat; + +class AttrCompat { + public: + AttrCompat(const std::string& attr_name, OpCompat* op_compat) + : attr_name_(attr_name), op_compat_(op_compat) {} + + // @{ String-related methods + //! Assert the attribute is an string in the `candidates` domain. + AttrCompat& IsStringIn(const std::set& candidates); + //! Assert the attribute is a string and match a custom judging function. + AttrCompat& IsStringMatch( + const std::function& func); + // @} + + //! Assert the attribute is an integer in the `candidates` domain. + AttrCompat& IsIntIn(const std::set& candidates); + + // @{ Number-releated methods + //! Assert the attribute is a number and > `v`. + template + AttrCompat& IsNumGT(T v); + //! Assert the attribute is a number and >= `v`. + template + AttrCompat& IsNumGE(T v); + //! Assert the attribute is a number and < `v`. + template + AttrCompat& IsNumLT(T v); + //! Assert the attribute is a number and <= `v`. + template + AttrCompat& IsNumLE(T v); + //! Assert the attribute is a number and == `v`. + template + AttrCompat& IsNumEQ(T v); + //! Assert the attribute is a number and matches a customized judging + //! function. + template + AttrCompat& IsNumMatch(bool (*func)(T)); + // @} + + //! Assert the attribute is a boolean value equals `v`. + AttrCompat& IsBoolEQ(bool v); + + //! Tell whether this attribute is left as default value. + AttrCompat& IsLeftDefault(); + + //! Jump back to retrieve OpCompat instance. + OpCompat& End() { return *op_compat_; } + + bool operator()(const OpDesc& op_desc); + + private: + std::string attr_name_; + OpCompat* op_compat_; + std::vector> conditions_; +}; + +class InputOrOutputCompat { + public: + InputOrOutputCompat(const std::string& name, OpCompat* op_compat) + : optional_(false), name_(name), op_compat_(op_compat) {} + + InputOrOutputCompat& IsTensor(); + InputOrOutputCompat& IsOptional(); + bool Optional() const { return optional_; } + bool operator()(const std::vector& input) const; + + //! Jump back to retrieve OpCompat instance. + OpCompat& End() { return *op_compat_; } + + private: + bool optional_; + std::string name_; + OpCompat* op_compat_; + std::vector&)>> conditions_; +}; + +/** + * OpCompat is a helper class to help define the compatible Op definition. + * + * Usage: + * OpCompat compat("FC"); + * compat.AddAttr("in_num_col_dims").IsNumLE(1).End() + * .AddAttr("activation_type").IsStringIn({"tanh", "sigmoid"}).End() + * .AddInput("Input").IsTensor().End() + * .AddInput("W").IsTensor().End() + * .AddInput("Bias").IsTensor().IsOptional().End() + * .AddOutput("Out").IsTensor().End() + * + * All the inference-aware Op defition is as above, all the other attributes not + * contained in the definition should be set default value or it would be judged + * incompatible. + */ +class OpCompat { + public: + explicit OpCompat(const std::string& op_name) : op_name_(op_name) {} + explicit OpCompat(std::string&& op_name) : op_name_(std::move(op_name)) {} + explicit OpCompat(const OpCompat&) = default; + explicit OpCompat(OpCompat&&) = default; + + AttrCompat& AddAttr(const std::string& attr_name); + InputOrOutputCompat& AddInput(const std::string& name); + InputOrOutputCompat& AddOutput(const std::string& name); + + //! Judge whether an OpDesc match the defined Op compatibility. + bool Judge(const OpDesc& op_desc); + const std::string& Name() const { return op_name_; } + + private: + std::string op_name_; + std::vector attr_compats_; + std::unordered_map input_compats_; + std::unordered_map output_compats_; +}; + +/** + * OpCompatSensiblePass is a base class for all the passes thouse is sensitive + * to Op update. + * There are two methods to help tell the compability of an Op + * bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph, Graph* g); + * bool IsCompat(const OpDesc& op_desc); + * + * One can register the related Op compabilities using + * void AddOpCompat(OpCompat&& judger); + * + * Most of the Passes are used for fusing ops, so we define a method for such + * scenerios. + * void AccessSubgraph(const GraphPatternDetector::subgraph_t& subgraph, + Graph* g); + * It will check the Op compatibility automatically. + * For other scenirios, one should call `IsCompat` by himself. + * + * A FC fuse pass example: + * class FcFusePass : public OpCompatSensiblePass { + * public: + * FcFusePass() { + * // define Mul op compatiblity. + * AddOpCompat(OpCompat("Mul")) + * .AddInput("Input").IsTensor().End() + * .AddAttr("in_num_col_dims").IsNumGE(1); + * AddOpCompat(OpCompat("Add")). ...; + * // There are multiple activation implemention. + * AddOpCompat(OpCompat("Tanh")). ...; + * AddOpCompat(OpCompat("Sigmoid")). ...; + * } + * + * // override the subgraph access method + * virtual bool AccessSubgraphImpl( + * const GraphPatternDetector::subgraph_t& subgraph, + * Graph* g) override { ... } + * + * // Call the AccessSubgraph method in main procedure of this Pass. + * }; + */ +class OpCompatSensiblePass : public Pass { + public: + //! Access the subgraph and pattern. + void AccessSubgraph(const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + if (IsCompat(subgraph, g)) { + AccessSubgraphImpl(subgraph, g); + } + } + + protected: + /** + * Developer should push the compatibility `teller` for each kind of Op in the + * subgraph. + * NOTE One should add all the related op compatiblity in the construct so + * that all the following methods are valid. + */ + OpCompat& AddOpCompat(OpCompat&& op_compat); + + //! Modify the subgraph. + virtual bool AccessSubgraphImpl( + const GraphPatternDetector::subgraph_t& subgraph, Graph* g) const { + return true; + } + + //! Tell the Op compability of a subgraph. + bool IsCompat(const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) const { + CHECK(!op_compat_judgers_.empty()) + << "At least one OpCompat instance should be added in the " + "OpCompatSensiblePass."; + // Check the all the ops in the subgraph are contained in the + // op_compat. + for (auto& node_pair : subgraph) { + if (!node_pair.first->IsOp()) continue; + auto op_type = node_pair.second->Op()->Type(); + if (!op_compat_judgers_.count(op_type)) { + return false; + } + auto& judger = *op_compat_judgers_.at(op_type); + if (!judger.Judge(*(node_pair.second->Op()))) { + return false; + } + } + return true; + } + + //! Tell the op compatibility of a single Op. + bool IsCompat(const OpDesc& op_desc) const { + if (!op_compat_judgers_.count(op_desc.Type())) return false; + return op_compat_judgers_.at(op_desc.Type())->Judge(op_desc); + } + + private: + std::map> op_compat_judgers_; +}; + +template +AttrCompat& AttrCompat::IsNumGT(T v) { + conditions_.emplace_back([v](const Attribute& attr) -> bool { + T value = BOOST_GET_CONST(T, attr); + return value > v; + }); + return *this; +} + +template +AttrCompat& AttrCompat::IsNumGE(T v) { + conditions_.emplace_back([v](const Attribute& attr) -> bool { + T value = BOOST_GET_CONST(T, attr); + return value >= v; + }); + return *this; +} + +template +AttrCompat& AttrCompat::IsNumLT(T v) { + conditions_.emplace_back([v](const Attribute& attr) -> bool { + T value = BOOST_GET_CONST(T, attr); + return value < v; + }); + return *this; +} + +template +AttrCompat& AttrCompat::IsNumLE(T v) { + conditions_.emplace_back([v](const Attribute& attr) -> bool { + T value = BOOST_GET_CONST(T, attr); + return value <= v; + }); + return *this; +} + +template +AttrCompat& AttrCompat::IsNumEQ(T v) { + conditions_.emplace_back([v](const Attribute& attr) -> bool { + T value = BOOST_GET_CONST(T, attr); + return value == v; + }); + return *this; +} + +template +AttrCompat& AttrCompat::IsNumMatch(bool (*func)(T)) { + conditions_.emplace_back([func](const Attribute& attr) -> bool { + T value = BOOST_GET_CONST(T, attr); + return func(value); + }); + return *this; +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc new file mode 100644 index 0000000000000..3d0863a6d12d9 --- /dev/null +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass_tester.cc @@ -0,0 +1,133 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/op_compat_sensible_pass.h" + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/program_desc.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(OpCompatSensiblePass, compatOp) { + auto lambda = [](const std::string& str) { return str == "tanh"; }; + OpCompat compat("FC"); + compat.AddAttr("in_num_col_dims") + .IsIntIn({1, 2}) + .IsNumLE(1) + .IsLeftDefault() + .End() + .AddAttr("activation_type") + .IsStringIn({"tanh", "sigmoid"}) + .IsStringMatch(lambda) + .End() + .AddAttr("test_attr") + .IsBoolEQ(true) + .End() + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("Test") + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End(); + + OpDesc fc_op; + + std::unordered_map attr_map; + attr_map["in_num_col_dims"] = 1; + attr_map["activation_type"] = std::string("tanh"); + attr_map["test_attr"] = true; + + fc_op.SetAttrMap(attr_map); + + fc_op.SetInput("Input", std::vector{"test_input"}); + fc_op.SetInput("W", std::vector{"test_input_0"}); + fc_op.SetInput("Bias", std::vector{"test_input_1"}); + fc_op.SetOutput("Out", std::vector{"test_output"}); + + EXPECT_STREQ(compat.Name().c_str(), "FC"); + EXPECT_TRUE(compat.Judge(fc_op)); +} + +class OpCompatSensiblePassTest : public OpCompatSensiblePass { + public: + OpCompatSensiblePassTest(); + bool TestIsCompat(const OpDesc& op_desc) { return IsCompat(op_desc); } +}; + +OpCompatSensiblePassTest::OpCompatSensiblePassTest() { + AddOpCompat(OpCompat("FC")) + .AddAttr("in_num_col_dims") + .IsNumLE(1) + .End() + .AddAttr("activation_type") + .IsStringIn({"tanh", "sigmoid"}) + .End() + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor(); +} + +TEST(OpCompatSensiblePass, IsCompat) { + OpCompatSensiblePassTest test; + OpDesc fc_op; + fc_op.SetType("FC"); + std::unordered_map attr_map; + attr_map["in_num_col_dims"] = 1; + attr_map["activation_type"] = std::string("tanh"); + + fc_op.SetAttrMap(attr_map); + fc_op.SetInput("Input", std::vector{"test_input"}); + fc_op.SetInput("W", std::vector{"test_input_0"}); + fc_op.SetInput("Bias", std::vector{"test_input_1"}); + fc_op.SetOutput("Out", std::vector{"test_output"}); + + EXPECT_TRUE(test.TestIsCompat(fc_op)); + + ProgramDesc prog; + std::unique_ptr g(new Graph(prog)); + Node* o1 = g->CreateOpNode(&fc_op); + + GraphPatternDetector detector; + PDNode* op2 = + detector.mutable_pattern()->NewNode([](Node* x) { return true; }); + GraphPatternDetector::subgraph_t subgraph; + subgraph[op2] = o1; + + test.AccessSubgraph(subgraph, g.get()); +} + +} // namespace ir +} // namespace framework +} // namespace paddle