From 4428f4b0f1ff9e88d33a85bb0be1e45a8e1ef873 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 10 Jan 2024 10:52:16 +0000 Subject: [PATCH 1/5] adjust pir pass log printing --- .../fluid/inference/api/analysis_predictor.cc | 9 ++- .../pir/transforms/constant_folding_pass.cc | 31 +++++---- .../transforms/dead_code_elimination_pass.cc | 2 +- paddle/fluid/pir/transforms/inplace_pass.cc | 2 +- .../params_sync_among_devices_pass.cc | 2 +- .../replace_fetch_with_shadow_output_pass.cc | 2 +- .../pir/transforms/sub_graph_extract_pass.cc | 2 +- paddle/fluid/pybind/pir.cc | 3 +- paddle/pir/pass/pass.cc | 26 +------- paddle/pir/pass/pass.h | 29 +++++++-- paddle/pir/pass/pass_manager.h | 8 ++- paddle/pir/pass/print_statistics.cc | 64 +++++++++++++++++++ 12 files changed, 128 insertions(+), 52 deletions(-) create mode 100644 paddle/pir/pass/print_statistics.cc diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index fc27625d60bd7..e7d109e7fcde8 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -828,8 +828,13 @@ bool AnalysisPredictor::PrepareExecutor() { gpu_pm.AddPass(::pir::CreateDeadCodeEliminationPass()); gpu_pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass()); //----------------------------------------------------------------------------------------------// - - // gpu_pm.EnableIRPrinting(); + if (!config_.glog_info_disabled()) { + gpu_pm.EnablePrintStatistics(); + } + if (config_.ir_debug_) { + gpu_pm.EnableIRPrinting(); + gpu_pm.EnablePassTiming(); + } gpu_pm.Run(pir_program_.get()); } diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 620a7c1c2fecc..cca083f709025 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -40,6 +40,7 @@ #include "paddle/pir/core/operation.h" #include "paddle/pir/core/parameter.h" #include "paddle/pir/core/program.h" +#include "paddle/pir/core/region.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" @@ -51,7 +52,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { public: ConstantFoldingPattern( pir::IrContext* context, - size_t* counter, + size_t* suffix, const phi::Place& place, paddle::framework::Scope* scope, paddle::framework::interpreter::ExecutionConfig* exe_config, @@ -60,7 +61,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { 1 /*benefit*/, context, {} /*generated_names*/), - counter_(counter), + suffix_(suffix), place_(place), scope_(scope), exe_config_(exe_config), @@ -298,7 +299,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { .time_since_epoch() .count(); std::string output_var_name = - "constant_folding@_" + ss.str() + std::to_string((*counter_)++); + "constant_folding@_" + ss.str() + std::to_string((*suffix_)++); builder.Build(temp_op->result(i), output_var_name); output_var_names.push_back(output_var_name); @@ -308,7 +309,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { } protected: - size_t* counter_; + size_t* suffix_; phi::Place place_; paddle::framework::Scope* scope_; paddle::framework::interpreter::ExecutionConfig* exe_config_; @@ -319,13 +320,13 @@ class ConstantFoldingPatternForTrain : public ConstantFoldingPattern { public: ConstantFoldingPatternForTrain( pir::IrContext* context, - size_t* counter, + size_t* suffix, const phi::Place& place, paddle::framework::Scope* scope, paddle::framework::interpreter::ExecutionConfig* exe_config, std::vector* deleted_vars) : ConstantFoldingPattern( - context, counter, place, scope, exe_config, deleted_vars) {} + context, suffix, place, scope, exe_config, deleted_vars) {} bool Match(pir::Operation* op) const override { VLOG(4) << "constant_folding_pass applys match on [" << op->name() @@ -405,26 +406,32 @@ class ConstantFoldingPass : public pir::Pass { if (Has("train_mode") && Get("train_mode")) { ps.Add(context, - &counter_, + &suffix_, phi::CPUPlace{}, scope_, &exe_config_, &deleted_vars_); } else { ps.Add( - context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); + context, &suffix_, place_, scope_, &exe_config_, &deleted_vars_); } patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); return true; } void Run(pir::Operation* op) override { - size_t op_nums = op->GetParentProgram()->block()->size(); + int64_t num_ops{0}; + for (uint32_t i = 0; i < op->num_regions(); ++i) { + auto& region = op->region(i); + for (auto& block : region) { + num_ops += block.size(); + } + } pir::GreedyRewriteConfig cfg; cfg.use_top_down_traversal = true; cfg.max_iterations = 10; - pir::ApplyPatternsGreedily(op, patterns_, cfg); - PrintStatistics(counter_, op_nums); + auto [_, num_rewrites] = pir::ApplyPatternsGreedily(op, patterns_, cfg); + AddStatistics(num_rewrites, num_ops); // delete old parameter var scope_->EraseVars(deleted_vars_); if (place_.GetType() != phi::AllocationType::CPU) { @@ -434,7 +441,7 @@ class ConstantFoldingPass : public pir::Pass { } private: - size_t counter_{0}; + size_t suffix_{0}; phi::Place place_; paddle::framework::Scope* scope_{nullptr}; paddle::framework::interpreter::ExecutionConfig exe_config_{}; diff --git a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc index 1a6433e233ed1..ff7805b3dada0 100644 --- a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc +++ b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc @@ -32,7 +32,7 @@ class DeadCodeEliminationPass : public pir::Pass { VLOG(6) << "apply dead_code_elimination_pass"; int64_t num_erasers{0}; EraseOp(*op->GetParentProgram()->block(), &num_erasers); - PrintStatistics(num_erasers); + AddStatistics(num_erasers); } private: diff --git a/paddle/fluid/pir/transforms/inplace_pass.cc b/paddle/fluid/pir/transforms/inplace_pass.cc index e7b156309ade8..c862ec9245701 100644 --- a/paddle/fluid/pir/transforms/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/inplace_pass.cc @@ -492,7 +492,7 @@ class InplacePass : public pir::Pass { } } } - PrintStatistics(num_rewrites_); + AddStatistics(num_rewrites_); } }; diff --git a/paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc b/paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc index 794b5bfe29484..41e51d00ef704 100644 --- a/paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc +++ b/paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc @@ -92,7 +92,7 @@ class ParamsSyncAmongDevicesPass : public pir::Pass { } } } - PrintStatistics(num_rewrites_); + AddStatistics(num_rewrites_); } bool CanApplyOn(pir::Operation* op) const override { diff --git a/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc b/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc index 97a30adfeb877..5e499436ec7f6 100644 --- a/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc +++ b/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc @@ -56,7 +56,7 @@ class ReplaceFetchWithShadowOutputPass : public pir::Pass { cfg.use_top_down_traversal = true; cfg.max_iterations = 10; auto [_, num_rewrites] = pir::ApplyPatternsGreedily(op, patterns_, cfg); - PrintStatistics(num_rewrites); + AddStatistics(num_rewrites); } bool CanApplyOn(pir::Operation* op) const override { diff --git a/paddle/fluid/pir/transforms/sub_graph_extract_pass.cc b/paddle/fluid/pir/transforms/sub_graph_extract_pass.cc index 09691f76fe62f..d4ef090f53f68 100644 --- a/paddle/fluid/pir/transforms/sub_graph_extract_pass.cc +++ b/paddle/fluid/pir/transforms/sub_graph_extract_pass.cc @@ -54,7 +54,7 @@ class SubGraphExtractPass : public pir::Pass { std::vector groups = ::pir::SubgraphDetector(&block, IsSplitOp)(); - PrintStatistics(groups.size()); + AddStatistics(groups.size()); for (auto& group_ops : groups) { VLOG(4) << "current group_ops.size(): " << group_ops.size(); ::pir::ReplaceWithGroupOp(&block, group_ops); diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index c4c9eb0145aa5..1376f17e4f019 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -1638,7 +1638,8 @@ void BindPassManager(pybind11::module *m) { return pass_names; }) .def("run", [](PassManager &self, Program *p) { self.Run(p); }) - .def("empty", &PassManager::Empty); + .def("empty", &PassManager::empty) + .def("clear", &PassManager::clear); } void BindPir(pybind11::module *module) { diff --git a/paddle/pir/pass/pass.cc b/paddle/pir/pass/pass.cc index 37cd7f724b748..2f9cb896215dd 100644 --- a/paddle/pir/pass/pass.cc +++ b/paddle/pir/pass/pass.cc @@ -23,7 +23,6 @@ #include "paddle/pir/pass/pass_adaptor.h" #include "paddle/pir/pass/pass_instrumentation.h" #include "paddle/pir/pass/pass_manager.h" -#include "paddle/utils/string/pretty_log.h" namespace pir { @@ -34,27 +33,6 @@ Pass::~Pass() = default; bool Pass::CanApplyOn(Operation* op) const { return op->num_regions() > 0; } -void Pass::PrintStatistics(int64_t match_count) const { - if (match_count > 0) { - LOG(INFO) << "--- detected [" << match_count << "] subgraphs!"; - } -} - -void Pass::PrintStatistics(int64_t match_count, int64_t all_count) const { - IR_ENFORCE(match_count < all_count, - "match_count should smaller than all_count"); - if (match_count > 0) { - LOG(INFO) << "--- detected [" << match_count << "/" << all_count - << "] subgraphs!"; - } -} - -void Pass::PrintStatistics(const std::string& custom_log) const { - if (!custom_log.empty()) { - LOG(INFO) << custom_log; - } -} - detail::PassExecutionState& Pass::pass_state() { IR_ENFORCE(pass_state_.has_value() == true, "pass state has no value"); return *pass_state_; @@ -81,7 +59,7 @@ void PatternRewritePass::Run(Operation* op) { cfg.use_top_down_traversal = true; cfg.max_iterations = 10; auto [_, num_rewrites] = ApplyPatternsGreedily(op, patterns_, cfg); - PrintStatistics(num_rewrites); + AddStatistics(num_rewrites); } //----------------------------------------------------------------------------------------------// @@ -154,8 +132,6 @@ bool detail::PassAdaptor::RunPass(Pass* pass, adaptor->Run(op, opt_level, verify); } else { if (instrumentor) instrumentor->RunBeforePass(pass, op); - paddle::string::PrettyLogH1("--- Running PIR pass [%s]", - pass->pass_info().name); pass->Run(op); if (instrumentor) instrumentor->RunAfterPass(pass, op); } diff --git a/paddle/pir/pass/pass.h b/paddle/pir/pass/pass.h index a8a1d15345ae3..3bce02040247a 100644 --- a/paddle/pir/pass/pass.h +++ b/paddle/pir/pass/pass.h @@ -139,6 +139,8 @@ class IR_API Pass { void Set(const std::string& attr_name, AttrType* attr) { VLOG(3) << "Setting the attribute " << attr_name << " for the pass " << name(); + IR_ENFORCE( + !Has(attr_name), "Attribute %s already set in the %s.", attr_name); attrs_[attr_name] = attr; attr_dels_[attr_name] = [attr, attr_name]() { VLOG(8) << "deleting " << attr_name; @@ -150,9 +152,9 @@ class IR_API Pass { // should delete the attribute. template void SetNotOwned(const std::string& attr_name, AttrType* attr) { - IR_ENFORCE(0 == attrs_.count(attr_name), - "Attribute %s already set in the pass.", - attr_name); + VLOG(3) << "Setting the attribute " << attr_name << " for the " << name(); + IR_ENFORCE( + !Has(attr_name), "Attribute %s already set in the pass.", attr_name); attrs_[attr_name] = attr; } @@ -163,11 +165,26 @@ class IR_API Pass { virtual bool Initialize(IrContext* context) { return true; } - void PrintStatistics(int64_t match_count) const; + void AddStatistics(int64_t match_count) { + IR_ENFORCE(!Has("__match_count__"), + "Attribute __match_count__ already set in the pass."); + Set("__match_count__", new int64_t{match_count}); + } - void PrintStatistics(int64_t match_count, int64_t all_count) const; + void AddStatistics(int64_t match_count, int64_t all_count) { + IR_ENFORCE(!Has("__match_count__"), + "Attribute __match_count__ already set in the pass."); + IR_ENFORCE(!Has("__all_count__"), + "Attribute __all_count__ already set in the pass."); + Set("__match_count__", new int64_t{match_count}); + Set("__all_count__", new int64_t{all_count}); + } - void PrintStatistics(const std::string& custom_log) const; + void AddStatistics(const std::string& custom_log) { + IR_ENFORCE(!Has("__custom_log__"), + "Attribute __custom_log__ already set in the pass."); + Set("__custom_log__", new std::string{custom_log}); + } AnalysisManager analysis_manager() { return pass_state().am; } diff --git a/paddle/pir/pass/pass_manager.h b/paddle/pir/pass/pass_manager.h index 92faed24f1f5d..5ee383a57ae39 100644 --- a/paddle/pir/pass/pass_manager.h +++ b/paddle/pir/pass/pass_manager.h @@ -42,7 +42,9 @@ class IR_API PassManager { const std::vector> &passes() const { return passes_; } - bool Empty() const { return passes_.empty(); } + bool empty() const { return passes_.empty(); } + + void clear() { passes_.clear(); } IrContext *context() const { return context_; } @@ -115,6 +117,8 @@ class IR_API PassManager { void EnablePassTiming(bool print_module = true); + void EnablePrintStatistics(); + void AddInstrumentation(std::unique_ptr pi); private: @@ -129,6 +133,8 @@ class IR_API PassManager { bool verify_{true}; + bool disable_log_{false}; + std::vector> passes_; std::unique_ptr pass_adaptor_; diff --git a/paddle/pir/pass/print_statistics.cc b/paddle/pir/pass/print_statistics.cc new file mode 100644 index 0000000000000..f8f93fd4e1fad --- /dev/null +++ b/paddle/pir/pass/print_statistics.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/core/operation.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_instrumentation.h" +#include "paddle/pir/pass/pass_manager.h" +#include "paddle/utils/string/pretty_log.h" + +namespace pir { + +class PrintStatistics : public PassInstrumentation { + public: + PrintStatistics() = default; + + ~PrintStatistics() override = default; + + void RunBeforePass(Pass *pass, Operation *op) override { + paddle::string::PrettyLogH1("--- Running PIR pass [%s]", + pass->pass_info().name); + } + + void RunAfterPass(Pass *pass, Operation *op) override { + if (pass->Has("__match_count__") && pass->Has("__all_count__")) { + auto match_count = pass->Get("__match__count__"); + auto all_count = pass->Get("__all_count__"); + IR_ENFORCE(match_count < all_count, + "match_count: %d should smaller than all_count: %d", + match_count, + all_count); + if (match_count > 0) { + LOG(INFO) << "--- detected [" << match_count << "/" << all_count + << "] subgraphs!"; + } + } else if (pass->Has("__match_count__") && !pass->Has("__all_count__")) { + auto match_count = pass->Get("__match__count__"); + if (match_count > 0) { + LOG(INFO) << "--- detected [" << match_count << "] subgraphs!"; + } + } else if (pass->Has("__custom_log__")) { + auto custom_log = pass->Get("__custom_log__"); + if (!custom_log.empty()) { + LOG(INFO) << custom_log; + } + } + } +}; + +void PassManager::EnablePrintStatistics() { + AddInstrumentation(std::make_unique()); +} + +} // namespace pir From 374c59a663e86e6b9f9f81eac127731997c3cc22 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 10 Jan 2024 11:04:29 +0000 Subject: [PATCH 2/5] update --- paddle/pir/pass/print_statistics.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/pir/pass/print_statistics.cc b/paddle/pir/pass/print_statistics.cc index f8f93fd4e1fad..127da5f90ed88 100644 --- a/paddle/pir/pass/print_statistics.cc +++ b/paddle/pir/pass/print_statistics.cc @@ -33,7 +33,7 @@ class PrintStatistics : public PassInstrumentation { void RunAfterPass(Pass *pass, Operation *op) override { if (pass->Has("__match_count__") && pass->Has("__all_count__")) { - auto match_count = pass->Get("__match__count__"); + auto match_count = pass->Get("__match_count__"); auto all_count = pass->Get("__all_count__"); IR_ENFORCE(match_count < all_count, "match_count: %d should smaller than all_count: %d", @@ -44,7 +44,7 @@ class PrintStatistics : public PassInstrumentation { << "] subgraphs!"; } } else if (pass->Has("__match_count__") && !pass->Has("__all_count__")) { - auto match_count = pass->Get("__match__count__"); + auto match_count = pass->Get("__match_count__"); if (match_count > 0) { LOG(INFO) << "--- detected [" << match_count << "] subgraphs!"; } From a16726ae1e1f5ee74aef9a4077854a3d3154e7f0 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 10 Jan 2024 11:46:02 +0000 Subject: [PATCH 3/5] update --- .../fluid/inference/api/analysis_predictor.cc | 1 - .../fluid/pir/transforms/build_cinn_pass.cc | 2 +- paddle/pir/pass/ir_printing.cc | 22 ++++++++++++------- paddle/pir/pass/pass_manager.h | 16 +++++--------- paddle/pir/pass/pass_timing.cc | 5 ++++- .../pattern_rewrite/pattern_rewrite_test.cc | 18 +++++++-------- 6 files changed, 33 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index e7d109e7fcde8..00995194ecf89 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -833,7 +833,6 @@ bool AnalysisPredictor::PrepareExecutor() { } if (config_.ir_debug_) { gpu_pm.EnableIRPrinting(); - gpu_pm.EnablePassTiming(); } gpu_pm.Run(pir_program_.get()); } diff --git a/paddle/fluid/pir/transforms/build_cinn_pass.cc b/paddle/fluid/pir/transforms/build_cinn_pass.cc index 7513cc806e502..3edee80df446b 100644 --- a/paddle/fluid/pir/transforms/build_cinn_pass.cc +++ b/paddle/fluid/pir/transforms/build_cinn_pass.cc @@ -240,7 +240,7 @@ class BuildCinnPass : public pir::Pass { std::vector groups = ::pir::SubgraphDetector(&block, IsSupportCinn)(); - PrintStatistics(groups.size()); + AddStatistics(groups.size()); for (auto& group_ops : groups) { VLOG(4) << "current group_ops.size(): " << group_ops.size(); ::pir::ReplaceWithGroupOp(&block, group_ops); diff --git a/paddle/pir/pass/ir_printing.cc b/paddle/pir/pass/ir_printing.cc index 901c8bdd89da7..900f52c3a250e 100644 --- a/paddle/pir/pass/ir_printing.cc +++ b/paddle/pir/pass/ir_printing.cc @@ -13,10 +13,12 @@ // limitations under the License. #include +#include #include #include #include "paddle/pir/core/operation.h" +#include "paddle/pir/core/program.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_instrumentation.h" #include "paddle/pir/pass/pass_manager.h" @@ -48,12 +50,14 @@ class IRPrinting : public PassInstrumentation { // TODO(liuyuanle): support print on change } - option_->PrintBeforeIfEnabled(pass, op, [&](std::ostream &os) { + option_->PrintBeforeIfEnabled(pass, op, [&]() { + std::ostringstream oss; std::string header = "IRPrinting on " + op->name() + " before " + pass->name() + " pass"; - detail::PrintHeader(header, os); - PrintIR(op, option_->print_module(), os); - os << "\n\n"; + detail::PrintHeader(header, oss); + PrintIR(op, option_->print_module(), oss); + oss << "\n\n"; + std::cout << oss.str(); }); } @@ -62,12 +66,14 @@ class IRPrinting : public PassInstrumentation { // TODO(liuyuanle): support print on change } - option_->PrintAfterIfEnabled(pass, op, [&](std::ostream &os) { + option_->PrintAfterIfEnabled(pass, op, [&]() { + std::ostringstream oss; std::string header = "IRPrinting on " + op->name() + " after " + pass->name() + " pass"; - detail::PrintHeader(header, os); - PrintIR(op, option_->print_module(), os); - os << "\n\n"; + detail::PrintHeader(header, oss); + PrintIR(op, option_->print_module(), oss); + oss << "\n\n"; + std::cout << oss.str(); }); } diff --git a/paddle/pir/pass/pass_manager.h b/paddle/pir/pass/pass_manager.h index 5ee383a57ae39..361782e36de52 100644 --- a/paddle/pir/pass/pass_manager.h +++ b/paddle/pir/pass/pass_manager.h @@ -15,11 +15,9 @@ #pragma once #include -#include #include #include -#include "paddle/pir/core/program.h" #include "paddle/pir/pass/pass.h" namespace pir { @@ -56,7 +54,7 @@ class IR_API PassManager { class IRPrinterOption { public: - using PrintCallBack = std::function; + using PrintCallBack = std::function; explicit IRPrinterOption( const std::function &enable_print_before = @@ -64,13 +62,11 @@ class IR_API PassManager { const std::function &enable_print_after = [](Pass *, Operation *) { return true; }, bool print_module = true, - bool print_on_change = true, - std::ostream &os = std::cout) + bool print_on_change = true) : enable_print_before_(enable_print_before), enable_print_after_(enable_print_after), print_module_(print_module), - print_on_change_(print_on_change), - os(os) { + print_on_change_(print_on_change) { assert((enable_print_before_ || enable_print_after_) && "expected at least one valid filter function"); } @@ -81,7 +77,7 @@ class IR_API PassManager { Operation *op, const PrintCallBack &print_callback) { if (enable_print_before_ && enable_print_before_(pass, op)) { - print_callback(os); + print_callback(); } } @@ -89,7 +85,7 @@ class IR_API PassManager { Operation *op, const PrintCallBack &print_callback) { if (enable_print_after_ && enable_print_after_(pass, op)) { - print_callback(os); + print_callback(); } } @@ -107,8 +103,6 @@ class IR_API PassManager { bool print_on_change_; - std::ostream &os; - // TODO(liuyuanle): Add flags to control printing behavior. }; diff --git a/paddle/pir/pass/pass_timing.cc b/paddle/pir/pass/pass_timing.cc index 354cd2cb83590..4cc932f0f4abd 100644 --- a/paddle/pir/pass/pass_timing.cc +++ b/paddle/pir/pass/pass_timing.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -63,7 +64,9 @@ class PassTimer : public PassInstrumentation { void RunAfterPipeline(Operation* op) override { pipeline_timers_[op].Stop(); - PrintTime(op, std::cout); + std::ostringstream oss; + PrintTime(op, oss); + std::cout << oss.str(); } void RunBeforePass(Pass* pass, Operation* op) override { diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc index 1a87247dab35b..59f10f241f2cd 100644 --- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc @@ -412,15 +412,15 @@ TEST(pattern_rewrite, Patterns) { pm.AddPass(pir::CreateDeadCodeEliminationPass()); // pm.EnablePassTiming(); pm.EnableIRPrinting(); - // pm.EnableIRPrinting(std::make_unique( - // [](pir::Pass *pass, pir::Operation *op) { - // return pass->name() == "constant_folding_pass"; - // }, - // [](pir::Pass *pass, pir::Operation *op) { - // return pass->name() == "constant_folding_pass"; - // }, - // true, - // true)); + // pm.EnableIRPrinting(std::make_unique( + // [](pir::Pass *pass, pir::Operation *op) { + // return pass->name() == "constant_folding_pass"; + // }, + // [](pir::Pass *pass, pir::Operation *op) { + // return pass->name() == "constant_folding_pass"; + // }, + // true, + // true)); CHECK_EQ(pm.Run(&program), true); EXPECT_EQ(program.block()->size(), 17u); From 628ff043e8d2f5ca9020123435d804cd14ca142e Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 10 Jan 2024 12:07:12 +0000 Subject: [PATCH 4/5] update --- paddle/fluid/framework/executor_cache.cc | 1 + paddle/fluid/inference/api/analysis_predictor.cc | 10 +++++++++- paddle/pir/pass/pass.h | 13 +++---------- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 97e4d386ea9aa..7e476f4fc506c 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -562,6 +562,7 @@ std::unique_ptr<::pir::Program> ConstructBackwardIrProgram( pm.AddPass(::pir::CreateInplacePass()); if (VLOG_IS_ON(6)) { pm.EnableIRPrinting(); + pm.EnablePrintStatistics(); } pm.Run(res.get()); if (FLAGS_print_ir) { diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 00995194ecf89..8261fee1517f6 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -844,8 +844,16 @@ bool AnalysisPredictor::PrepareExecutor() { if (FLAGS_pir_apply_inplace_pass) { lowered_pm.AddPass(::pir::CreateInplacePass()); } + if (!config_.glog_info_disabled()) { + lowered_pm.EnablePrintStatistics(); + } + if (config_.ir_debug_) { + lowered_pm.EnableIRPrinting(); + } lowered_pm.Run(pir_program_.get()); + LOG(INFO) << "======= pir optimization completed ======="; + executor_->PrepareInterpreterCore( sub_scope_, *pir_program_, execution_config); } else { @@ -1867,7 +1875,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { argument_.reset(nullptr); } #endif - LOG(INFO) << "======= optimize end ======="; + LOG(INFO) << "======= ir optimization completed ======="; } template <> diff --git a/paddle/pir/pass/pass.h b/paddle/pir/pass/pass.h index 3bce02040247a..6c2c565322bf8 100644 --- a/paddle/pir/pass/pass.h +++ b/paddle/pir/pass/pass.h @@ -139,8 +139,9 @@ class IR_API Pass { void Set(const std::string& attr_name, AttrType* attr) { VLOG(3) << "Setting the attribute " << attr_name << " for the pass " << name(); - IR_ENFORCE( - !Has(attr_name), "Attribute %s already set in the %s.", attr_name); + if (Has(attr_name)) { + Erase(attr_name); + } attrs_[attr_name] = attr; attr_dels_[attr_name] = [attr, attr_name]() { VLOG(8) << "deleting " << attr_name; @@ -166,23 +167,15 @@ class IR_API Pass { virtual bool Initialize(IrContext* context) { return true; } void AddStatistics(int64_t match_count) { - IR_ENFORCE(!Has("__match_count__"), - "Attribute __match_count__ already set in the pass."); Set("__match_count__", new int64_t{match_count}); } void AddStatistics(int64_t match_count, int64_t all_count) { - IR_ENFORCE(!Has("__match_count__"), - "Attribute __match_count__ already set in the pass."); - IR_ENFORCE(!Has("__all_count__"), - "Attribute __all_count__ already set in the pass."); Set("__match_count__", new int64_t{match_count}); Set("__all_count__", new int64_t{all_count}); } void AddStatistics(const std::string& custom_log) { - IR_ENFORCE(!Has("__custom_log__"), - "Attribute __custom_log__ already set in the pass."); Set("__custom_log__", new std::string{custom_log}); } From cbe339ede7dd78eb2f94d42605c22960459a726b Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Wed, 10 Jan 2024 14:00:35 +0000 Subject: [PATCH 5/5] fix compile --- paddle/common/macros.h | 7 +++++++ paddle/fluid/framework/details/gather_op_handle.cc | 2 +- paddle/fluid/framework/executor_cache.cc | 3 +++ paddle/fluid/framework/op_compatible_info.cc | 3 ++- paddle/fluid/memory/allocation/aligned_allocator.cc | 4 ++-- paddle/fluid/memory/allocation/best_fit_allocator.cc | 3 ++- paddle/fluid/memory/allocation/buffered_allocator.cc | 5 ++++- paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc | 3 ++- paddle/fluid/platform/init_phi.cc | 4 +++- paddle/fluid/platform/init_phi.h | 7 ------- paddle/fluid/pybind/pybind.cc | 2 +- paddle/pir/pass/ir_printing.cc | 8 ++++---- paddle/pir/pass/pass_timing.cc | 4 ++-- paddle/pir/pass/print_statistics.cc | 3 +++ test/cpp/new_executor/standalone_executor_pir_test.cc | 2 +- test/cpp/prim/test_vjp.cc | 2 +- 16 files changed, 38 insertions(+), 24 deletions(-) diff --git a/paddle/common/macros.h b/paddle/common/macros.h index 2d476c58cb6ae..683c1d3667e9f 100644 --- a/paddle/common/macros.h +++ b/paddle/common/macros.h @@ -86,4 +86,11 @@ namespace common { #endif // __FLT_MAX__ #endif // PADDLE_WITH_MUSL +#define REGISTER_FILE_SYMBOLS(name) \ + int RegisterSymbolsFor##name() { return 0; } + +#define DECLARE_FILE_SYMBOLS(name) \ + extern int RegisterSymbolsFor##name(); \ + UNUSED static int use_file_##name = RegisterSymbolsFor##name() + } // namespace common diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index 3731e562549f3..53e3807e53e8f 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -14,9 +14,9 @@ #include "paddle/fluid/framework/details/gather_op_handle.h" +#include "paddle/common/macros.h" #include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/variable_visitor.h" -#include "paddle/fluid/platform/init_phi.h" REGISTER_FILE_SYMBOLS(gather_op_handle); diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 7e476f4fc506c..a3686a4841adf 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/executor_cache.h" +#include "paddle/common/macros.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/ir_adaptor/translator/translate.h" @@ -25,6 +26,8 @@ #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" +DECLARE_FILE_SYMBOLS(print_statistics); + PHI_DECLARE_bool(pir_apply_inplace_pass); PHI_DECLARE_bool(print_ir); diff --git a/paddle/fluid/framework/op_compatible_info.cc b/paddle/fluid/framework/op_compatible_info.cc index 5a3c6189472fe..ba71043771ff2 100644 --- a/paddle/fluid/framework/op_compatible_info.cc +++ b/paddle/fluid/framework/op_compatible_info.cc @@ -14,11 +14,12 @@ #include "paddle/fluid/framework/op_compatible_info.h" -#include "paddle/fluid/platform/enforce.h" +#include "paddle/common/macros.h" #include "paddle/fluid/platform/init_phi.h" #include "paddle/fluid/string/string_helper.h" REGISTER_FILE_SYMBOLS(op_compatible_info); + namespace paddle { namespace framework { diff --git a/paddle/fluid/memory/allocation/aligned_allocator.cc b/paddle/fluid/memory/allocation/aligned_allocator.cc index 0c433c74a53a6..22382a2691bd0 100644 --- a/paddle/fluid/memory/allocation/aligned_allocator.cc +++ b/paddle/fluid/memory/allocation/aligned_allocator.cc @@ -14,11 +14,11 @@ #include "paddle/fluid/memory/allocation/aligned_allocator.h" +#include "paddle/common/macros.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/init_phi.h" - REGISTER_FILE_SYMBOLS(aligned_allocator); + namespace paddle { namespace memory { namespace allocation { diff --git a/paddle/fluid/memory/allocation/best_fit_allocator.cc b/paddle/fluid/memory/allocation/best_fit_allocator.cc index 2daa31e2a1ddb..337d5f0db7b80 100644 --- a/paddle/fluid/memory/allocation/best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/best_fit_allocator.cc @@ -17,8 +17,9 @@ #include #include +#include "paddle/common/macros.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/init_phi.h" + REGISTER_FILE_SYMBOLS(best_fit_allocator); namespace paddle { diff --git a/paddle/fluid/memory/allocation/buffered_allocator.cc b/paddle/fluid/memory/allocation/buffered_allocator.cc index 57a90bd7cd112..cbc79078397ea 100644 --- a/paddle/fluid/memory/allocation/buffered_allocator.cc +++ b/paddle/fluid/memory/allocation/buffered_allocator.cc @@ -13,8 +13,11 @@ // limitations under the License. #include "paddle/fluid/memory/allocation/buffered_allocator.h" -#include "paddle/fluid/platform/init_phi.h" + +#include "paddle/common/macros.h" + REGISTER_FILE_SYMBOLS(buffered_allocator); + namespace paddle { namespace memory { namespace allocation { diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc index 63e2a83a7dbe9..75698f3e9ccbe 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.cc @@ -13,11 +13,12 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" + #include "paddle/common/ddim.h" +#include "paddle/common/macros.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" -#include "paddle/fluid/platform/init_phi.h" #include "paddle/phi/common/place.h" #include "paddle/pir/core/ir_printer.h" diff --git a/paddle/fluid/platform/init_phi.cc b/paddle/fluid/platform/init_phi.cc index 306dbf9f6e5f3..1feab75bdb288 100644 --- a/paddle/fluid/platform/init_phi.cc +++ b/paddle/fluid/platform/init_phi.cc @@ -11,8 +11,10 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "paddle/fluid/platform/init_phi.h" -#include "glog/logging.h" + +#include "paddle/common/macros.h" #include "paddle/fluid/platform/init.h" REGISTER_FILE_SYMBOLS(init_phi) diff --git a/paddle/fluid/platform/init_phi.h b/paddle/fluid/platform/init_phi.h index 80c1d6242e545..a5fd3a6c3567a 100644 --- a/paddle/fluid/platform/init_phi.h +++ b/paddle/fluid/platform/init_phi.h @@ -23,11 +23,4 @@ class PADDLE_API InitPhi { InitPhi(); }; -#define REGISTER_FILE_SYMBOLS(name) \ - int RegisterSymbolsFor##name() { return 0; } - -#define DECLARE_FILE_SYMBOLS(name) \ - extern int RegisterSymbolsFor##name(); \ - UNUSED static int use_file_##name = RegisterSymbolsFor##name() - } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 32e9ffd3a5c63..35b3b613d7bfc 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -81,6 +81,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/memory/allocation/cuda_ipc_allocator.h" #endif +#include "paddle/common/macros.h" #include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/common_infer_shape_functions.h" @@ -92,7 +93,6 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/init.h" -#include "paddle/fluid/platform/init_phi.h" #include "paddle/fluid/platform/monitor.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" diff --git a/paddle/pir/pass/ir_printing.cc b/paddle/pir/pass/ir_printing.cc index 900f52c3a250e..710e627023a78 100644 --- a/paddle/pir/pass/ir_printing.cc +++ b/paddle/pir/pass/ir_printing.cc @@ -56,8 +56,8 @@ class IRPrinting : public PassInstrumentation { "IRPrinting on " + op->name() + " before " + pass->name() + " pass"; detail::PrintHeader(header, oss); PrintIR(op, option_->print_module(), oss); - oss << "\n\n"; - std::cout << oss.str(); + oss << "\n"; + std::cout << oss.str() << std::endl; }); } @@ -72,8 +72,8 @@ class IRPrinting : public PassInstrumentation { "IRPrinting on " + op->name() + " after " + pass->name() + " pass"; detail::PrintHeader(header, oss); PrintIR(op, option_->print_module(), oss); - oss << "\n\n"; - std::cout << oss.str(); + oss << "\n"; + std::cout << oss.str() << std::endl; }); } diff --git a/paddle/pir/pass/pass_timing.cc b/paddle/pir/pass/pass_timing.cc index 4cc932f0f4abd..383183eceaaa7 100644 --- a/paddle/pir/pass/pass_timing.cc +++ b/paddle/pir/pass/pass_timing.cc @@ -19,7 +19,7 @@ #include #include -#include "paddle/fluid/platform/init_phi.h" +#include "paddle/common/macros.h" #include "paddle/pir/core/operation.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_instrumentation.h" @@ -66,7 +66,7 @@ class PassTimer : public PassInstrumentation { pipeline_timers_[op].Stop(); std::ostringstream oss; PrintTime(op, oss); - std::cout << oss.str(); + std::cout << oss.str() << std::endl; } void RunBeforePass(Pass* pass, Operation* op) override { diff --git a/paddle/pir/pass/print_statistics.cc b/paddle/pir/pass/print_statistics.cc index 127da5f90ed88..7569538df57e9 100644 --- a/paddle/pir/pass/print_statistics.cc +++ b/paddle/pir/pass/print_statistics.cc @@ -12,12 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/common/macros.h" #include "paddle/pir/core/operation.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_instrumentation.h" #include "paddle/pir/pass/pass_manager.h" #include "paddle/utils/string/pretty_log.h" +REGISTER_FILE_SYMBOLS(print_statistics); + namespace pir { class PrintStatistics : public PassInstrumentation { diff --git a/test/cpp/new_executor/standalone_executor_pir_test.cc b/test/cpp/new_executor/standalone_executor_pir_test.cc index 60e589bcf4b78..9f655979be601 100644 --- a/test/cpp/new_executor/standalone_executor_pir_test.cc +++ b/test/cpp/new_executor/standalone_executor_pir_test.cc @@ -33,7 +33,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/fluid/platform/init_phi.h" +#include "paddle/common/macros.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" #include "paddle/pir/dialect/control_flow/ir/cf_op.h" diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index f9393bf6b9f54..ecfb38c01438e 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -14,6 +14,7 @@ #include +#include "paddle/common/macros.h" #include "paddle/fluid/framework/new_executor/pir_interpreter.h" #include "paddle/fluid/framework/new_executor/standalone_executor.h" #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" @@ -22,7 +23,6 @@ #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" -#include "paddle/fluid/platform/init_phi.h" #include "paddle/pir/core/block.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_op.h"