Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] Support 0-D tensor for some oneDNN unary kernels #51687

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions paddle/fluid/framework/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ void TransformData(const phi::KernelKey &expected_kernel_type,
if (lin != DataLayout::ONEDNN && lout == DataLayout::ONEDNN) {
// Case1 - transform from Non-ONEDNN OPKernel to ONEDNN OPKernel
// Just set layout/format. No real transform occur

auto out_format = phi::funcs::OneDNNFormatForSize(
in.dims().size(), phi::funcs::ToOneDNNFormat(lin));
out.ShareDataWith(input_tensor);
// For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order
Expand All @@ -72,10 +69,9 @@ void TransformData(const phi::KernelKey &expected_kernel_type,
// NHWC or NCHW
phi::OneDNNContext::tls().set_cur_paddle_data_layout(lin);
}
dnnl::memory::desc out_mem_desc(
vectorize(out.dims()),
phi::funcs::ToOneDNNDataType(in.dtype()),
out_format);

dnnl::memory::desc out_mem_desc =
phi::funcs::make_memory_desc(out, lin);
out.set_mem_desc(out_mem_desc);
} else {
// Case2 - transfrom from ONEDNN OPKernel to Non-ONEDNN OPKernel
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/backends/onednn/onednn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ void* to_void_cast(const Type* t) {

inline OneDNNMemoryFormat OneDNNFormatForSize(size_t dims_size,
OneDNNMemoryFormat data_format) {
if (dims_size == 1) {
if (dims_size == 0) {
return OneDNNMemoryFormat::x;
} else if (dims_size == 1) {
return OneDNNMemoryFormat::x;
} else if (dims_size == 2) {
return OneDNNMemoryFormat::nc;
Expand Down
15 changes: 11 additions & 4 deletions paddle/phi/backends/onednn/onednn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,8 @@ class SoftmaxOneDNNHandler
errors::InvalidArgument(
"The shape of input and output tensor must be identical."));

const int canonical_axis = funcs::CanonicalAxis(axis, x->dims().size());
int rank = x->dims().size() != 0 ? x->dims().size() : 1;
const int canonical_axis = funcs::CanonicalAxis(axis, rank);
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_scoring, x->mem_desc(), canonical_axis);
}
Expand All @@ -790,8 +791,8 @@ class SoftmaxOneDNNHandler
dnnl::softmax_forward,
dnnl::softmax_backward>(onednn_engine,
cpu_place) {
const int canonical_axis =
funcs::CanonicalAxis(axis, out_grad->dims().size());
int rank = out_grad->dims().size() != 0 ? out_grad->dims().size() : 1;
const int canonical_axis = funcs::CanonicalAxis(axis, rank);
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_scoring, out->mem_desc(), canonical_axis);
this->AcquireBackwardPrimitiveDescriptor(
Expand Down Expand Up @@ -1646,7 +1647,13 @@ class SoftplusOneDNNHandler : public OneDNNHandlerNoCachingT<T, dnnl::binary> {
dnnl::primitive_attr attrs;
attrs.set_post_ops(post_ops);

auto x_tz = phi::vectorize(x->dims());
// if x is a 0-D tensor, then:
// x->dims() is [] and x->mem_desc().dims() is [1], we should use
// the later shape since oneDNN doesn't support 0-D shape.
// else, then:
// x->dims() == x->mem_desc().dims()
// so, we can directly use x->mem_desc().dims() here
auto x_tz = x->mem_desc().dims();
auto beta_tz = std::vector<int64_t>(x_tz.size(), 1);
auto beta_md = dnnl::memory::desc(
beta_tz, OneDNNGetDataType<T>(), GetPlainOneDNNFormat(x_tz.size()));
Expand Down
37 changes: 24 additions & 13 deletions paddle/phi/kernels/funcs/data_layout_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ void* GetDataFromTensor(const DenseTensor& tensor,
}
}

// This helper function is used to construct a dnnl memory descriptor from a
// reference dense tensor and a target layout. For 0-D tensor case, we will
// construct a 1-D memory descriptor with shape [1], since oneDNN didn't support
// 0-D now.
dnnl::memory::desc make_memory_desc(const phi::DenseTensor& ref_tensor,
phi::DataLayout target_layout) {
auto ref_dims = vectorize<int64_t>(ref_tensor.dims());
auto ref_type = ToOneDNNDataType(ref_tensor.dtype());
PADDLE_ENFORCE_NE(ref_type,
OneDNNDataType::undef,
errors::InvalidArgument(
"Ref tensor type (%s) is not supported by oneDNN.",
ref_tensor.dtype()));

auto md_dims = ref_dims.size() != 0 ? ref_dims : std::vector<int64_t>{1};
auto md_format =
OneDNNFormatForSize(md_dims.size(), ToOneDNNFormat(target_layout));
dnnl::memory::desc md(md_dims, ref_type, md_format);
return md;
}

void TransDataLayoutFromOneDNN(DataLayout in_layout,
DataLayout out_layout,
const DenseTensor& in,
Expand All @@ -64,19 +85,7 @@ void TransDataLayoutFromOneDNN(DataLayout in_layout,
auto* dev_ctx = dynamic_cast<OneDNNContext*>(pool.Get(place));
auto& cpu_engine = dev_ctx->GetEngine();

auto in_tz = vectorize<int64_t>(in.dims());
auto out_tz = in_tz;

auto in_type = ToOneDNNDataType(in.dtype());
PADDLE_ENFORCE_NE(
in_type,
OneDNNDataType::undef,
errors::InvalidArgument("Input tensor type (%s) is not supported.",
in.dtype()));

auto out_format =
OneDNNFormatForSize(in_tz.size(), ToOneDNNFormat(out_layout));
dnnl::memory::desc out_mem_desc(out_tz, in_type, out_format);
dnnl::memory::desc out_mem_desc = make_memory_desc(in, out_layout);

// output tensor has the same dims as input. Reorder don't change dims
out->set_mem_desc(out_mem_desc);
Expand All @@ -85,6 +94,8 @@ void TransDataLayoutFromOneDNN(DataLayout in_layout,
// Note(0x45f): Using initialized() to support slice Tensors
// with shapes like [0, 0, 0].
if (in.initialized() && ((in.mem_desc() != out->mem_desc()) || always_copy)) {
auto in_tz = vectorize<int64_t>(in.dims());
auto in_type = ToOneDNNDataType(in.dtype());
void* in_data = GetDataFromTensor(in, in_type);

ReorderOneDNNHandler handler(in_tz, in.dtype(), in_type, cpu_engine);
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/funcs/data_layout_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ void TransDataLayoutFromOneDNN(DataLayout in_layout,
bool always_copy = false);
void* GetDataFromTensor(const DenseTensor& tensor, OneDNNDataType type);

dnnl::memory::desc make_memory_desc(const phi::DenseTensor& ref_tensor,
phi::DataLayout target_layout);

#endif

} // namespace funcs
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/kernels/onednn/log_softmax_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ class LogSoftmaxOneDNNHandler
const int axis)
: funcs::OneDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward>(
onednn_engine, cpu_place) {
const int rank = x.dims().size() != 0 ? x.dims().size() : 1;
const int canonical_axis = funcs::CanonicalAxis(axis, rank);
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, x.mem_desc(), axis);
dnnl::prop_kind::forward_inference, x.mem_desc(), canonical_axis);
}
};

Expand All @@ -43,7 +45,6 @@ void LogSoftmaxKernel(const Context& dev_ctx,
int axis,
DenseTensor* out) {
const auto& onednn_engine = dev_ctx.GetEngine();
axis = axis >= 0 ? axis : x.dims().size() + axis;

LogSoftmaxOneDNNHandler<T> handler(
onednn_engine, dev_ctx.GetPlace(), x, axis);
Expand Down
7 changes: 1 addition & 6 deletions paddle/phi/kernels/transfer_layout_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,6 @@ void TransferLayoutMKLDNN(const Context& dev_ctx,
if (src_layout != DataLayout::ONEDNN && dst_layout == DataLayout::ONEDNN) {
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
// Just set layout/format. No real transform occur
auto out_format = funcs::OneDNNFormatForSize(
x.dims().size(), funcs::ToOneDNNFormat(src_layout));

out->ShareDataWith(x);
// For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order
Expand All @@ -148,9 +145,7 @@ void TransferLayoutMKLDNN(const Context& dev_ctx,
OneDNNContext::tls().set_cur_paddle_data_layout(src_layout);
}

dnnl::memory::desc out_mem_desc(vectorize<int64_t>(out->dims()),
funcs::ToOneDNNDataType(x.dtype()),
out_format);
dnnl::memory::desc out_mem_desc = funcs::make_memory_desc(*out, src_layout);
out->set_mem_desc(out_mem_desc);
} else if (src_layout == DataLayout::ONEDNN &&
dst_layout != DataLayout::ONEDNN) {
Expand Down
Loading