diff --git a/paddle/phi/core/dense_tensor.cc b/paddle/phi/core/dense_tensor.cc index c86a06bedef8d..1181a81266976 100644 --- a/paddle/phi/core/dense_tensor.cc +++ b/paddle/phi/core/dense_tensor.cc @@ -59,10 +59,6 @@ DenseTensor::DenseTensor(const DenseTensor& other) { storage_properties_ = std::move(CopyStorageProperties(other.storage_properties_)); inplace_version_counter_ = other.inplace_version_counter_; - -#ifdef PADDLE_WITH_DNNL - mem_desc_ = other.mem_desc_; -#endif } DenseTensor& DenseTensor::operator=(const DenseTensor& other) { @@ -74,9 +70,6 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) { storage_properties_ = std::move(CopyStorageProperties(other.storage_properties_)); inplace_version_counter_ = other.inplace_version_counter_; -#ifdef PADDLE_WITH_DNNL - mem_desc_ = other.mem_desc_; -#endif return *this; } @@ -85,9 +78,6 @@ DenseTensor& DenseTensor::operator=(DenseTensor&& other) noexcept { std::swap(holder_, other.holder_); storage_properties_ = std::move(other.storage_properties_); std::swap(inplace_version_counter_, other.inplace_version_counter_); -#ifdef PADDLE_WITH_DNNL - mem_desc_ = other.mem_desc_; -#endif return *this; } diff --git a/paddle/phi/core/dense_tensor.h b/paddle/phi/core/dense_tensor.h index bcc2b07a89e3a..b78cec1483272 100644 --- a/paddle/phi/core/dense_tensor.h +++ b/paddle/phi/core/dense_tensor.h @@ -22,12 +22,6 @@ limitations under the License. */ #include "paddle/phi/core/tensor_meta.h" #include "paddle/utils/test_macros.h" -/* @jim19930609: Move to MKLDNN_Tensor in the future - */ -#ifdef PADDLE_WITH_DNNL -#include "dnnl.hpp" // NOLINT -#endif - namespace phi { class DenseTensorUtils; @@ -290,18 +284,6 @@ class TEST_API DenseTensor : public TensorBase, std::shared_ptr inplace_version_counter_ = std::make_shared(); -/* @jim19930609: This is a hack -In general, it is badly designed to fuse MKLDNN-specific objects into a -generic Tensor. -We temporarily leave them here to unblock Tensor Unification progress. -In the final state, we should come up with a MKLDNN_Tensor and move the -following codes there. -*/ -#ifdef PADDLE_WITH_DNNL - /// \brief memory descriptor of tensor which have layout set as kMKLDNN - dnnl::memory::desc mem_desc_; -#endif - #ifndef PADDLE_WITH_CUSTOM_KERNEL #include "paddle/phi/core/dense_tensor.inl" #endif diff --git a/paddle/phi/core/dense_tensor.inl b/paddle/phi/core/dense_tensor.inl index 19101e7093f74..a8672b2171143 100644 --- a/paddle/phi/core/dense_tensor.inl +++ b/paddle/phi/core/dense_tensor.inl @@ -97,22 +97,12 @@ std::vector Split(int64_t split_size, int64_t axis) const; std::vector Chunk(int64_t chunks, int64_t axis) const; -/* @jim19930609: This is a hack -In general, it is badly designed to fuse MKLDNN-specific objects into a -generic Tensor. -We temporarily leave them here to unblock Tensor Unification progress. -In the final state, we should come up with a MKLDNN_Tensor and move the -following codes there. -*/ #ifdef PADDLE_WITH_DNNL public: const dnnl::memory::desc& mem_desc() const; -inline void set_mem_desc(const dnnl::memory::desc& mem_desc) { - mem_desc_ = mem_desc; - meta_.layout = DataLayout::ONEDNN; -} +void set_mem_desc(const dnnl::memory::desc& mem_desc); #endif diff --git a/paddle/phi/core/dense_tensor_impl.cc b/paddle/phi/core/dense_tensor_impl.cc index 5fa43647da19c..39efb048e7432 100644 --- a/paddle/phi/core/dense_tensor_impl.cc +++ b/paddle/phi/core/dense_tensor_impl.cc @@ -377,7 +377,30 @@ std::vector DenseTensor::Chunk(int64_t chunks, } #ifdef PADDLE_WITH_DNNL -const dnnl::memory::desc& DenseTensor::mem_desc() const { return mem_desc_; } +const dnnl::memory::desc& DenseTensor::mem_desc() const { + if (storage_properties_ == nullptr) { + static dnnl::memory::desc undef_desc = dnnl::memory::desc(); + return undef_desc; + } + return this->storage_properties().mem_desc; +} + +void DenseTensor::set_mem_desc(const dnnl::memory::desc& mem_desc) { + if (storage_properties_ == nullptr) { + storage_properties_ = std::make_unique(); + static_cast(storage_properties_.get())->mem_desc = + mem_desc; + meta_.layout = DataLayout::ONEDNN; + } else if (OneDNNStorageProperties::classof(storage_properties_.get())) { + static_cast(storage_properties_.get())->mem_desc = + mem_desc; + meta_.layout = DataLayout::ONEDNN; + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "The actual type of storage_properties is inconsistent with the type " + "of the template parameter passed in.")); + } +} #endif // NOTE: For historical reasons, this interface has a special behavior, @@ -394,9 +417,6 @@ DenseTensor& DenseTensor::ShareDataWith(const DenseTensor& src) { meta_.strides = src.meta_.strides; storage_properties_ = std::move(CopyStorageProperties(src.storage_properties_)); -#ifdef PADDLE_WITH_DNNL - mem_desc_ = src.mem_desc_; -#endif return *this; }