Skip to content

Commit

Permalink
Fix for undefined format for 6 dim tensor (#38553)
Browse files Browse the repository at this point in the history
* 6 dims fix

* removed limitations of max dims
  • Loading branch information
jakpiase authored Dec 31, 2021
1 parent 31efec5 commit 730ccd9
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions paddle/fluid/platform/mkldnn_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,13 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::ncdhw;
} else {
return dnnl::memory::format_tag::ndhwc;
return dnnl::memory::format_tag::abcde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::acbde;
} else if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::acdeb;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
Expand Down Expand Up @@ -310,6 +314,10 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
strides[2] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::abcdef;
} else if (strides[0] >= strides[2] && strides[2] >= strides[1] &&
strides[1] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return dnnl::memory::format_tag::acbdef;
}
}
}
Expand Down Expand Up @@ -397,7 +405,9 @@ inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
return MKLDNNMemoryFormat::ndhwc;
}
} else if (dims_size == 6) {
return MKLDNNMemoryFormat::abcdef;
if (data_format == MKLDNNMemoryFormat::nchw) {
return MKLDNNMemoryFormat::abcdef;
}
}
return data_format;
}
Expand Down

0 comments on commit 730ccd9

Please sign in to comment.