Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More precise mkldnn kernel rules in GetExpectedKernelType #29840

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1040,21 +1040,23 @@ static void CheckTensorNANOrInf(const std::string& op_type,
op_type, name));
}

bool OperatorWithKernel::SupportsMKLDNN() const {
bool OperatorWithKernel::SupportsMKLDNN(
const proto::VarType::Type data_type) const {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) {
[data_type](OpKernelMap::const_reference kern_pair) {
return platform::is_cpu_place(kern_pair.first.place_) &&
kern_pair.first.library_type_ ==
LibraryType::kMKLDNN;
LibraryType::kMKLDNN &&
kern_pair.first.data_type_ == data_type;
});
}

bool OperatorWithKernel::CanMKLDNNBeUsed(
const framework::ExecutionContext& ctx) const {
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
bool use_mkldnn_ctx =
ctx.Attr<bool>("use_mkldnn") && platform::is_cpu_place(ctx.GetPlace());
return use_mkldnn_ctx && this->SupportsMKLDNN();
return use_mkldnn_ctx && this->SupportsMKLDNN(data_type);
}

void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,6 @@ class OperatorBase {

virtual bool SupportGPU() const { return false; }

virtual bool SupportsMKLDNN() const { return false; }

const std::string& Type() const { return type_; }

bool HasAttr(const std::string& name) const { return attrs_.count(name); }
Expand Down Expand Up @@ -492,9 +490,10 @@ class OperatorWithKernel : public OperatorBase {
return platform::is_gpu_place(kern_pair.first.place_);
});
}
bool SupportsMKLDNN() const override;
bool SupportsMKLDNN(proto::VarType::Type data_type) const;

bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) const;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const;

virtual void InferShape(InferShapeContext* ctx) const = 0;

Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = oper.IndicateVarDataType(ctx, name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK the resulting log from the benchmark machine has checks about GPU ops, in the log there is no oneDNN verbose info, so it looks that oneDNN kernels are not run, which would eventually be correlated with the PR.
Could you advise on that, please?

Although op-benchmark-ci checks about GPU ops, does line 96 have additional time cost? How about move it into line 109?

if (library == framework::LibraryType::kPlain && it != oper.Attrs().end()) {
  auto data_type = oper.IndicateVarDataType(ctx, name);
  if (oper.CanMKLDNNBeUsed(ctx, data_type)) {
    xxxx
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luotao1 Thank you for the comment. The data_type variable is used also in line 119:
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);

So it will be needed/calculated at the end of the function, whether the condition you mentioned is True or not (despite mkldnn is to be used or not).

In the way it's implemented in PR, the variable value is calculated only once per function as it was prior to my changes, without additional time cost.
It applies to every occurrence in other op files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wojtuss Thank you for the comment. Having in mind baby sitting the PR-CI-OP-benchmark for more than a week on the same code, I prefer to not refactor the code in the PR. Of course if it's OK with you because I respect you opinion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the way it's implemented in PR, the variable value is calculated only once per function as it was prior to my changes, without additional time cost.

Got it.

I prefer to not refactor the code in the PR. Of course if it's OK with you because I respect you opinion.

@wojtuss What's your opinion?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally understandable.

// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
Expand All @@ -106,13 +107,12 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
#ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
oper.CanMKLDNNBeUsed(ctx)) {
oper.CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(oper.IndicateVarDataType(ctx, name),
ctx.GetPlace(), layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}

class ActivationOp : public framework::OperatorWithKernel {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/addmm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class AddMMOp : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;

Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/operators/batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType(
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
Expand Down Expand Up @@ -524,17 +525,17 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif

return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}

framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class ConcatOp : public framework::OperatorWithKernel {
"All Inputs of Concat OP are Empty!"));
}
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/operators/conv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) {
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
customized_type_value =
Expand Down Expand Up @@ -556,6 +557,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");

#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
Expand All @@ -564,17 +566,16 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
const std::string data_format = ctx.Attr<std::string>("data_format");
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
customized_type_value = kConvMKLDNNFP32;
}
#endif

auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_, customized_type_value);
auto type = framework::OpKernelType(data_type, ctx.GetPlace(), layout_,
library_, customized_type_value);
return type;
}

Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/conv_transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
Expand All @@ -193,15 +194,13 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif

return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_);
}

framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
Expand Down
9 changes: 4 additions & 5 deletions paddle/fluid/operators/data_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class DataNormOp : public framework::OperatorWithKernel {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
Expand Down Expand Up @@ -483,18 +483,17 @@ class DataNormGradOp : public framework::OperatorWithKernel {
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif

return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detection/prior_box_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, input_input_type)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
auto input_image_type = ctx.Input<framework::Tensor>("Image")->type();
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/elementwise/elementwise_div_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out");

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/elementwise/elementwise_mul_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ElementwiseMulOp : public ElementwiseOp {
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/operators/elementwise/elementwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down Expand Up @@ -280,8 +280,9 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
return (ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims());
};

if (this->CanMKLDNNBeUsed(ctx) && (ctx.Type() != "elementwise_add_grad" ||
CanMKLDNNElementwiseAddGradBeUsed())) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type) &&
(ctx.Type() != "elementwise_add_grad" ||
CanMKLDNNElementwiseAddGradBeUsed())) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down Expand Up @@ -331,7 +332,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down Expand Up @@ -384,7 +385,7 @@ class ElementwiseOpDoubleGradWithoutDXDY
}

#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/fused/fusion_gru_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,14 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx)) {
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}

void FusionGRUOpMaker::Make() {
Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/operators/gaussian_random_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,19 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));

#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx)) {
this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif

return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.device_context(), layout, library);
return framework::OpKernelType(data_type, ctx.device_context(), layout,
library);
}

framework::OpKernelType GetKernelTypeForVar(
Expand Down
14 changes: 6 additions & 8 deletions paddle/fluid/operators/gelu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,16 @@ class GeluOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) {
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
};

Expand Down Expand Up @@ -86,17 +85,16 @@ class GeluGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
auto it = this->Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain &&
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) {
it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
};

Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/interpolate_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,20 +322,19 @@ class InterpolateOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
framework::LibraryType library = framework::LibraryType::kPlain;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");

#ifdef PADDLE_WITH_MKLDNN
auto interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx) &&
if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) {
layout = framework::DataLayout::kMKLDNN;
library = framework::LibraryType::kMKLDNN;
}
#endif

return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}

framework::OpKernelType GetKernelTypeForVar(
Expand Down
Loading