diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 7cadf53cc5299..5a1c03327d592 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1826,7 +1826,7 @@ pt::KernelContext OperatorWithKernel::BuildPtKernelContext( << in_def.layout; auto ins_vector = ctx.inputs.at(input_names[i]); - std::vector> tmp_inputs; + std::vector> tmp_inputs; for (auto var : ins_vector) { auto pt_in = framework::InputVariableToPtTensor(*var, in_def); @@ -1839,7 +1839,7 @@ pt::KernelContext OperatorWithKernel::BuildPtKernelContext( auto out_def = output_defs.at(i); auto outs_vector = ctx.outputs.at(output_names[i]); - std::vector> tmp_outputs; + std::vector> tmp_outputs; for (auto var : outs_vector) { auto pt_out = framework::OutputVariableToPtTensor(var, out_def); tmp_outputs.emplace_back(pt_out); diff --git a/paddle/fluid/framework/tcmpt_utils.cc b/paddle/fluid/framework/tcmpt_utils.cc index a39e653d0349e..fc38eb42d74c7 100644 --- a/paddle/fluid/framework/tcmpt_utils.cc +++ b/paddle/fluid/framework/tcmpt_utils.cc @@ -77,7 +77,7 @@ std::shared_ptr MakeTensorImpl( pt::TransToPtDataLayout(tensor.layout())); } -std::shared_ptr InputVariableToPtTensor( +std::shared_ptr InputVariableToPtTensor( const framework::Variable& variable, const pt::TensorArgDef& arg_def) { auto expected_place = pt::TransToFluidPlace(arg_def.backend); @@ -122,7 +122,7 @@ std::shared_ptr InputVariableToPtTensor( return nullptr; } -std::shared_ptr OutputVariableToPtTensor( +std::shared_ptr OutputVariableToPtTensor( framework::Variable* variable, const pt::TensorArgDef& arg_def) { // mutable_data before run kernel, to avoid share output form // KernelContext to original tensor diff --git a/paddle/fluid/framework/tcmpt_utils.h b/paddle/fluid/framework/tcmpt_utils.h index 27c2c8e9b5dec..4d08692bd9c26 100644 --- a/paddle/fluid/framework/tcmpt_utils.h +++ b/paddle/fluid/framework/tcmpt_utils.h @@ -49,9 +49,15 @@ std::shared_ptr MakeTensorImpl(const Tensor& tensor, const platform::Place& place, proto::VarType::Type type); -std::shared_ptr InputVariableToPtTensor( +template +void ShareTensorImpl(PtTensorImplT* tensor_impl, LoDTensor* out); + +template +void ShareTensorImpl(PtTensorImplT* tensor_impl, Tensor* out); + +std::shared_ptr InputVariableToPtTensor( const framework::Variable& variable, const pt::TensorArgDef& arg_def); -std::shared_ptr OutputVariableToPtTensor( +std::shared_ptr OutputVariableToPtTensor( framework::Variable* variable, const pt::TensorArgDef& arg_def); /* Kernel Key translate */ diff --git a/paddle/fluid/framework/tcmpt_utils_test.cc b/paddle/fluid/framework/tcmpt_utils_test.cc index f1966789c1dde..200bd5429cd46 100644 --- a/paddle/fluid/framework/tcmpt_utils_test.cc +++ b/paddle/fluid/framework/tcmpt_utils_test.cc @@ -38,7 +38,7 @@ TEST(TcmptUtils, MakeTensor) { ASSERT_EQ(dense_x->data()[0], expect_value[0]); ASSERT_EQ(dense_x->data()[1], expect_value[1]); ASSERT_EQ(dense_x->backend(), pt::Backend::kCPU); - ASSERT_EQ(dense_x->type(), pt::DataType::kFLOAT32); + ASSERT_EQ(dense_x->data_type(), pt::DataType::kFLOAT32); } TEST(TcmptUtils, VarToPtTensor) { @@ -60,7 +60,7 @@ TEST(TcmptUtils, VarToPtTensor) { auto tensor_x = InputVariableToPtTensor(v, tensor_def); // 3. check result ASSERT_EQ(tensor_x->backend(), expect_backend); - ASSERT_EQ(tensor_x->type(), pt::DataType::kINT32); + ASSERT_EQ(tensor_x->data_type(), pt::DataType::kINT32); } } // namespace framework diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 87e7e754e3ee8..f65b799e150fc 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -288,7 +288,7 @@ static pt::KernelContext BuildDygraphPtKernelContext( auto& in_def = input_defs.at(i); auto& ins_vector = ins.at(input_names[i]); - std::vector> tmp_inputs; + std::vector> tmp_inputs; for (auto var : ins_vector) { const auto& variable = var->Var(); @@ -302,7 +302,7 @@ static pt::KernelContext BuildDygraphPtKernelContext( auto& out_def = output_defs.at(i); auto& outs_vector = outs.at(output_names[i]); - std::vector> tmp_outputs; + std::vector> tmp_outputs; for (auto var : outs_vector) { auto* variable = var->MutableVar(); diff --git a/paddle/tcmpt/common/data_type.h b/paddle/tcmpt/common/data_type.h new file mode 100644 index 0000000000000..03881e6bda1ca --- /dev/null +++ b/paddle/tcmpt/common/data_type.h @@ -0,0 +1,181 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/platform/bfloat16.h" +#include "paddle/fluid/platform/complex.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace experimental { + +using complex64 = ::paddle::platform::complex; +using complex128 = ::paddle::platform::complex; +using float16 = ::paddle::platform::float16; +using bfloat16 = ::paddle::platform::bfloat16; + +enum class DataType { + kUndef = 0, + kBOOL, + kINT8, // Char + kUINT8, // BYte + kINT16, + kINT32, + kUINT32, + kINT64, + kUINT64, + kBFLOAT16, + kFLOAT16, + kUINT16, + kFLOAT32, + kFLOAT64, + kCOMPLEX64, + kCOMPLEX128, + kNumDataTypes +}; + +inline size_t SizeOf(DataType data_type) { + switch (data_type) { + case DataType::kBOOL: + case DataType::kUINT8: + case DataType::kINT8: + return 1; + case DataType::kFLOAT16: + case DataType::kINT16: + case DataType::kUINT16: + return 2; + case DataType::kFLOAT32: + case DataType::kINT32: + case DataType::kUINT32: + return 4; + case DataType::kFLOAT64: + case DataType::kINT64: + case DataType::kUINT64: + return 8; + case DataType::kUndef: + case DataType::kBFLOAT16: + case DataType::kCOMPLEX64: + case DataType::kCOMPLEX128: + case DataType::kNumDataTypes: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type %d is not supported by tensor.", + static_cast(data_type))); + return 0; + } +} + +#define PT_FOR_EACH_DATA_TYPE(_) \ + _(bool, DataType::kBOOL) \ + _(int8_t, DataType::kINT8) \ + _(uint8_t, DataType::kUINT8) \ + _(int16_t, DataType::kINT16) \ + _(int, DataType::kINT32) \ + _(int64_t, DataType::kINT64) \ + _(bfloat16, DataType::kBFLOAT16) \ + _(float16, DataType::kFLOAT16) \ + _(float, DataType::kFLOAT32) \ + _(double, DataType::kFLOAT64) \ + _(complex64, DataType::kCOMPLEX64) \ + _(complex128, DataType::kCOMPLEX128) + +template +struct DataTypeToCppType; + +template +struct CppTypeToDataType; + +#define PT_SPECIALIZE_DataTypeToCppType(cpp_type, data_type) \ + template <> \ + struct DataTypeToCppType { \ + using type = cpp_type; \ + }; + +PT_FOR_EACH_DATA_TYPE(PT_SPECIALIZE_DataTypeToCppType) + +#undef PT_SPECIALIZE_DataTypeToCppType + +#define PT_SPECIALIZE_CppTypeToDataType(cpp_type, data_type) \ + template <> \ + struct CppTypeToDataType { \ + constexpr static DataType Type() { return data_type; } \ + }; + +PT_FOR_EACH_DATA_TYPE(PT_SPECIALIZE_CppTypeToDataType) + +#undef PT_SPECIALIZE_CppTypeToDataType + +inline std::ostream& operator<<(std::ostream& os, DataType dtype) { + switch (dtype) { + case DataType::kUndef: + os << "Undefined"; + break; + case DataType::kBOOL: + os << "bool"; + break; + case DataType::kINT8: + os << "int8"; + break; + case DataType::kUINT8: + os << "uint8"; + break; + case DataType::kINT16: + os << "int16"; + break; + case DataType::kINT32: + os << "int32"; + break; + case DataType::kINT64: + os << "int64"; + break; + case DataType::kBFLOAT16: + os << "bfloat16"; + break; + case DataType::kFLOAT16: + os << "float16"; + break; + case DataType::kFLOAT32: + os << "float32"; + break; + case DataType::kFLOAT64: + os << "float64"; + break; + case DataType::kCOMPLEX64: + os << "complex64"; + break; + case DataType::kCOMPLEX128: + os << "complex128"; + break; + default: + // TODO(chenweihang): change to enforce later + throw std::runtime_error("Invalid DataType type."); + } + return os; +} + +inline DataType& operator++(DataType& dtype, int) { + dtype = + DataType(static_cast::type>(dtype) + 1); + return dtype; +} + +} // namespace experimental +} // namespace paddle + +namespace pt { +using DataType = paddle::experimental::DataType; +} diff --git a/paddle/tcmpt/core/layout.cc b/paddle/tcmpt/common/layout.h similarity index 75% rename from paddle/tcmpt/core/layout.cc rename to paddle/tcmpt/common/layout.h index 4f4fd972516da..ae4e43a9f7197 100644 --- a/paddle/tcmpt/core/layout.cc +++ b/paddle/tcmpt/common/layout.h @@ -12,11 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/core/layout.h" +#pragma once -namespace pt { +namespace paddle { +namespace experimental { + +enum class DataLayout { + kUndef = 0, + kAny, + kNHWC, + kNCHW, + kMKLDNN, + kNumLayouts, +}; -std::ostream& operator<<(std::ostream& os, DataLayout dtype) { +inline std::ostream& operator<<(std::ostream& os, DataLayout dtype) { switch (dtype) { case DataLayout::kUndef: os << "Undefined"; @@ -40,9 +50,15 @@ std::ostream& operator<<(std::ostream& os, DataLayout dtype) { return os; } -DataLayout& operator++(DataLayout& layout, int) { +inline DataLayout& operator++(DataLayout& layout, int) { layout = DataLayout( static_cast::type>(layout) + 1); return layout; } -} // namespace pt + +} // namespace experimental +} // namespace paddle + +namespace pt { +using DataLayout = paddle::experimental::DataLayout; +} diff --git a/paddle/tcmpt/core/CMakeLists.txt b/paddle/tcmpt/core/CMakeLists.txt index 5eadf3db39a64..88573c729c3f2 100644 --- a/paddle/tcmpt/core/CMakeLists.txt +++ b/paddle/tcmpt/core/CMakeLists.txt @@ -5,17 +5,15 @@ ELSE() ENDIF() cc_library(backend SRCS backend.cc) -cc_library(dtype SRCS dtype.cc) -cc_library(layout SRCS layout.cc) if(WITH_GPU) - cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place backend dtype layout gpu_info) + cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place backend gpu_info) elseif(WITH_ROCM) - cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place backend dtype layout gpu_info) + cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place backend gpu_info) else() - cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place backend dtype layout) + cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place backend) endif() cc_library(dense_tensor SRCS dense_tensor.cc DEPS enforce data_type ddim allocator place convert_utils ${MKLDNN_CTX_DEPS}) -cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce backend dtype layout) +cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce backend) cc_library(kernel_context SRCS kernel_context.cc DEPS enforce device_context) diff --git a/paddle/tcmpt/core/allocator.cc b/paddle/tcmpt/core/allocator.cc new file mode 100644 index 0000000000000..da1576f81ad71 --- /dev/null +++ b/paddle/tcmpt/core/allocator.cc @@ -0,0 +1,19 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/tcmpt/core/allocator.h" + +namespace paddle { +namespace tcmpt {} // namespace tcmpt +} // namespace paddle diff --git a/paddle/tcmpt/core/allocator.h b/paddle/tcmpt/core/allocator.h new file mode 100644 index 0000000000000..592f7a4078f80 --- /dev/null +++ b/paddle/tcmpt/core/allocator.h @@ -0,0 +1,159 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace tcmpt { + +/// \brief Encapsulates strategies for access/addressing, allocation/ +/// deallocation and construction/destruction of objects. +class RawAllocator { + public: + /// \brief Default destructor. + virtual ~RawAllocator() = default; + + /// \brief Allocates storage suitable for an array object of n bytes + /// and creates the array, but does not construct array elements. + /// May throw exceptions. + /// \param bytes_size The number of bytes to allocate. + /// \return The first address allocated. + virtual void* Allocate(size_t bytes_size) = 0; + + /// \brief Deallocates storage pointed to ptr, which must be a value + /// returned by a previous call to allocate that has not been + /// invalidated by an intervening call to deallocate. The bytes_size + /// must match the value previously passed to allocate. + /// \param ptr The first address to deallocate. + /// \param bytes_size The number of bytes to deallocate. + virtual void Deallocate(void* ptr, size_t bytes_size) = 0; + + /// \brief Get the place value of the allocator and the allocation. + /// \return The place value of the allocator and the allocation. + virtual const platform::Place& place() const = 0; +}; + +/// \brief Fancy pointer with context. The use of this data type +/// is to be compatible with allocators from different frameworks +/// without significant performance loss. This class does not +/// support being inherited. +class Allocation final { + public: + using DeleterFnPtr = void (*)(void*); + + Allocation() = default; + Allocation(Allocation&&) = default; + Allocation& operator=(Allocation&&) = default; + + Allocation(void* data, const platform::Place& place) + : data_(data), place_(place) {} + + Allocation(void* data, + void* ctx, + DeleterFnPtr ctx_deleter, + const platform::Place& place) + : data_(data), ctx_(ctx, ctx_deleter), place_(place) {} + + void* operator->() const noexcept { return data_; } + operator bool() const noexcept { return data_ || ctx_.Get(); } + const platform::Place& place() const noexcept { return place_; } + + void Clear() noexcept { + data_ = nullptr; + ctx_.Clear(); + } + + /// \brief Statically cast the void pointer of the context object to + /// the primitive type. Conversion of any pointer to void* and back + /// to pointer to the original cv type preserves its original value. + /// \param T The primitive type name of the context pointer. + /// \param expected_deleter The destructor passed in to enhance type + /// safety checking. + template + T* CastContext(DeleterFnPtr expected_deleter) const noexcept { + if (ctx_.deleter() != expected_deleter) { + return nullptr; + } + return static_cast(ctx_.Get()); + } + + public: + class Context { + public: + Context() = default; + Context(void* ctx, DeleterFnPtr deleter) noexcept : ctx_(ctx), + deleter_(deleter) {} + Context(Context&& other) noexcept { + // Exchange them explicitly to avoid moving is equivalent + // to copying. + swap(*this, other); + } + Context& operator=(Context&& other) noexcept { + swap(*this, other); + return *this; + } + ~Context() { + if (deleter_) { + deleter_(ctx_); + } + } + void Clear() noexcept { + ctx_ = nullptr; + deleter_ = nullptr; + } + void* Get() const noexcept { return ctx_; } + DeleterFnPtr deleter() const noexcept { return deleter_; } + void* Release() noexcept { + deleter_ = nullptr; + return ctx_; + } + friend void swap(Context& a, Context& b) noexcept; + + private: + void* ctx_{nullptr}; + DeleterFnPtr deleter_{nullptr}; + }; + + private: + void* data_{nullptr}; + Context ctx_; + // TODO(Shixiaowei02): Enum needs to be used instead to reduce + // the construction overhead by more than 50%. + platform::Place place_; +}; + +inline void swap(Allocation::Context& a, Allocation::Context& b) noexcept { + ::std::swap(a.ctx_, b.ctx_); + ::std::swap(a.deleter_, b.deleter_); +} + +/// \brief Context compatible allocator interface. This allocator is +/// mainly used for general data structures such as Tensor. The raw +/// allocator is more universal and efficient. +class Allocator { + public: + virtual ~Allocator() = default; + virtual Allocation Allocate(size_t bytes_size) = 0; +}; + +inline Allocation Allocate(const std::shared_ptr& a, size_t n) { + CHECK(a); + return a->Allocate(n); +} + +} // namespace tcmpt +} // namespace paddle diff --git a/paddle/tcmpt/core/convert_utils.h b/paddle/tcmpt/core/convert_utils.h index a567775811349..011652bdc9572 100644 --- a/paddle/tcmpt/core/convert_utils.h +++ b/paddle/tcmpt/core/convert_utils.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once +#include "paddle/tcmpt/common/data_type.h" +#include "paddle/tcmpt/common/layout.h" #include "paddle/tcmpt/core/backend.h" -#include "paddle/tcmpt/core/dtype.h" -#include "paddle/tcmpt/core/layout.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/framework/data_layout.h" @@ -27,6 +27,9 @@ limitations under the License. */ namespace pt { +using DataType = paddle::experimental::DataType; +using DataLayout = paddle::experimental::DataLayout; + // TODO(chenweihang): Use the original var type as much as possible // to avoid transform, such as DataLayout, VarType Backend TransToPtBackend(const paddle::platform::Place& place); diff --git a/paddle/tcmpt/core/dense_tensor.cc b/paddle/tcmpt/core/dense_tensor.cc index 921f0ee8d9102..9c34b5823d590 100644 --- a/paddle/tcmpt/core/dense_tensor.cc +++ b/paddle/tcmpt/core/dense_tensor.cc @@ -31,7 +31,7 @@ using XPUPlace = paddle::platform::XPUPlace; using NPUPlace = paddle::platform::NPUPlace; using NPUPinnedPlace = paddle::platform::NPUPinnedPlace; -Place DenseTensor::place() const { +const paddle::platform::Place& DenseTensor::place() const { PADDLE_ENFORCE_NOT_NULL( allocation_, paddle::platform::errors::PreconditionNotMet( @@ -52,7 +52,7 @@ void DenseTensor::ShareAllocation( } // TODO(chenweihang): Add other place branchs -Place DenseTensor::GetPlaceByBackend() const { +paddle::platform::Place DenseTensor::GetPlaceByBackend() const { switch (meta_.backend) { case Backend::kCPU: return CPUPlace(); diff --git a/paddle/tcmpt/core/dense_tensor.h b/paddle/tcmpt/core/dense_tensor.h index d7853e7cba201..a0d195b740bed 100644 --- a/paddle/tcmpt/core/dense_tensor.h +++ b/paddle/tcmpt/core/dense_tensor.h @@ -16,7 +16,7 @@ limitations under the License. */ #include -#include "paddle/tcmpt/core/tensor_interface.h" +#include "paddle/tcmpt/core/tensor_base.h" #include "paddle/tcmpt/core/tensor_meta.h" #include "paddle/tcmpt/core/tensor_status.h" @@ -30,6 +30,9 @@ class Allocation; namespace pt { +using TensorBase = paddle::tcmpt::TensorBase; +using DataType = paddle::experimental::DataType; + // TODO(chenweihang): Allocation still link to framework, Redesign and // decoupled Allocation and Allocator? using Allocation = paddle::memory::allocation::Allocation; @@ -47,9 +50,9 @@ using Allocation = paddle::memory::allocation::Allocation; * * If the memory layout is different, it cannot be described based on the * general Allocation, and it needs to be directly inherited from - * TensorInterface. + * TensorBase. */ -class DenseTensor : public TensorInterface { +class DenseTensor : public TensorBase { public: // Not allowed to initialize a tensor without descriptive metadata DenseTensor() = delete; @@ -71,20 +74,20 @@ class DenseTensor : public TensorInterface { DenseTensor(TensorMeta&& meta, TensorStatus&& status) : meta_(std::move(meta)), status_(std::move(status)) {} - ~DenseTensor() override {} - int64_t numel() const override { return meta_.numel; } - DDim dims() const override { return meta_.dims; } + const paddle::framework::DDim& dims() const override { return meta_.dims; } - DataType type() const override { return meta_.type; } + DataType data_type() const override { return meta_.type; } DataLayout layout() const override { return meta_.layout; } - Place place() const override; + const paddle::platform::Place& place() const override; Backend backend() const override { return meta_.backend; } + bool valid() const override { return allocation_ != nullptr; } + bool initialized() const override { return allocation_ != nullptr; } /* member methods */ @@ -130,7 +133,7 @@ class DenseTensor : public TensorInterface { void ShareAllocation(const std::shared_ptr& allocation); - Place GetPlaceByBackend() const; + paddle::platform::Place GetPlaceByBackend() const; size_t MemorySize() const; diff --git a/paddle/tcmpt/core/dtype.cc b/paddle/tcmpt/core/dtype.cc deleted file mode 100644 index c9fefc6a69080..0000000000000 --- a/paddle/tcmpt/core/dtype.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/tcmpt/core/dtype.h" - -namespace pt { - -std::ostream& operator<<(std::ostream& os, DataType dtype) { - switch (dtype) { - case DataType::kUndef: - os << "Undefined"; - break; - case DataType::kBOOL: - os << "bool"; - break; - case DataType::kINT8: - os << "int8"; - break; - case DataType::kUINT8: - os << "uint8"; - break; - case DataType::kINT16: - os << "int16"; - break; - case DataType::kINT32: - os << "int32"; - break; - case DataType::kINT64: - os << "int64"; - break; - case DataType::kBFLOAT16: - os << "bfloat16"; - break; - case DataType::kFLOAT16: - os << "float16"; - break; - case DataType::kFLOAT32: - os << "float32"; - break; - case DataType::kFLOAT64: - os << "float64"; - break; - case DataType::kCOMPLEX64: - os << "complex64"; - break; - case DataType::kCOMPLEX128: - os << "complex128"; - break; - default: - // TODO(chenweihang): change to enforce later - throw std::runtime_error("Invalid DataType type."); - } - return os; -} - -DataType& operator++(DataType& dtype, int) { - dtype = - DataType(static_cast::type>(dtype) + 1); - return dtype; -} - -} // namespace pt diff --git a/paddle/tcmpt/core/dtype.h b/paddle/tcmpt/core/dtype.h deleted file mode 100644 index 1b5c1b8037a21..0000000000000 --- a/paddle/tcmpt/core/dtype.h +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/float16.h" - -namespace pt { - -using complex64 = paddle::platform::complex; -using complex128 = paddle::platform::complex; -using float16 = paddle::platform::float16; -using bfloat16 = paddle::platform::bfloat16; - -/** - * [ Why need new data type? ] - * - * The Var data type design in framework.proto is confusing, maybe we need - * polish the VarType in framework.proto. - * - * We need to ensure that the operator library is relatively independent - * and does not depend on the framework. Therefore, before calling the kernel - * in the Tensor Compute library inside the framework, the internal - * data type needs to be converted to the data type in the Tensor Compute - * library. - * - */ -enum class DataType { - kUndef = 0, - kBOOL, - kINT8, // Char - kUINT8, // BYte - kINT16, - kINT32, - kINT64, - kBFLOAT16, - kFLOAT16, - kFLOAT32, - kFLOAT64, - kCOMPLEX64, - kCOMPLEX128, - kNumDataTypes -}; - -std::ostream& operator<<(std::ostream& os, DataType dtype); - -DataType& operator++(DataType& dtype, int); - -#define PT_FOR_EACH_DATA_TYPE(_) \ - _(bool, DataType::kBOOL) \ - _(int8_t, DataType::kINT8) \ - _(uint8_t, DataType::kUINT8) \ - _(int16_t, DataType::kINT16) \ - _(int, DataType::kINT32) \ - _(int64_t, DataType::kINT64) \ - _(bfloat16, DataType::kBFLOAT16) \ - _(float16, DataType::kFLOAT16) \ - _(float, DataType::kFLOAT32) \ - _(double, DataType::kFLOAT64) \ - _(complex64, DataType::kCOMPLEX64) \ - _(complex128, DataType::kCOMPLEX128) - -template -struct DataTypeToCppType; - -template -struct CppTypeToDataType; - -#define PT_SPECIALIZE_DataTypeToCppType(cpp_type, data_type) \ - template <> \ - struct DataTypeToCppType { \ - using type = cpp_type; \ - }; - -PT_FOR_EACH_DATA_TYPE(PT_SPECIALIZE_DataTypeToCppType) - -#undef PT_SPECIALIZE_DataTypeToCppType - -#define PT_SPECIALIZE_CppTypeToDataType(cpp_type, data_type) \ - template <> \ - struct CppTypeToDataType { \ - constexpr static DataType Type() { return data_type; } \ - }; - -PT_FOR_EACH_DATA_TYPE(PT_SPECIALIZE_CppTypeToDataType) - -#undef PT_SPECIALIZE_CppTypeToDataType - -} // namespace pt diff --git a/paddle/tcmpt/core/kernel_context.h b/paddle/tcmpt/core/kernel_context.h index 057cbc11689f1..022d8a6713155 100644 --- a/paddle/tcmpt/core/kernel_context.h +++ b/paddle/tcmpt/core/kernel_context.h @@ -16,7 +16,7 @@ #include -#include "paddle/tcmpt/core/tensor_interface.h" +#include "paddle/tcmpt/core/tensor_base.h" #include "paddle/utils/any.h" // See Note [ Why still include the fluid headers? ] @@ -26,6 +26,9 @@ namespace pt { using DeviceContext = paddle::platform::DeviceContext; +using TensorBase = paddle::tcmpt::TensorBase; +using DataType = paddle::experimental::DataType; +using DataLayout = paddle::experimental::DataLayout; /** * Note: KernelContext doesn't manage the life if DeviceContext and Tensor @@ -38,8 +41,8 @@ class KernelContext { public: explicit KernelContext(const DeviceContext& dev_ctx) : dev_ctx_(dev_ctx) {} KernelContext(const DeviceContext& dev_ctx, - const std::vector>& inputs, - const std::vector>& outputs, + const std::vector>& inputs, + const std::vector>& outputs, const std::vector& attrs) : dev_ctx_(dev_ctx), inputs_(inputs), outputs_(outputs), attrs_(attrs) {} @@ -48,14 +51,14 @@ class KernelContext { return static_cast(dev_ctx_); } - void EmplaceBackInput(std::shared_ptr input) { + void EmplaceBackInput(std::shared_ptr input) { inputs_.emplace_back(input); // Record the start and end index of the input int index = inputs_.size(); input_range_.emplace_back(std::pair(index, index + 1)); } - void EmplaceBackInputs(std::vector> inputs) { + void EmplaceBackInputs(std::vector> inputs) { for (auto in : inputs) { inputs_.emplace_back(in); } @@ -65,15 +68,14 @@ class KernelContext { std::pair(index, index + inputs.size())); } - void EmplaceBackOutput(std::shared_ptr output) { + void EmplaceBackOutput(std::shared_ptr output) { outputs_.emplace_back(output); // Record the start and end index of the input int index = outputs_.size(); output_range_.emplace_back(std::pair(index, index + 1)); } - void EmplaceBackOutputs( - std::vector> outputs) { + void EmplaceBackOutputs(std::vector> outputs) { for (auto out : outputs) { outputs_.emplace_back(out); } @@ -115,8 +117,8 @@ class KernelContext { // TODO(chenweihang): replaced by small_vector // TODO(chenweihang): Tensor -> Tensor*, Tensor should by managed `scope` // Note: can't use API Tensor here, the inference don't use this API Tensor - std::vector> inputs_{}; - std::vector> outputs_{}; + std::vector> inputs_{}; + std::vector> outputs_{}; std::vector attrs_{}; // Only contains input like list[Tensor] need `range` diff --git a/paddle/tcmpt/core/kernel_factory.h b/paddle/tcmpt/core/kernel_factory.h index 5978264c9ef26..6e4a3fa86dfda 100644 --- a/paddle/tcmpt/core/kernel_factory.h +++ b/paddle/tcmpt/core/kernel_factory.h @@ -19,10 +19,10 @@ #include #include +#include "paddle/tcmpt/common/data_type.h" +#include "paddle/tcmpt/common/layout.h" #include "paddle/tcmpt/core/backend.h" -#include "paddle/tcmpt/core/dtype.h" #include "paddle/tcmpt/core/kernel_def.h" -#include "paddle/tcmpt/core/layout.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/enforce.h" @@ -31,6 +31,9 @@ namespace pt { +using DataType = paddle::experimental::DataType; +using DataLayout = paddle::experimental::DataLayout; + /** * [ Naming considerations ] * diff --git a/paddle/tcmpt/core/kernel_registry.h b/paddle/tcmpt/core/kernel_registry.h index 661d387e9b8e2..caa42546ab054 100644 --- a/paddle/tcmpt/core/kernel_registry.h +++ b/paddle/tcmpt/core/kernel_registry.h @@ -336,213 +336,213 @@ struct KernelRegistrar { // clang-format on -#define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ - func_id, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::pt::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pt::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pt::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ PT_KERNEL(meta_kernel_fn)); -#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ - func_id, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::pt::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ - func_id, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pt::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pt::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ - func_id, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::pt::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ - func_id, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pt::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pt::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ - func_id, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::pt::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ - func_id, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pt::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pt::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ - func_id, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::pt::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ - func_id, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pt::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pt::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ - func_id, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::pt::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ - func_id, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pt::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pt::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ - func_id, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::pt::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ - func_id, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pt::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pt::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ - func_id, \ - registrar_id, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_op_kernel_##func_id##_, registrar_id)( \ - kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::pt::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ - func_id, \ - PT_ID, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ + func_id, \ + registrar_id, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::pt::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_op_kernel_##func_id##_, registrar_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::pt::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ + func_id, \ + PT_ID, \ + backend, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) #define PT_REGISTER_KERNEL_STANDARD( \ diff --git a/paddle/tcmpt/core/layout.h b/paddle/tcmpt/core/layout.h deleted file mode 100644 index 4a8a223b62f84..0000000000000 --- a/paddle/tcmpt/core/layout.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -namespace pt { - -/** - * We need to ensure that the operator library is relatively independent - * and does not depend on the framework. Therefore, before calling the kernel - * in the Tensor Compute library inside the framework, the internal - * layout needs to be converted to the data type in the Tensor Compute - * library. - * - * Here we also can use the DataLayout in framework, they are all enum classes. - */ -enum class DataLayout { - kUndef = 0, - kAny, - kNHWC, - kNCHW, - kMKLDNN, - kNumLayouts, -}; - -std::ostream& operator<<(std::ostream& os, DataLayout dtype); - -DataLayout& operator++(DataLayout& layout, int); - -} // namespace pt diff --git a/paddle/tcmpt/core/spatial_tensor.h b/paddle/tcmpt/core/spatial_tensor.h index 5e51322bb8339..0e5bdd8be50a3 100644 --- a/paddle/tcmpt/core/spatial_tensor.h +++ b/paddle/tcmpt/core/spatial_tensor.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/core/tensor_interface.h" +#include "paddle/tcmpt/core/tensor_base.h" namespace pt { @@ -27,7 +27,7 @@ namespace pt { */ template -class SpatialTensor : public TensorInterface { +class SpatialTensor : public TensorBase { public: SpatialTensor(std::shared_ptr allocation, std::unique_ptr meta, diff --git a/paddle/tcmpt/core/storage.cc b/paddle/tcmpt/core/storage.cc new file mode 100644 index 0000000000000..02fbea8d0b3a1 --- /dev/null +++ b/paddle/tcmpt/core/storage.cc @@ -0,0 +1,27 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/tcmpt/core/storage.h" + +namespace paddle { +namespace tcmpt { + +void TensorStorage::Realloc(size_t size) { + data_.Clear(); + data_ = Allocate(alloc_, size); + size_ = size; +} + +} // namespace tcmpt +} // namespace paddle diff --git a/paddle/tcmpt/core/storage.h b/paddle/tcmpt/core/storage.h new file mode 100644 index 0000000000000..d838d0cd1c957 --- /dev/null +++ b/paddle/tcmpt/core/storage.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "boost/intrusive_ptr.hpp" +#include "paddle/tcmpt/core/utils/intrusive_ptr.h" +#include "paddle/tcmpt/core/utils/intrusive_ref_counter.h" + +#include "paddle/fluid/platform/place.h" +#include "paddle/tcmpt/core/allocator.h" + +namespace paddle { +namespace tcmpt { + +/// \brief The interface of contiguous storage used for the dense tensor. +/// It should be used in conjunction with the intrusive pointer. We prohibit +/// all default copy operations to ensure the integrity of the package. +class Storage : public intrusive_ref_counter { + public: + Storage() = default; + Storage(const Storage&) = delete; + + explicit Storage(Allocation&& data) : data_(std::move(data)) {} + + virtual ~Storage() = default; + + /// \brief Get the mutable data pointer of the storage. + /// This function is set to inline to improve performance. + /// \return The mutable data pointer of the storage. + void* data() const noexcept { return data_.operator->(); } + + virtual size_t size() const = 0; + virtual const platform::Place& place() const = 0; + virtual bool OwnsMemory() const = 0; + virtual void Realloc(size_t n) = 0; + + protected: + Allocation data_; +}; + +class TensorStorage : public Storage { + public: + explicit TensorStorage(const std::shared_ptr& a) : alloc_(a) {} + TensorStorage(const std::shared_ptr& a, size_t size) + : Storage(Allocate(a, size)), alloc_(a), size_(size) {} + + ~TensorStorage() = default; + + void Realloc(size_t size) override; + + size_t size() const noexcept override { return size_; } + const platform::Place& place() const override { return data_.place(); } + bool OwnsMemory() const noexcept override { return true; } + const std::shared_ptr& allocator() const noexcept { + return alloc_; + } + + private: + const std::shared_ptr alloc_; + int64_t size_{0}; +}; + +} // namespace tcmpt +} // namespace paddle diff --git a/paddle/tcmpt/core/tensor_base.cc b/paddle/tcmpt/core/tensor_base.cc new file mode 100644 index 0000000000000..05dba1206075d --- /dev/null +++ b/paddle/tcmpt/core/tensor_base.cc @@ -0,0 +1,20 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/tcmpt/core/tensor_base.h" +#include "paddle/tcmpt/core/utils/type_registry.h" + +namespace paddle { +namespace tcmpt {} +} diff --git a/paddle/tcmpt/core/tensor_base.h b/paddle/tcmpt/core/tensor_base.h new file mode 100644 index 0000000000000..240808e3cc492 --- /dev/null +++ b/paddle/tcmpt/core/tensor_base.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/tcmpt/common/data_type.h" +#include "paddle/tcmpt/common/layout.h" +#include "paddle/tcmpt/core/storage.h" +#include "paddle/tcmpt/core/utils/type_registry.h" + +#include "paddle/tcmpt/core/backend.h" + +namespace paddle { +namespace tcmpt { + +class TensorBase { + public: + using DataType = experimental::DataType; + using DataLayout = experimental::DataLayout; + + virtual ~TensorBase() = default; + + /// \brief Returns the number of elements contained in tensor. + /// \return The number of elements contained in tensor. + virtual int64_t numel() const = 0; + + /// \brief Returns the dims of the tensor. + /// \return The dims of the tensor. + virtual const paddle::framework::DDim& dims() const = 0; + + /// \brief Returns the data type of the tensor. + /// \return The data type of the tensor. + virtual DataType data_type() const = 0; + + /// \brief Returns the data layout of the tensor. + /// \return The data layout of the tensor. + virtual DataLayout layout() const = 0; + + /// \brief Returns the data place of the tensor. + /// \return The data place of the tensor. + virtual const platform::Place& place() const = 0; + + /// \brief Test whether the metadata is valid. + /// \return Whether the metadata is valid. + virtual bool valid() const = 0; + + /// \brief Test whether the storage is allocated. + /// return Whether the storage is allocated. + virtual bool initialized() const = 0; + + virtual pt::Backend backend() const = 0; + + /// \brief Return the type information of the derived class to support + /// safely downcast in non-rtti environment. + /// return The type information of the derived class. + TypeInfo type_info() const { return type_info_; } + + private: + template + friend class TypeInfoTraits; + TypeInfo type_info_{TypeInfo::kUnknownType}; +}; + +} // namespace tcmpt +} // namespace paddle diff --git a/paddle/tcmpt/core/tensor_interface.h b/paddle/tcmpt/core/tensor_interface.h deleted file mode 100644 index 6991c0d7f7f71..0000000000000 --- a/paddle/tcmpt/core/tensor_interface.h +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/tcmpt/core/backend.h" -#include "paddle/tcmpt/core/dtype.h" -#include "paddle/tcmpt/core/layout.h" - -namespace paddle { -namespace framework { -class DDim; -} -namespace platform { -class Place; -} -} - -namespace pt { - -// TODO(shixiaowei): replace by new DDim -using DDim = paddle::framework::DDim; - -// TODO(shixiaowei): replace by new Place? -using Place = paddle::platform::Place; - -/** - * The abstract class of Tensor implemention, it needs to define its basic - * behavior through inherited classes. - * - * TensorInterface allows Tensor to uniformly access various different - * TensorImpls within the framework. It will not be used as a kernel argument, - * but only contains the interfaces supported by various TensorImpls. - * In extreme cases, it can be an empty base class. - * - * If we don't use TensorInterface, we may need to use shared_ptr - * to unify Tensor's API. - */ -class TensorInterface { - public: - // Not allowed to initialize a tensor without descriptive metadata - TensorInterface() = default; - - TensorInterface(const TensorInterface&) = delete; - TensorInterface& operator=(const TensorInterface&) = delete; - TensorInterface(TensorInterface&&) = delete; - TensorInterface& operator=(TensorInterface&&) = delete; - - virtual ~TensorInterface() {} - - virtual int64_t numel() const = 0; - - virtual DDim dims() const = 0; - - virtual DataType type() const = 0; - - virtual DataLayout layout() const = 0; - - virtual Place place() const = 0; - - virtual Backend backend() const = 0; - - virtual bool initialized() const = 0; -}; - -} // namespace pt diff --git a/paddle/tcmpt/core/tensor_meta.h b/paddle/tcmpt/core/tensor_meta.h index de564a44de36e..3cc557e05b4c1 100644 --- a/paddle/tcmpt/core/tensor_meta.h +++ b/paddle/tcmpt/core/tensor_meta.h @@ -16,9 +16,9 @@ limitations under the License. */ #include +#include "paddle/tcmpt/common/data_type.h" +#include "paddle/tcmpt/common/layout.h" #include "paddle/tcmpt/core/backend.h" -#include "paddle/tcmpt/core/dtype.h" -#include "paddle/tcmpt/core/layout.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/framework/ddim.h" @@ -28,6 +28,9 @@ limitations under the License. */ namespace pt { +using DataType = paddle::experimental::DataType; +using DataLayout = paddle::experimental::DataLayout; + // template // using Vector = paddle::framework::Vector; diff --git a/paddle/tcmpt/core/tensor_status.h b/paddle/tcmpt/core/tensor_status.h index 1328c88dd014a..1eb56397414b5 100644 --- a/paddle/tcmpt/core/tensor_status.h +++ b/paddle/tcmpt/core/tensor_status.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once +#include "paddle/tcmpt/common/data_type.h" +#include "paddle/tcmpt/common/layout.h" #include "paddle/tcmpt/core/backend.h" -#include "paddle/tcmpt/core/dtype.h" -#include "paddle/tcmpt/core/layout.h" namespace pt { diff --git a/paddle/tcmpt/core/utils/CMakeLists.txt b/paddle/tcmpt/core/utils/CMakeLists.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/paddle/tcmpt/core/utils/intrusive_ptr.h b/paddle/tcmpt/core/utils/intrusive_ptr.h new file mode 100644 index 0000000000000..f368d05cb47db --- /dev/null +++ b/paddle/tcmpt/core/utils/intrusive_ptr.h @@ -0,0 +1,160 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "glog/logging.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace tcmpt { + +template +class intrusive_ptr { + public: + using this_type = intrusive_ptr; + constexpr intrusive_ptr() noexcept = default; + + ~intrusive_ptr() { + if (px) { + intrusive_ptr_release(px); + } + } + + intrusive_ptr(intrusive_ptr&& rhs) noexcept : px(rhs.px) { rhs.px = nullptr; } + + template ::value>> + intrusive_ptr(intrusive_ptr&& rhs) noexcept : px(rhs.get()) { + rhs.reset(); + } + + void reset() { this_type().swap(*this); } + + void reset(T* rhs) { this_type(rhs).swap(*this); } + + void reset(T* rhs, bool add_ref) { this_type(rhs, add_ref).swap(*this); } + + T* get() const noexcept { return px; } + + T* detach() noexcept { + T* ret = px; + px = nullptr; + return ret; + } + + T& operator*() const { + PADDLE_ENFORCE_NOT_NULL( + px, + platform::errors::PreconditionNotMet( + "The pointer must be non-null before the dereference operation.")); + return *px; + } + + T* operator->() const { + PADDLE_ENFORCE_NOT_NULL( + px, + platform::errors::PreconditionNotMet( + "The pointer must be non-null before the dereference operation.")); + return px; + } + + void swap(intrusive_ptr& rhs) noexcept { + T* tmp = px; + px = rhs.px; + rhs.px = tmp; + } + + private: + template ::value>> + explicit intrusive_ptr(U* p, bool add_ref = true) : px(p) { + if (px && add_ref) { + intrusive_ptr_add_ref(px); + } + } + + template + friend intrusive_ptr make_intrusive(Args&&...); + template + friend intrusive_ptr copy_intrusive(const intrusive_ptr&); + + T* px{nullptr}; +}; + +template +inline bool operator==(const intrusive_ptr& a, + const intrusive_ptr& b) noexcept { + return a.get() == b.get(); +} + +template +inline bool operator!=(const intrusive_ptr& a, + const intrusive_ptr& b) noexcept { + return a.get() != b.get(); +} + +template +inline bool operator==(const intrusive_ptr& a, U* b) noexcept { + return a.get() == b; +} + +template +inline bool operator!=(const intrusive_ptr& a, U* b) noexcept { + return a.get() != b; +} + +template +inline bool operator==(T* a, const intrusive_ptr& b) noexcept { + return a == b.get(); +} + +template +inline bool operator!=(T* a, const intrusive_ptr& b) noexcept { + return a != b.get(); +} + +template +inline bool operator==(const intrusive_ptr& p, std::nullptr_t) noexcept { + return p.get() == nullptr; +} + +template +inline bool operator==(std::nullptr_t, const intrusive_ptr& p) noexcept { + return p.get() == nullptr; +} + +template +inline bool operator!=(const intrusive_ptr& p, std::nullptr_t) noexcept { + return p.get() != nullptr; +} + +template +inline bool operator!=(std::nullptr_t, const intrusive_ptr& p) noexcept { + return p.get() != nullptr; +} + +template +inline intrusive_ptr make_intrusive(Args&&... args) { + return intrusive_ptr(new T(std::forward(args)...), false); +} + +template +inline intrusive_ptr copy_intrusive(const intrusive_ptr& rhs) { + return intrusive_ptr(rhs.get(), true); +} + +} // namespace tcmpt +} // namespace paddle diff --git a/paddle/tcmpt/core/utils/intrusive_ref_counter.h b/paddle/tcmpt/core/utils/intrusive_ref_counter.h new file mode 100644 index 0000000000000..1c93bede71df1 --- /dev/null +++ b/paddle/tcmpt/core/utils/intrusive_ref_counter.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +namespace paddle { +namespace tcmpt { + +template +class intrusive_ref_counter; +template +void intrusive_ptr_add_ref(const intrusive_ref_counter* p) noexcept; +template +void intrusive_ptr_release(const intrusive_ref_counter* p) noexcept; + +template +class intrusive_ref_counter { + public: + constexpr intrusive_ref_counter() noexcept : ref_(1) {} + virtual ~intrusive_ref_counter() = default; + + unsigned int use_count() const noexcept { return ref_.load(); } + + protected: + intrusive_ref_counter(const intrusive_ref_counter&) = delete; + intrusive_ref_counter& operator=(const intrusive_ref_counter&) = delete; + + friend void intrusive_ptr_add_ref( + const intrusive_ref_counter* p) noexcept; + friend void intrusive_ptr_release( + const intrusive_ref_counter* p) noexcept; + + private: + mutable std::atomic_int_fast32_t ref_; +}; + +template +inline void intrusive_ptr_add_ref( + const intrusive_ref_counter* p) noexcept { + p->ref_.fetch_add(1, std::memory_order_relaxed); +} + +template +inline void intrusive_ptr_release( + const intrusive_ref_counter* p) noexcept { + if (p->ref_.load(std::memory_order_acquire) == 0 || + p->ref_.fetch_sub(1) == 0) { + delete static_cast(p); + } +} + +} // namespace tcmpt +} // namespace paddle diff --git a/paddle/tcmpt/core/utils/type_info.h b/paddle/tcmpt/core/utils/type_info.h new file mode 100644 index 0000000000000..ba5bc641b94b2 --- /dev/null +++ b/paddle/tcmpt/core/utils/type_info.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +namespace paddle { +namespace tcmpt { + +template +class TypeRegistry; + +template +class TypeInfo { + public: + const std::string& name() const; + + int8_t id() const { return id_; } + + bool operator==(TypeInfo other) const { return id_ == other.id(); } + bool operator!=(TypeInfo other) const { return id_ != other.id(); } + + static const TypeInfo kUnknownType; + + private: + friend class TypeRegistry; + explicit TypeInfo(int8_t id) : id_(id) {} + int8_t id_; +}; + +template +class TypeInfoTraits { + public: + static const TypeInfo kType; + TypeInfoTraits() { + static_cast(static_cast(this))->type_info_ = kType; + } + static bool classof(const BaseT* obj) { return obj->type_info() == kType; } +}; + +template +TypeInfo RegisterStaticType(const std::string& type); + +template +const TypeInfo TypeInfoTraits::kType = + RegisterStaticType(DerivedT::name()); + +} // namespace tcmpt +} // namespace paddle diff --git a/paddle/tcmpt/core/utils/type_registry.h b/paddle/tcmpt/core/utils/type_registry.h new file mode 100644 index 0000000000000..52b699a0dd413 --- /dev/null +++ b/paddle/tcmpt/core/utils/type_registry.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/tcmpt/core/utils/type_info.h" + +namespace paddle { +namespace tcmpt { + +template +class TypeRegistry { + public: + TypeRegistry(const TypeRegistry&) = delete; + TypeRegistry& operator=(const TypeRegistry&) = delete; + + static TypeRegistry& GetInstance(); + + TypeInfo RegisterType(const std::string& type); + const std::string& GetTypeName(TypeInfo info) const; + + private: + TypeRegistry() = default; + mutable std::mutex mutex_; + std::vector names_; + std::map name_to_id_; +}; + +template +TypeRegistry& TypeRegistry::GetInstance() { + static TypeRegistry registry; + return registry; +} + +template +TypeInfo TypeRegistry::RegisterType(const std::string& type) { + std::lock_guard guard(mutex_); + assert(name_to_id_.find(type) == name_to_id_.end()); + assert(names_.size() < std::numeric_limits::max()); + int8_t id = names_.size(); + names_.emplace_back(type); + name_to_id_[type] = id; + return TypeInfo(id); +} + +template +const std::string& TypeRegistry::GetTypeName( + TypeInfo info) const { + std::lock_guard guard(mutex_); + int8_t id = info.id(); + assert(id >= 0); + assert(static_cast(id) < names_.size()); + return names_[id]; +} + +template +TypeInfo RegisterStaticType(const std::string& type) { + return TypeRegistry::GetInstance().RegisterType(type); +} + +template +const std::string& TypeInfo::name() const { + return TypeRegistry::GetInstance().GetTypeName(*this); +} + +template +const TypeInfo TypeInfo::kUnknownType = + RegisterStaticType("Unknown"); + +} // namespace tcmpt +} // namespace paddle diff --git a/paddle/tcmpt/hapi/include/creation.h b/paddle/tcmpt/hapi/include/creation.h index f502adb2e2472..d2d68e3bb7e61 100644 --- a/paddle/tcmpt/hapi/include/creation.h +++ b/paddle/tcmpt/hapi/include/creation.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/tcmpt/core/dtype.h" +#include "paddle/tcmpt/common/data_type.h" #include "paddle/tcmpt/core/scalar.h" #include "paddle/tcmpt/hapi/include/tensor.h" diff --git a/paddle/tcmpt/hapi/include/tensor.h b/paddle/tcmpt/hapi/include/tensor.h index eb64d66435c90..ccca911cf8c86 100644 --- a/paddle/tcmpt/hapi/include/tensor.h +++ b/paddle/tcmpt/hapi/include/tensor.h @@ -18,7 +18,7 @@ limitations under the License. */ #include #include -#include "paddle/tcmpt/core/tensor_interface.h" +#include "paddle/tcmpt/core/tensor_base.h" /** * [ Why still include the fluid headers? ] @@ -73,7 +73,7 @@ class AutogradMetaInterface { * letters and underscores. * * Note: Tensor cannot be inherited. The heterogeneous Tensor implementation - * can be achieved by inheriting the underlying TensorInterface. + * can be achieved by inheriting the underlying TensorBase. * * Note: This Tensor API is suitable for training and custom operators, * another simple Tensor design may be required for inference. @@ -88,10 +88,10 @@ class Tensor final { /** * @description: Use a TensorImpl pointer to construct a Tensor - * @param {shared_ptr} tensor_impl + * @param {shared_ptr} tensor_impl * @return {Tensor} */ - explicit Tensor(std::shared_ptr tensor_impl) + explicit Tensor(std::shared_ptr tensor_impl) : impl_(std::move(tensor_impl)) { if (impl_.get() == nullptr) { throw std::runtime_error("TensorImpl with nullptr is not supported"); @@ -111,14 +111,14 @@ class Tensor final { * @param None * @return {DDim} */ - pt::DDim shape() const { return impl_->dims(); } + paddle::framework::DDim shape() const { return impl_->dims(); } /** * @description: Return the data type of current Tensor. * @param None * @return {DataType} */ - pt::DataType type() const { return impl_->type(); } + pt::DataType type() const { return impl_->data_type(); } /** * @description: Return the layout of current Tensor. @@ -133,7 +133,7 @@ class Tensor final { * @param None * @return {Place} */ - pt::Place place() const { return impl_->place(); } + paddle::platform::Place place() const { return impl_->place(); } /** * Backend judgment APIs, shield the concept of Backend. @@ -163,16 +163,16 @@ class Tensor final { /** * @description: Return the implemention of current Tensor. * @param None - * @return {std::shared_ptr} + * @return {std::shared_ptr} */ - std::shared_ptr impl() const { return impl_; } + std::shared_ptr impl() const { return impl_; } /** * @description: Set the implemention of current Tensor. - * @param {std::shared_ptr} + * @param {std::shared_ptr} * @return None */ - void set_impl(const std::shared_ptr& impl) { + void set_impl(const std::shared_ptr& impl) { impl_ = impl; } @@ -245,7 +245,7 @@ class Tensor final { * heterogeneous Tensor implementation, so that the API level can be unified * to one `Tensor`. */ - std::shared_ptr impl_; + std::shared_ptr impl_; /** * [ Why need abstract AutogradMetaInterface here? ] diff --git a/paddle/tcmpt/kernels/cpu/utils.cc b/paddle/tcmpt/kernels/cpu/utils.cc index 7550934d70be4..a50cfad481693 100644 --- a/paddle/tcmpt/kernels/cpu/utils.cc +++ b/paddle/tcmpt/kernels/cpu/utils.cc @@ -14,8 +14,8 @@ limitations under the License. */ #include "paddle/tcmpt/kernels/cpu/utils.h" #include "paddle/fluid/memory/memcpy.h" +#include "paddle/tcmpt/common/data_type.h" #include "paddle/tcmpt/core/convert_utils.h" -#include "paddle/tcmpt/core/dtype.h" namespace pt { @@ -37,8 +37,8 @@ void Copy(const CPUContext& dev_ctx, const DenseTensor& src, DenseTensor* dst) { << dst_place; dst->Resize(src.dims()); dst->mutable_meta()->layout = src.meta().layout; - auto size = src.numel() * - paddle::framework::SizeOfType(TransToProtoVarType(src.type())); + auto size = src.numel() * paddle::framework::SizeOfType( + TransToProtoVarType(src.data_type())); if (paddle::platform::is_cpu_place(src_place) && paddle::platform::is_cpu_place(dst_place)) { diff --git a/paddle/tcmpt/kernels/cuda/math.cu b/paddle/tcmpt/kernels/cuda/math.cu index f0d76744f68bd..113971126a71f 100644 --- a/paddle/tcmpt/kernels/cuda/math.cu +++ b/paddle/tcmpt/kernels/cuda/math.cu @@ -78,7 +78,7 @@ void Mean(const CUDAContext& dev_ctx, const DenseTensor& x, DenseTensor* out) { TensorMeta(paddle::framework::make_ddim( {static_cast(temp_storage_bytes)}), pt::TransToPtBackend(dev_ctx.GetPlace()), - x.type(), + x.data_type(), x.layout()), TensorStatus()); auto* temp_storage = tmp.mutable_data(); diff --git a/paddle/tcmpt/kernels/cuda/utils.cu b/paddle/tcmpt/kernels/cuda/utils.cu index b8483d17cfc24..00b32e2fbb10a 100644 --- a/paddle/tcmpt/kernels/cuda/utils.cu +++ b/paddle/tcmpt/kernels/cuda/utils.cu @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/memory/memcpy.h" +#include "paddle/tcmpt/common/data_type.h" #include "paddle/tcmpt/core/convert_utils.h" -#include "paddle/tcmpt/core/dtype.h" #include "paddle/tcmpt/core/kernel_registry.h" #include "paddle/tcmpt/kernels/cuda/utils.h" @@ -40,8 +40,8 @@ void Copy(const CUDAContext& dev_ctx, << dst_place; dst->Resize(src.dims()); dst->mutable_meta()->layout = src.meta().layout; - auto size = src.numel() * - paddle::framework::SizeOfType(TransToProtoVarType(src.type())); + auto size = src.numel() * paddle::framework::SizeOfType( + TransToProtoVarType(src.data_type())); if (paddle::platform::is_cuda_pinned_place(src_place) && // NOLINT paddle::platform::is_cuda_pinned_place(dst_place)) { diff --git a/paddle/tcmpt/tests/dense_tensor_test.cc b/paddle/tcmpt/tests/dense_tensor_test.cc index 633e787159444..138ef1e30e76e 100644 --- a/paddle/tcmpt/tests/dense_tensor_test.cc +++ b/paddle/tcmpt/tests/dense_tensor_test.cc @@ -28,7 +28,7 @@ TEST(DenseTensor, Constructor) { pt::TensorStatus()); ASSERT_EQ(tensor.dims().size(), 2); ASSERT_EQ(tensor.backend(), pt::Backend::kCPU); - ASSERT_EQ(tensor.type(), pt::DataType::kFLOAT32); + ASSERT_EQ(tensor.data_type(), pt::DataType::kFLOAT32); ASSERT_EQ(tensor.layout(), pt::DataLayout::kNCHW); }