Skip to content

Commit

Permalink
Merge pull request #5543 from tensor-tang/ds2
Browse files Browse the repository at this point in the history
add resize of MKLDNNMatrix
  • Loading branch information
luotao1 authored Nov 12, 2017
2 parents 23b9bc0 + e1b8f5f commit 1c31bb9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
7 changes: 1 addition & 6 deletions paddle/math/MKLDNNMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,7 @@ void MKLDNNMatrix::downSpatial() {
}
memory::desc md = memory::desc(dstDims, getDtype(), dstFmt);
memory::primitive_desc pd = memory::primitive_desc(md, getEngine());
mkldnn_primitive_t result;
mkldnn::error::wrap_c_api(
mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr),
"could not create a memory primitive");
reset(result);
set_data_handle(data_);
resetMKLDNNMemory(pd, data_);
}

} // namespace paddle
32 changes: 32 additions & 0 deletions paddle/math/MKLDNNMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,27 @@ class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory {
m_.reset();
}

/**
* override the CpuMatrix::resize
*/
void resize(size_t newHeight, size_t newWidth) override {
m_->resize(newHeight, newWidth);
if (data_ == m_->getData() && elementCnt_ == newHeight * newWidth) {
return;
}
CpuMatrix::setData(data_);
height_ = newHeight;
width_ = newWidth;
elementCnt_ = newHeight * newWidth;
stride_ = width_;
auto pd = mkldnn::memory::primitive_desc(
mkldnn::memory::desc({(int)newHeight, (int)newWidth},
getDtype(),
mkldnn::memory::format::nc),
getEngine());
resetMKLDNNMemory(pd, data_);
}

/**
* override Matrix::getData
* check data before return
Expand Down Expand Up @@ -215,6 +236,17 @@ class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory {
memory::format srcFmt,
memory::format dstFmt,
memory::dims dm);
/**
* reset this MKLDNN Memory from primitve desc
*/
void resetMKLDNNMemory(memory::primitive_desc pd, real* data) {
mkldnn_primitive_t result;
mkldnn::error::wrap_c_api(
mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr),
"could not create a memory primitive");
reset(result);
set_data_handle(data);
}

private:
// save the CpuMatrixPtr in case the buffer released outside
Expand Down

0 comments on commit 1c31bb9

Please sign in to comment.