Skip to content

Commit

Permalink
Merge pull request #14709 from yihuaxu/develop_4f71a6ee2_conv3d_bias_…
Browse files Browse the repository at this point in the history
…fusion_mkldnn_impl

Implement the fusion of convolution 3D and bias for mkldnn
  • Loading branch information
luotao1 authored Dec 7, 2018
2 parents add98c9 + 3821fc3 commit c83d5b7
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 17 deletions.
10 changes: 7 additions & 3 deletions paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,16 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
auto* scope = param_scope();
PADDLE_ENFORCE(scope);

std::string type = is_conv3d() ? "conv3d" : "conv2d";

GraphPatternDetector gpd;
auto* conv_input =
gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput()
->assert_is_op_input("conv2d", "Input");
->assert_is_op_input(type, "Input");
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
conv_bias_pattern(conv_input);
conv_bias_pattern(conv_input, is_conv3d());
int found_conv_bias_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
Expand Down Expand Up @@ -109,7 +111,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
desc.SetType("conv2d");
desc.SetType(type);

for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second);
Expand All @@ -135,3 +137,5 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
} // namespace paddle
REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
paddle::framework::ir::ConvBiasFusePass);
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv3DBiasFusePass);
8 changes: 8 additions & 0 deletions paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,19 @@ namespace ir {
class ConvBiasFusePass : public FusePassBase {
public:
virtual ~ConvBiasFusePass() {}
virtual bool is_conv3d() const { return false; }

protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_bias_mkldnn_fuse"};
};
/*
* Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp.
*/
class Conv3DBiasFusePass : public ConvBiasFusePass {
public:
bool is_conv3d() const override { return true; }
};
} // namespace ir
} // namespace framework
} // namespace paddle
11 changes: 6 additions & 5 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1030,22 +1030,23 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
}

PDNode *patterns::ConvBias::operator()(
paddle::framework::ir::PDNode *conv_input) {
paddle::framework::ir::PDNode *conv_input, bool is_conv3d) {
std::string type = is_conv3d ? "conv3d" : "conv2d";
// Create Operators
conv_input->assert_is_op_input("conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d");
conv_input->assert_is_op_input(type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(type);
auto *eltiwse_op =
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
// Create variables
// Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter");
->assert_is_op_input(type, "Filter");
// intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op("conv2d")
->assert_is_only_output_of_op(type)
->assert_is_op_input("elementwise_add");
// Bias stored in elementwise_add
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
struct ConvBias : public PatternBase {
ConvBias(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bias") {}
PDNode* operator()(PDNode* conv_input);
PDNode* operator()(PDNode* conv_input, bool is_conv3d = false);
// declare operator node's name
PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(eltwise);
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/inference/api/paddle_pass_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ class CpuPassStrategy : public PassStrategy {
passes_.insert(passes_.begin(), "mkldnn_placement_pass");

for (auto &pass :
std::vector<std::string>({"depthwise_conv_mkldnn_pass", //
"conv_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", //
std::vector<std::string>({"depthwise_conv_mkldnn_pass", //
"conv_bias_mkldnn_fuse_pass", //
"conv3d_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass"})) {
passes_.push_back(pass);
}
Expand Down
23 changes: 20 additions & 3 deletions paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,14 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
}

// Easy for profiling independently.
TEST(Analyzer_dam, profile) {
void profile(bool use_mkldnn = false) {
contrib::AnalysisConfig cfg;
SetConfig(&cfg);

if (use_mkldnn) {
cfg.EnableMKLDNN();
}

std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
Expand All @@ -209,6 +213,11 @@ TEST(Analyzer_dam, profile) {
}
}

TEST(Analyzer_dam, profile) { profile(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_dam, profile_mkldnn) { profile(true /* use_mkldnn */); }
#endif

// Check the fuse status
TEST(Analyzer_dam, fuse_statis) {
contrib::AnalysisConfig cfg;
Expand All @@ -222,9 +231,12 @@ TEST(Analyzer_dam, fuse_statis) {
}

// Compare result of NativeConfig and AnalysisConfig
TEST(Analyzer_dam, compare) {
contrib::AnalysisConfig cfg;
void compare(bool use_mkldnn = false) {
AnalysisConfig cfg;
SetConfig(&cfg);
if (use_mkldnn) {
cfg.EnableMKLDNN();
}

std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
Expand All @@ -233,5 +245,10 @@ TEST(Analyzer_dam, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}

TEST(Analyzer_dam, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_dam, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif

} // namespace inference
} // namespace paddle
5 changes: 3 additions & 2 deletions paddle/fluid/operators/activation_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
const T *x_data = x->data<T>();
T *y_data = y->mutable_data<T>(ctx.GetPlace());

PADDLE_ENFORCE(x->dims().size() == 2 || x->dims().size() == 4,
"Input dim must be with 2 or 4");
PADDLE_ENFORCE(
x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
"Input dim must be with 2, 3 or 4");

std::vector<int> src_tz = framework::vectorize2int(x->dims());

Expand Down

0 comments on commit c83d5b7

Please sign in to comment.