Skip to content

Commit

Permalink
merge bugfix #48364
Browse files Browse the repository at this point in the history
  • Loading branch information
Silv3S committed Nov 25, 2022
1 parent fe384a1 commit 6328e85
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 6 deletions.
27 changes: 21 additions & 6 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3230,15 +3230,30 @@ void OperatorWithKernel::BuildPhiKernelContext(
}
VLOG(4) << "Done attributes";

// Clear All old attrs before add new attrs,
// because sometimes old attrs may be misused.
#if defined(PADDLE_WITH_MKLDNN)
phi::OneDNNContext* one_dnn_ctx = static_cast<phi::OneDNNContext*>(dev_ctx);
std::string empty_string = "";
one_dnn_ctx->SetDnnAttr("fuse_activation", empty_string);
one_dnn_ctx->SetDnnAttr("fuse_alpha", 0.0f);
one_dnn_ctx->SetDnnAttr("fuse_beta", 0.0f);
one_dnn_ctx->SetDnnAttr("fused_output_scale", 1.0f);
if (phi::OneDNNContext::classof(dev_ctx)) {
phi::OneDNNContext* one_dnn_ctx = static_cast<phi::OneDNNContext*>(dev_ctx);
one_dnn_ctx->ClearDnnAttr();
}
#endif

// Note(YuanRisheng): Now, we can't open code below.
// Because some unittest run OLD dygraph and ExtraAttr is not supported in OLD
// dygraph. So, here we use trick that dev_ctx is a global object. We can
// store ExtraAttr in static graph and when unittest run OLD dygraph, it can
// obtain these ExtraAttr. We can open this code when OLD dygraph is no longer
// used.
/*
#if defined(PADDLE_WITH_CUDA)
if(phi::GPUContext::classof(dev_ctx)) {
phi::GPUContext* gpu_dnn_ctx = static_cast<phi::GPUContext*>(dev_ctx);
gpu_dnn_ctx->ClearDnnAttr();
}
#endif
*/

// For compatible with Op with extra attrs for specific backend
#if defined(PADDLE_WITH_MKLDNN) || defined(PADDLE_WITH_CUDA)
auto& runtime_attrs = RuntimeAttrs();
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/backends/gpu/gpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,8 @@ struct GPUContext::Impl {
dnn_attrs_[attr_name] = attr;
}

void ClearDnnAttr() { dnn_attrs_.clear(); }

// use one flag for all handles?
// they should be accessed consistently
bool owned_{false};
Expand Down Expand Up @@ -1042,4 +1044,6 @@ void GPUContext::SetDnnAttr(const std::string& attr_name, Attribute attr) {
return impl_->SetDnnAttr(attr_name, std::move(attr));
}

void GPUContext::ClearDnnAttr() { return impl_->ClearDnnAttr(); }

} // namespace phi
1 change: 1 addition & 0 deletions paddle/phi/backends/gpu/gpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class PADDLE_API GPUContext : public DeviceContext,
bool HasDnnAttr(const std::string& attr_name) const;
const Attribute& GetDnnAttr(const std::string& attr_name) const;
void SetDnnAttr(const std::string& attr_name, Attribute attr);
void ClearDnnAttr();

static const char* name() { return "GPUContext"; }

Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/backends/onednn/onednn_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ struct OneDNNContext::Impl {
dnn_attrs_[attr_name] = attr;
}

void ClearDnnAttr() { dnn_attrs_.clear(); }

bool HasDnnInput(const std::string& input_name) const {
return dnn_inputs_.count(input_name) != 0UL;
}
Expand Down Expand Up @@ -429,6 +431,8 @@ bool OneDNNContext::HasDnnInput(const std::string& input_name) const {
return impl_->HasDnnInput(input_name);
}

void OneDNNContext::ClearDnnAttr() { return impl_->ClearDnnAttr(); }

const DenseTensor* OneDNNContext::GetDnnInput(
const std::string& input_name) const {
return impl_->GetDnnInput(input_name);
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/backends/onednn/onednn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ class OneDNNContext : public CPUContext {
const DenseTensor* GetDnnInput(const std::string& input_name) const;
void SetDnnInput(const std::string& input_name, const DenseTensor* input);

void ClearDnnAttr();

void SetInputsName(const TensorNameMap& inputs_name);

void SetOutputsName(const TensorNameMap& outputs_name);
Expand Down

0 comments on commit 6328e85

Please sign in to comment.