Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions paddle/fluid/inference/api/analysis_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -665,32 +665,48 @@ void AnalysisConfig::EnableCUDNN() {
Update();
}

void AnalysisConfig::EnableMKLDNN() {
LOG(WARNING) << ONEDNN_UPDATE_WARNING(EnableONEDNN);
EnableONEDNN();
}
void AnalysisConfig::EnableONEDNN() {
#ifdef PADDLE_WITH_DNNL
use_onednn_ = true;
#else
LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN";
LOG(ERROR) << "Please compile with ONEDNN first to use ONEDNN";
use_onednn_ = false;
#endif

Update();
}

void AnalysisConfig::DisableMKLDNN() {
LOG(WARNING) << ONEDNN_UPDATE_WARNING(DisableONEDNN);
DisableONEDNN();
}
void AnalysisConfig::DisableONEDNN() {
use_onednn_ = false;
Update();
}

void AnalysisConfig::SetMkldnnCacheCapacity(int capacity) {
LOG(WARNING) << ONEDNN_UPDATE_WARNING(SetOnednnCacheCapacity);
SetOnednnCacheCapacity(capacity);
}
void AnalysisConfig::SetOnednnCacheCapacity(int capacity) {
#ifdef PADDLE_WITH_DNNL
onednn_cache_capacity_ = capacity;
#else
LOG(ERROR) << "Please compile with MKLDNN first to set MKLDNN Thread Id";
LOG(ERROR) << "Please compile with ONEDNN first to set ONEDNN Thread Id";
onednn_cache_capacity_ = 0;
#endif
}

