From 04bcbb227f4cfb4af790c2de1afbc1e194e1499d Mon Sep 17 00:00:00 2001 From: co63oc Date: Mon, 4 Aug 2025 10:24:04 +0800 Subject: [PATCH 1/2] Fix --- paddle/phi/kernels/onednn/conv_kernel.cc | 5 ++++ .../kernels/onednn/conv_transpose_kernel.cc | 24 ++++++++++++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/onednn/conv_kernel.cc b/paddle/phi/kernels/onednn/conv_kernel.cc index 0007c717a4d9db..1968bfe8154a03 100644 --- a/paddle/phi/kernels/onednn/conv_kernel.cc +++ b/paddle/phi/kernels/onednn/conv_kernel.cc @@ -42,6 +42,11 @@ void ConvKernel(const Context& dev_ctx, dev_ctx.GetDnnAttr("mkldnn_data_type")) == "bfloat16" : false; + is_BFLOAT16 = dev_ctx.HasDnnAttr("onednn_data_type") + ? PADDLE_GET_CONST( + std::string, + dev_ctx.GetDnnAttr("onednn_data_type")) == "bfloat16" + : is_BFLOAT16; bool force_fp32_output = dev_ctx.HasDnnAttr("force_fp32_output") ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) diff --git a/paddle/phi/kernels/onednn/conv_transpose_kernel.cc b/paddle/phi/kernels/onednn/conv_transpose_kernel.cc index 31b08c59d18f0a..305576ad168d6b 100644 --- a/paddle/phi/kernels/onednn/conv_transpose_kernel.cc +++ b/paddle/phi/kernels/onednn/conv_transpose_kernel.cc @@ -157,7 +157,13 @@ class ConvTransposeOneDNNHandlerT dev_ctx.GetDnnAttr("mkldnn_data_type")) == "bfloat16" : false; - if (is_BFLOAT16 || std::is_same::value) { + const bool is_onednn_BFLOAT16 = + dev_ctx.HasDnnAttr("onednn_data_type") + ? PADDLE_GET_CONST(std::string, + dev_ctx.GetDnnAttr("onednn_data_type")) == + "bfloat16" + : is_BFLOAT16; + if (is_onednn_BFLOAT16 || std::is_same::value) { data_type = dnnl::memory::data_type::bf16; } @@ -494,11 +500,17 @@ void Conv2dTransposeKernel(const Context& dev_ctx, dev_ctx.GetDnnAttr("mkldnn_data_type")) == "bfloat16" : false; + const bool is_onednn_BFLOAT16 = + dev_ctx.HasDnnAttr("onednn_data_type") + ? PADDLE_GET_CONST(std::string, + dev_ctx.GetDnnAttr("onednn_data_type")) == + "bfloat16" + : is_BFLOAT16; const bool force_fp32_output = dev_ctx.HasDnnAttr("force_fp32_output") ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) : false; - const bool use_bfloat16 = (!force_fp32_output && is_BFLOAT16); + const bool use_bfloat16 = (!force_fp32_output && is_onednn_BFLOAT16); if (use_bfloat16) { Execute(dev_ctx, @@ -545,11 +557,17 @@ void Conv2dTransposeBiasKernel(const Context& dev_ctx, dev_ctx.GetDnnAttr("mkldnn_data_type")) == "bfloat16" : false; + const bool is_one_BFLOAT16 = + dev_ctx.HasDnnAttr("onednn_data_type") + ? PADDLE_GET_CONST(std::string, + dev_ctx.GetDnnAttr("onednn_data_type")) == + "bfloat16" + : is_BFLOAT16; const bool force_fp32_output = dev_ctx.HasDnnAttr("force_fp32_output") ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) : false; - const bool use_bfloat16 = (!force_fp32_output && is_BFLOAT16); + const bool use_bfloat16 = (!force_fp32_output && is_one_BFLOAT16); if (use_bfloat16) { Execute(dev_ctx, From 152da224c62df6238dfe308dae77f9acef8eddc2 Mon Sep 17 00:00:00 2001 From: co63oc Date: Mon, 4 Aug 2025 11:55:27 +0800 Subject: [PATCH 2/2] Fix --- paddle/phi/kernels/onednn/conv_kernel.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/onednn/conv_kernel.cc b/paddle/phi/kernels/onednn/conv_kernel.cc index 1968bfe8154a03..313c9171924080 100644 --- a/paddle/phi/kernels/onednn/conv_kernel.cc +++ b/paddle/phi/kernels/onednn/conv_kernel.cc @@ -42,11 +42,12 @@ void ConvKernel(const Context& dev_ctx, dev_ctx.GetDnnAttr("mkldnn_data_type")) == "bfloat16" : false; - is_BFLOAT16 = dev_ctx.HasDnnAttr("onednn_data_type") - ? PADDLE_GET_CONST( - std::string, - dev_ctx.GetDnnAttr("onednn_data_type")) == "bfloat16" - : is_BFLOAT16; + bool is_onednn_BFLOAT16 = + dev_ctx.HasDnnAttr("onednn_data_type") + ? PADDLE_GET_CONST(std::string, + dev_ctx.GetDnnAttr("onednn_data_type")) == + "bfloat16" + : is_BFLOAT16; bool force_fp32_output = dev_ctx.HasDnnAttr("force_fp32_output") ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) @@ -64,7 +65,7 @@ void ConvKernel(const Context& dev_ctx, groups, data_format, is_test, - is_BFLOAT16, + is_onednn_BFLOAT16, "", false, force_fp32_output,