void AnalysisConfig::EnableMkldnnBfloat16() {
LOG(WARNING) << ONEDNN_UPDATE_WARNING(EnableOnednnBfloat16);
EnableOnednnBfloat16();
}
void AnalysisConfig::EnableOnednnBfloat16() {
#ifdef PADDLE_WITH_DNNL
if (phi::backends::cpu::MayIUse(phi::backends::cpu::cpu_isa_t::avx512_core)) {
use_onednn_bfloat16_ = true;
Expand All @@ -704,32 +720,41 @@ void AnalysisConfig::EnableMkldnnBfloat16() {
use_onednn_bfloat16_ = false;
}
#else
LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnBfloat16";
LOG(ERROR) << "Please compile with ONEDNN first to use OnednnBfloat16";
use_onednn_bfloat16_ = false;
#endif

Update();
}

void AnalysisConfig::DisableMkldnnFcPasses() {
LOG(WARNING) << ONEDNN_UPDATE_WARNING(DisableOnednnFcPasses);
DisableOnednnFcPasses();
}
void AnalysisConfig::DisableOnednnFcPasses() {
#ifdef PADDLE_WITH_DNNL
disable_onednn_fc_passes_ = true;
#else
LOG(ERROR) << "Please compile with MKLDNN first to use DisableOnednnFcPasses";
LOG(ERROR) << "Please compile with ONEDNN first to use DisableOnednnFcPasses";
disable_onednn_fc_passes_ = false;
#endif
Update();
}

void AnalysisConfig::EnableMkldnnInt8(
const std::unordered_set<std::string> &op_list) {
LOG(WARNING) << ONEDNN_UPDATE_WARNING(EnableOnednnInt8);
EnableOnednnInt8(op_list);
}
void AnalysisConfig::EnableOnednnInt8(
const std::unordered_set<std::string> &op_list) {
#ifdef PADDLE_WITH_DNNL
use_onednn_int8_ = true;
use_fc_padding_ = false;
if (!op_list.empty())
quantize_enabled_op_types_.insert(op_list.begin(), op_list.end());
#else
LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnInt8";
LOG(ERROR) << "Please compile with ONEDNN first to use OnednnInt8";
use_onednn_int8_ = false;
#endif

Expand Down Expand Up @@ -970,7 +995,7 @@ void AnalysisConfig::Update() {
// Since EnableONEDNN is default, the pass_builder has created in the first
// time.
// Case1: User manually disable onednn after pass_builder
// create.(config.disable_mkldnn())
// create.(config.disable_onednn())
// Case2: User device is gpu/ipu/xpu, use
// EnableXpu(), EnableCUDNN(), PassStrategy has been reset in the above code
// block
Expand Down Expand Up @@ -1040,20 +1065,20 @@ void AnalysisConfig::Update() {

if (use_onednn_bfloat16_) {
#ifdef PADDLE_WITH_DNNL
pass_builder()->EnableMkldnnBfloat16();
pass_builder()->EnableOnednnBfloat16();
#endif
}

if (use_onednn_int8_) {
#ifdef PADDLE_WITH_DNNL
if (!enable_ir_optim_) {
LOG(ERROR) << "EnableMkldnnInt8() only works when IR optimization "
LOG(ERROR) << "EnableOnednnInt8() only works when IR optimization "
"is enabled.";
} else if (!use_onednn_) {
LOG(ERROR) << "EnableMkldnnInt8() only works when MKLDNN "
LOG(ERROR) << "EnableOnednnInt8() only works when ONEDNN "
"is enabled.";
} else {
pass_builder()->EnableMkldnnInt8();
pass_builder()->EnableOnednnInt8();
}
#endif
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
pass_pm.AddPass(pir::PassRegistry::Instance().Get(mkldnn_pass));
}
}
if (config_.mkldnn_bfloat16_enabled()) {
if (config_.onednn_bfloat16_enabled()) {
for (const auto &mkldnn_pass : kPirMkldnnBf16Passes) {
if (std::find(config_.deleted_passes_.begin(),
config_.deleted_passes_.end(),
Expand Down Expand Up @@ -2110,7 +2110,7 @@ void AnalysisPredictor::PrepareArgument() {
}

#ifdef PADDLE_WITH_DNNL
if (config_.mkldnn_bfloat16_enabled()) {
if (config_.onednn_bfloat16_enabled()) {
LOG(INFO) << "Bfloat16 is enabled";
argument_->SetBfloat16EnabledOpTypes(config_.bfloat16_enabled_op_types_);
}
Expand Down
58 changes: 52 additions & 6 deletions paddle/fluid/inference/api/paddle_analysis_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ struct PD_INFER_DECL AnalysisConfig {
///
/// \return bool Whether to use the OneDNN.
///
bool mkldnn_enabled() const { return use_onednn_; }
bool mkldnn_enabled() const { return use_onednn_; } // deprecated

///
/// \brief Turn on OneDNN.
Expand All @@ -916,6 +916,13 @@ struct PD_INFER_DECL AnalysisConfig {
///
void SetOnednnCacheCapacity(int capacity);

///
/// \brief A boolean state telling whether to use the OneDNN.
///
/// \return bool Whether to use the OneDNN.
///
bool onednn_enabled() const { return use_onednn_; }

///
/// \brief Set the number of cpu math library threads.
///
Expand Down Expand Up @@ -961,26 +968,47 @@ struct PD_INFER_DECL AnalysisConfig {
///
/// \param op_list The operator type list.
///
void EnableMkldnnInt8(const std::unordered_set<std::string>& op_list = {});
void EnableMkldnnInt8(
const std::unordered_set<std::string>& op_list = {}); // deprecated

///
/// \brief A boolean state telling whether to use the OneDNN Int8.
///
/// \return bool Whether to use the OneDNN Int8.
///
bool mkldnn_int8_enabled() const { return use_onednn_int8_; }
bool mkldnn_int8_enabled() const { return use_onednn_int8_; } // deprecated

///
/// \brief Turn on OneDNN bfloat16.
///
///
void EnableMkldnnBfloat16();
void EnableMkldnnBfloat16(); // deprecated

///
/// \brief Turn off OneDNN fc passes.
///
void DisableMkldnnFcPasses(); // deprecated

///
/// \brief Turn on OneDNN int8.
///
/// \param op_list The operator type list.
///
void EnableOnednnInt8(const std::unordered_set<std::string>& op_list = {});

///
/// \brief A boolean state telling whether to use the OneDNN Int8.
///
/// \return bool Whether to use the OneDNN Int8.
///
bool onednn_int8_enabled() const { return use_onednn_int8_; }

///
/// \brief Turn on OneDNN bfloat16.
///
///
void EnableOnednnBfloat16();

///
/// \brief Turn off OneDNN fc passes.
///
Expand All @@ -991,14 +1019,32 @@ struct PD_INFER_DECL AnalysisConfig {
///
/// \return bool Whether to disable the OneDNN Fc passes.
///
bool mkldnn_fc_passes_disabled() const { return disable_onednn_fc_passes_; }
bool mkldnn_fc_passes_disabled() const {
return disable_onednn_fc_passes_;
} // deprecated

///
/// \brief A boolean state telling whether to use the OneDNN Bfloat16.
///
/// \return bool Whether to use the OneDNN Bfloat16.
///
bool mkldnn_bfloat16_enabled() const {
return use_onednn_bfloat16_;
} // deprecated

///
/// \brief A boolean state telling whether to disable the OneDNN Fc passes.
///
/// \return bool Whether to disable the OneDNN Fc passes.
///
bool onednn_fc_passes_disabled() const { return disable_onednn_fc_passes_; }

///
/// \brief A boolean state telling whether to use the OneDNN Bfloat16.
///
/// \return bool Whether to use the OneDNN Bfloat16.
///
bool mkldnn_bfloat16_enabled() const { return use_onednn_bfloat16_; }
bool onednn_bfloat16_enabled() const { return use_onednn_bfloat16_; }

/// \brief Specify the operator type list to use Bfloat16 acceleration.
///
Expand Down
22 changes: 21 additions & 1 deletion paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,18 @@ void GpuPassStrategy::EnableONEDNN() {
}

void GpuPassStrategy::EnableMkldnnBfloat16() {
LOG(WARNING) << ONEDNN_UPDATE_WARNING(EnableOnednnBfloat16);
EnableOnednnBfloat16();
}
void GpuPassStrategy::EnableOnednnBfloat16() {
LOG(ERROR) << "GPU not support MKL-DNN bfloat16";
}

void GpuPassStrategy::EnableMkldnnInt8() {
LOG(WARNING) << ONEDNN_UPDATE_WARNING(EnableOnednnInt8);
EnableOnednnInt8();
}
void GpuPassStrategy::EnableOnednnInt8() {
LOG(ERROR) << "GPU not support MKL-DNN int8";
}

Expand Down Expand Up @@ -411,6 +419,10 @@ void CpuPassStrategy::DisableONEDNN() {
}

void CpuPassStrategy::EnableMkldnnBfloat16() {
LOG(WARNING) << ONEDNN_UPDATE_WARNING(EnableOnednnBfloat16);
EnableOnednnBfloat16();
}
void CpuPassStrategy::EnableOnednnBfloat16() {
#ifdef PADDLE_WITH_DNNL
if (!use_onednn_bfloat16_) {
passes_.emplace_back("fc_onednn_pass");
Expand All @@ -427,6 +439,10 @@ void CpuPassStrategy::EnableMkldnnBfloat16() {
}

void CpuPassStrategy::EnableMkldnnInt8() {
LOG(WARNING) << ONEDNN_UPDATE_WARNING(EnableOnednnInt8);
EnableOnednnInt8();
}
void CpuPassStrategy::EnableOnednnInt8() {
#ifdef PADDLE_WITH_DNNL
if (!use_onednn_int8_) {
passes_.clear();
Expand Down Expand Up @@ -498,7 +514,7 @@ void CpuPassStrategy::DisableMkldnnFcPasses() {
void CpuPassStrategy::DisableOnednnFcPasses() {
#ifdef PADDLE_WITH_DNNL
if (!disable_onednn_fc_passes_) {
EraseFcMkldnnPasses();
EraseFcOnednnPasses();
}
disable_onednn_fc_passes_ = true;
#else
Expand All @@ -507,6 +523,10 @@ void CpuPassStrategy::DisableOnednnFcPasses() {
}

void CpuPassStrategy::EraseFcMkldnnPasses() {
LOG(WARNING) << ONEDNN_UPDATE_WARNING(EraseFcMkldnnPasses);
EraseFcMkldnnPasses();
}
void CpuPassStrategy::EraseFcOnednnPasses() {
std::vector<std::string> fc_passes_to_erase(
{"fc_onednn_pass", "fc_act_onednn_fuse_pass"});
for (const auto &pass : fc_passes_to_erase) {
Expand Down
Loading