From 466ce03d3e718e81a066f85e97e789d94c14d636 Mon Sep 17 00:00:00 2001 From: zyfncg <1370305206@qq.com> Date: Wed, 20 Oct 2021 10:31:22 +0800 Subject: [PATCH] Rename tcmpt to pten (#23) * rename tcmpt to pten * update omitted files for rename to pten * update omitted file for rename to pten --- cmake/generic.cmake | 22 ++-- cmake/{tcmpt.cmake => pten.cmake} | 10 +- paddle/CMakeLists.txt | 2 +- paddle/fluid/framework/CMakeLists.txt | 8 +- paddle/fluid/framework/operator.cc | 24 ++-- paddle/fluid/framework/operator.h | 12 +- .../{tcmpt_utils.cc => pten_utils.cc} | 107 +++++++++-------- .../framework/{tcmpt_utils.h => pten_utils.h} | 22 ++-- ...tcmpt_utils_test.cc => pten_utils_test.cc} | 19 +-- paddle/fluid/imperative/CMakeLists.txt | 4 +- paddle/fluid/imperative/prepared_operator.cc | 29 ++--- paddle/fluid/imperative/prepared_operator.h | 8 +- paddle/fluid/inference/CMakeLists.txt | 8 +- paddle/fluid/operators/CMakeLists.txt | 4 +- paddle/fluid/operators/dot_op.h | 18 +-- paddle/fluid/operators/fill_any_like_op.h | 16 +-- paddle/fluid/operators/mean_op.h | 14 +-- paddle/fluid/operators/scale_op.h | 16 +-- paddle/fluid/operators/sign_op.h | 16 +-- paddle/fluid/platform/CMakeLists.txt | 2 +- paddle/fluid/pybind/op_function_generator.cc | 4 +- paddle/pten/CMakeLists.txt | 15 +++ paddle/pten/api/CMakeLists.txt | 21 ++++ paddle/{tcmpt => pten}/api/all.cc | 4 +- paddle/{tcmpt => pten}/api/all.h | 12 +- paddle/{tcmpt => pten}/api/include/core.h | 10 +- paddle/{tcmpt => pten}/api/include/creation.h | 4 +- .../{tcmpt => pten}/api/include/infershape.h | 4 +- paddle/{tcmpt => pten}/api/include/linalg.h | 4 +- .../api/include/manipulation.h | 4 +- paddle/{tcmpt => pten}/api/include/math.h | 4 +- paddle/{tcmpt => pten}/api/include/symbols.h | 2 +- paddle/{tcmpt => pten}/common/data_type.h | 2 +- paddle/{tcmpt => pten}/common/layout.h | 2 +- paddle/{tcmpt => pten}/core/CMakeLists.txt | 0 paddle/{tcmpt => pten}/core/allocator.cc | 6 +- paddle/{tcmpt => pten}/core/allocator.h | 16 ++- paddle/{tcmpt => pten}/core/backend.cc | 6 +- paddle/{tcmpt => pten}/core/backend.h | 4 +- paddle/{tcmpt => pten}/core/convert_utils.cc | 22 ++-- paddle/{tcmpt => pten}/core/convert_utils.h | 10 +- paddle/{tcmpt => pten}/core/dense_tensor.cc | 10 +- paddle/{tcmpt => pten}/core/dense_tensor.h | 25 ++-- paddle/{tcmpt => pten}/core/kernel_context.cc | 4 +- paddle/{tcmpt => pten}/core/kernel_context.h | 7 +- paddle/{tcmpt => pten}/core/kernel_def.h | 4 +- paddle/{tcmpt => pten}/core/kernel_factory.cc | 14 ++- paddle/{tcmpt => pten}/core/kernel_factory.h | 12 +- paddle/{tcmpt => pten}/core/kernel_registry.h | 109 +++++++++--------- paddle/{tcmpt => pten}/core/kernel_utils.h | 16 +-- paddle/{tcmpt => pten}/core/scalar.h | 4 +- paddle/{tcmpt => pten}/core/spatial_tensor.h | 6 +- paddle/{tcmpt => pten}/core/storage.cc | 8 +- paddle/{tcmpt => pten}/core/storage.h | 18 +-- paddle/{tcmpt => pten}/core/tensor_base.cc | 8 +- paddle/{tcmpt => pten}/core/tensor_base.h | 24 ++-- paddle/{tcmpt => pten}/core/tensor_meta.h | 10 +- paddle/{tcmpt => pten}/core/tensor_status.h | 10 +- .../{tcmpt => pten}/core/utils/CMakeLists.txt | 0 .../core/utils/intrusive_ptr.h | 10 +- .../core/utils/intrusive_ref_counter.h | 6 +- paddle/{tcmpt => pten}/core/utils/type_info.h | 6 +- .../core/utils/type_registry.h | 8 +- paddle/pten/hapi/CMakeLists.txt | 3 + paddle/{tcmpt => pten}/hapi/all.cc | 2 +- paddle/{tcmpt => pten}/hapi/all.h | 8 +- .../{tcmpt => pten}/hapi/include/creation.h | 19 +-- paddle/{tcmpt => pten}/hapi/include/linalg.h | 2 +- .../hapi/include/manipulation.h | 2 +- paddle/{tcmpt => pten}/hapi/include/math.h | 2 +- paddle/{tcmpt => pten}/hapi/include/tensor.h | 24 ++-- paddle/pten/hapi/lib/CMakeLists.txt | 4 + paddle/{tcmpt => pten}/hapi/lib/creation.cc | 28 ++--- .../hapi/lib/kernel_generate.h | 24 ++-- paddle/{tcmpt => pten}/hapi/lib/linalg.cc | 28 ++--- .../{tcmpt => pten}/hapi/lib/manipulation.cc | 18 +-- paddle/{tcmpt => pten}/hapi/lib/math.cc | 20 ++-- .../{tcmpt => pten}/infershape/CMakeLists.txt | 0 paddle/{tcmpt => pten}/infershape/binary.cc | 6 +- paddle/{tcmpt => pten}/infershape/binary.h | 6 +- paddle/{tcmpt => pten}/infershape/unary.cc | 6 +- paddle/{tcmpt => pten}/infershape/unary.h | 6 +- paddle/{tcmpt => pten}/kernels/CMakeLists.txt | 2 +- .../kernels/common/eigen/CMakeLists.txt | 0 .../kernels/common/eigen/common.h | 31 ++--- .../kernels/common/eigen/dot.h | 20 ++-- .../kernels/common/eigen/fill.h | 10 +- .../kernels/common/eigen/mean.h | 12 +- .../kernels/common/eigen/scale.h | 12 +- .../kernels/common/eigen/sign.h | 12 +- .../kernels/cpu/CMakeLists.txt | 2 +- .../{tcmpt => pten}/kernels/cpu/creation.cc | 12 +- paddle/{tcmpt => pten}/kernels/cpu/creation.h | 8 +- paddle/{tcmpt => pten}/kernels/cpu/linalg.cc | 10 +- paddle/{tcmpt => pten}/kernels/cpu/linalg.h | 6 +- .../kernels/cpu/manipulation.cc | 16 +-- .../kernels/cpu/manipulation.h | 8 +- paddle/{tcmpt => pten}/kernels/cpu/math.cc | 22 ++-- paddle/{tcmpt => pten}/kernels/cpu/math.h | 8 +- paddle/{tcmpt => pten}/kernels/cpu/utils.cc | 12 +- paddle/{tcmpt => pten}/kernels/cpu/utils.h | 8 +- .../kernels/cuda/CMakeLists.txt | 2 +- .../{tcmpt => pten}/kernels/cuda/creation.cu | 12 +- .../{tcmpt => pten}/kernels/cuda/creation.h | 8 +- paddle/{tcmpt => pten}/kernels/cuda/linalg.cu | 12 +- paddle/{tcmpt => pten}/kernels/cuda/linalg.h | 6 +- .../kernels/cuda/manipulation.cu | 16 +-- .../kernels/cuda/manipulation.h | 6 +- paddle/{tcmpt => pten}/kernels/cuda/math.cu | 30 ++--- paddle/{tcmpt => pten}/kernels/cuda/math.h | 6 +- paddle/{tcmpt => pten}/kernels/cuda/utils.cu | 14 +-- paddle/{tcmpt => pten}/kernels/cuda/utils.h | 8 +- .../kernels/mkldnn/CMakeLists.txt | 0 .../kernels/npu/CMakeLists.txt | 0 .../kernels/xpu/CMakeLists.txt | 0 paddle/{tcmpt => pten}/module/CMakeLists.txt | 0 paddle/{tcmpt => pten}/tests/CMakeLists.txt | 0 paddle/{tcmpt => pten}/tests/backend_test.cc | 2 +- .../tests/dense_tensor_test.cc | 21 ++-- paddle/{tcmpt => pten}/tests/dtype_test.cc | 0 .../tests/kernel_factory_test.cc | 7 +- paddle/{tcmpt => pten}/tests/layout_test.cc | 0 paddle/{tcmpt => pten}/tests/test_copy_api.cc | 32 ++--- paddle/{tcmpt => pten}/tests/test_dot_api.cc | 36 +++--- paddle/{tcmpt => pten}/tests/test_fill_api.cc | 69 +++++------ .../{tcmpt => pten}/tests/test_flatten_api.cc | 24 ++-- paddle/{tcmpt => pten}/tests/test_mean_api.cc | 24 ++-- paddle/tcmpt/CMakeLists.txt | 15 --- paddle/tcmpt/api/CMakeLists.txt | 21 ---- paddle/tcmpt/hapi/CMakeLists.txt | 3 - paddle/tcmpt/hapi/lib/CMakeLists.txt | 4 - 131 files changed, 820 insertions(+), 813 deletions(-) rename cmake/{tcmpt.cmake => pten.cmake} (84%) rename paddle/fluid/framework/{tcmpt_utils.cc => pten_utils.cc} (68%) rename paddle/fluid/framework/{tcmpt_utils.h => pten_utils.h} (83%) rename paddle/fluid/framework/{tcmpt_utils_test.cc => pten_utils_test.cc} (73%) create mode 100644 paddle/pten/CMakeLists.txt create mode 100644 paddle/pten/api/CMakeLists.txt rename paddle/{tcmpt => pten}/api/all.cc (89%) rename paddle/{tcmpt => pten}/api/all.h (69%) rename paddle/{tcmpt => pten}/api/include/core.h (75%) rename paddle/{tcmpt => pten}/api/include/creation.h (87%) rename paddle/{tcmpt => pten}/api/include/infershape.h (88%) rename paddle/{tcmpt => pten}/api/include/linalg.h (88%) rename paddle/{tcmpt => pten}/api/include/manipulation.h (87%) rename paddle/{tcmpt => pten}/api/include/math.h (88%) rename paddle/{tcmpt => pten}/api/include/symbols.h (94%) rename paddle/{tcmpt => pten}/common/data_type.h (99%) rename paddle/{tcmpt => pten}/common/layout.h (98%) rename paddle/{tcmpt => pten}/core/CMakeLists.txt (100%) rename paddle/{tcmpt => pten}/core/allocator.cc (82%) rename paddle/{tcmpt => pten}/core/allocator.h (93%) rename paddle/{tcmpt => pten}/core/backend.cc (94%) rename paddle/{tcmpt => pten}/core/backend.h (97%) rename paddle/{tcmpt => pten}/core/convert_utils.cc (94%) rename paddle/{tcmpt => pten}/core/convert_utils.h (90%) rename paddle/{tcmpt => pten}/core/dense_tensor.cc (95%) rename paddle/{tcmpt => pten}/core/dense_tensor.h (88%) rename paddle/{tcmpt => pten}/core/kernel_context.cc (88%) rename paddle/{tcmpt => pten}/core/kernel_context.h (97%) rename paddle/{tcmpt => pten}/core/kernel_def.h (97%) rename paddle/{tcmpt => pten}/core/kernel_factory.cc (91%) rename paddle/{tcmpt => pten}/core/kernel_factory.h (97%) rename paddle/{tcmpt => pten}/core/kernel_registry.h (91%) rename paddle/{tcmpt => pten}/core/kernel_utils.h (96%) rename paddle/{tcmpt => pten}/core/scalar.h (97%) rename paddle/{tcmpt => pten}/core/spatial_tensor.h (95%) rename paddle/{tcmpt => pten}/core/storage.cc (85%) rename paddle/{tcmpt => pten}/core/storage.h (85%) rename paddle/{tcmpt => pten}/core/tensor_base.cc (81%) rename paddle/{tcmpt => pten}/core/tensor_base.h (81%) rename paddle/{tcmpt => pten}/core/tensor_meta.h (96%) rename paddle/{tcmpt => pten}/core/tensor_status.h (92%) rename paddle/{tcmpt => pten}/core/utils/CMakeLists.txt (100%) rename paddle/{tcmpt => pten}/core/utils/intrusive_ptr.h (95%) rename paddle/{tcmpt => pten}/core/utils/intrusive_ref_counter.h (96%) rename paddle/{tcmpt => pten}/core/utils/type_info.h (95%) rename paddle/{tcmpt => pten}/core/utils/type_registry.h (94%) create mode 100644 paddle/pten/hapi/CMakeLists.txt rename paddle/{tcmpt => pten}/hapi/all.cc (95%) rename paddle/{tcmpt => pten}/hapi/all.h (77%) rename paddle/{tcmpt => pten}/hapi/include/creation.h (56%) rename paddle/{tcmpt => pten}/hapi/include/linalg.h (95%) rename paddle/{tcmpt => pten}/hapi/include/manipulation.h (94%) rename paddle/{tcmpt => pten}/hapi/include/math.h (94%) rename paddle/{tcmpt => pten}/hapi/include/tensor.h (91%) create mode 100644 paddle/pten/hapi/lib/CMakeLists.txt rename paddle/{tcmpt => pten}/hapi/lib/creation.cc (65%) rename paddle/{tcmpt => pten}/hapi/lib/kernel_generate.h (86%) rename paddle/{tcmpt => pten}/hapi/lib/linalg.cc (69%) rename paddle/{tcmpt => pten}/hapi/lib/manipulation.cc (77%) rename paddle/{tcmpt => pten}/hapi/lib/math.cc (75%) rename paddle/{tcmpt => pten}/infershape/CMakeLists.txt (100%) rename paddle/{tcmpt => pten}/infershape/binary.cc (96%) rename paddle/{tcmpt => pten}/infershape/binary.h (94%) rename paddle/{tcmpt => pten}/infershape/unary.cc (96%) rename paddle/{tcmpt => pten}/infershape/unary.h (94%) rename paddle/{tcmpt => pten}/kernels/CMakeLists.txt (94%) rename paddle/{tcmpt => pten}/kernels/common/eigen/CMakeLists.txt (100%) rename paddle/{tcmpt => pten}/kernels/common/eigen/common.h (86%) rename paddle/{tcmpt => pten}/kernels/common/eigen/dot.h (72%) rename paddle/{tcmpt => pten}/kernels/common/eigen/fill.h (91%) rename paddle/{tcmpt => pten}/kernels/common/eigen/mean.h (82%) rename paddle/{tcmpt => pten}/kernels/common/eigen/scale.h (85%) rename paddle/{tcmpt => pten}/kernels/common/eigen/sign.h (84%) rename paddle/{tcmpt => pten}/kernels/cpu/CMakeLists.txt (89%) rename paddle/{tcmpt => pten}/kernels/cpu/creation.cc (84%) rename paddle/{tcmpt => pten}/kernels/cpu/creation.h (88%) rename paddle/{tcmpt => pten}/kernels/cpu/linalg.cc (92%) rename paddle/{tcmpt => pten}/kernels/cpu/linalg.h (93%) rename paddle/{tcmpt => pten}/kernels/cpu/manipulation.cc (89%) rename paddle/{tcmpt => pten}/kernels/cpu/manipulation.h (88%) rename paddle/{tcmpt => pten}/kernels/cpu/math.cc (85%) rename paddle/{tcmpt => pten}/kernels/cpu/math.h (91%) rename paddle/{tcmpt => pten}/kernels/cpu/utils.cc (89%) rename paddle/{tcmpt => pten}/kernels/cpu/utils.h (87%) rename paddle/{tcmpt => pten}/kernels/cuda/CMakeLists.txt (94%) rename paddle/{tcmpt => pten}/kernels/cuda/creation.cu (84%) rename paddle/{tcmpt => pten}/kernels/cuda/creation.h (89%) rename paddle/{tcmpt => pten}/kernels/cuda/linalg.cu (86%) rename paddle/{tcmpt => pten}/kernels/cuda/linalg.h (92%) rename paddle/{tcmpt => pten}/kernels/cuda/manipulation.cu (90%) rename paddle/{tcmpt => pten}/kernels/cuda/manipulation.h (93%) rename paddle/{tcmpt => pten}/kernels/cuda/math.cu (85%) rename paddle/{tcmpt => pten}/kernels/cuda/math.h (94%) rename paddle/{tcmpt => pten}/kernels/cuda/utils.cu (97%) rename paddle/{tcmpt => pten}/kernels/cuda/utils.h (87%) rename paddle/{tcmpt => pten}/kernels/mkldnn/CMakeLists.txt (100%) rename paddle/{tcmpt => pten}/kernels/npu/CMakeLists.txt (100%) rename paddle/{tcmpt => pten}/kernels/xpu/CMakeLists.txt (100%) rename paddle/{tcmpt => pten}/module/CMakeLists.txt (100%) rename paddle/{tcmpt => pten}/tests/CMakeLists.txt (100%) rename paddle/{tcmpt => pten}/tests/backend_test.cc (94%) rename paddle/{tcmpt => pten}/tests/dense_tensor_test.cc (62%) rename paddle/{tcmpt => pten}/tests/dtype_test.cc (100%) rename paddle/{tcmpt => pten}/tests/kernel_factory_test.cc (75%) rename paddle/{tcmpt => pten}/tests/layout_test.cc (100%) rename paddle/{tcmpt => pten}/tests/test_copy_api.cc (64%) rename paddle/{tcmpt => pten}/tests/test_dot_api.cc (67%) rename paddle/{tcmpt => pten}/tests/test_fill_api.cc (54%) rename paddle/{tcmpt => pten}/tests/test_flatten_api.cc (72%) rename paddle/{tcmpt => pten}/tests/test_mean_api.cc (69%) delete mode 100644 paddle/tcmpt/CMakeLists.txt delete mode 100644 paddle/tcmpt/api/CMakeLists.txt delete mode 100644 paddle/tcmpt/hapi/CMakeLists.txt delete mode 100644 paddle/tcmpt/hapi/lib/CMakeLists.txt diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 12b4530a77a4c..2004abcbfa1f2 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -116,19 +116,19 @@ function(find_fluid_modules TARGET_NAME) endif() endfunction(find_fluid_modules) -set_property(GLOBAL PROPERTY TCMPT_MODULES "") -# find all tcmpt modules is used for paddle static library +set_property(GLOBAL PROPERTY PTEN_MODULES "") +# find all pten modules is used for paddle static library # for building inference libs -function(find_tcmpt_modules TARGET_NAME) +function(find_pten_modules TARGET_NAME) get_filename_component(__target_path ${TARGET_NAME} ABSOLUTE) string(REGEX REPLACE "^${PADDLE_SOURCE_DIR}/" "" __target_path ${__target_path}) - string(FIND "${__target_path}" "tcmpt" pos) + string(FIND "${__target_path}" "pten" pos) if(pos GREATER 1) - get_property(tcmpt_modules GLOBAL PROPERTY TCMPT_MODULES) - set(tcmpt_modules ${tcmpt_modules} ${TARGET_NAME}) - set_property(GLOBAL PROPERTY TCMPT_MODULES "${tcmpt_modules}") + get_property(pten_modules GLOBAL PROPERTY PTEN_MODULES) + set(pten_modules ${pten_modules} ${TARGET_NAME}) + set_property(GLOBAL PROPERTY PTEN_MODULES "${pten_modules}") endif() -endfunction(find_tcmpt_modules) +endfunction(find_pten_modules) function(common_link TARGET_NAME) if (WITH_PROFILER) @@ -324,7 +324,7 @@ function(cc_library TARGET_NAME) else() add_library(${TARGET_NAME} STATIC ${cc_library_SRCS}) find_fluid_modules(${TARGET_NAME}) - find_tcmpt_modules(${TARGET_NAME}) + find_pten_modules(${TARGET_NAME}) endif() if(cc_library_DEPS) # Don't need link libwarpctc.so @@ -497,7 +497,7 @@ function(nv_library TARGET_NAME) else() add_library(${TARGET_NAME} STATIC ${nv_library_SRCS}) find_fluid_modules(${TARGET_NAME}) - find_tcmpt_modules(${TARGET_NAME}) + find_pten_modules(${TARGET_NAME}) endif() if (nv_library_DEPS) add_dependencies(${TARGET_NAME} ${nv_library_DEPS}) @@ -588,7 +588,7 @@ function(hip_library TARGET_NAME) else() hip_add_library(${TARGET_NAME} STATIC ${hip_library_SRCS}) find_fluid_modules(${TARGET_NAME}) - find_tcmpt_modules(${TARGET_NAME}) + find_pten_modules(${TARGET_NAME}) endif() if (hip_library_DEPS) add_dependencies(${TARGET_NAME} ${hip_library_DEPS}) diff --git a/cmake/tcmpt.cmake b/cmake/pten.cmake similarity index 84% rename from cmake/tcmpt.cmake rename to cmake/pten.cmake index 819cd42287974..bfe75475edcc0 100644 --- a/cmake/tcmpt.cmake +++ b/cmake/pten.cmake @@ -29,13 +29,13 @@ function(kernel_instantiate TARGET) string(REGEX MATCH "[A-Z][A-Za-z0-9]+\\(" func_name ${signature}) string(REPLACE "(" "" func_name ${func_name}) # message(STATUS "FUNC NAME: ${func_name}") - string(REGEX REPLACE "${func_name}" "pt::${func_name}<${dtype}>" inst_signature ${signature}) + string(REGEX REPLACE "${func_name}" "pten::${func_name}<${dtype}>" inst_signature ${signature}) # append namespace - string(REPLACE "CPUContext" "pt::CPUContext" inst_signature ${inst_signature}) - string(REPLACE "CUDAContext" "pt::CUDAContext" inst_signature ${inst_signature}) - string(REPLACE "DenseTensor" "pt::DenseTensor" inst_signature ${inst_signature}) + string(REPLACE "CPUContext" "pten::CPUContext" inst_signature ${inst_signature}) + string(REPLACE "CUDAContext" "pten::CUDAContext" inst_signature ${inst_signature}) + string(REPLACE "DenseTensor" "pten::DenseTensor" inst_signature ${inst_signature}) # TODO(chenweihang): adapt SelectedRows after adding it - # string(REPLACE "SelectedRowsTensor" "pt::SelectedRowsTensor" inst_signature ${inst_signature}) + # string(REPLACE "SelectedRowsTensor" "pten::SelectedRowsTensor" inst_signature ${inst_signature}) # message(STATUS "INST FUNC: ${inst_signature}") string(APPEND instantiate_context "template ${inst_signature};\n") endforeach() diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index ce3f6973e7a68..b3a1b2e8c9587 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -1,5 +1,5 @@ add_subdirectory(scripts) add_subdirectory(testing) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory") -add_subdirectory(tcmpt) +add_subdirectory(pten) add_subdirectory(fluid) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 27f83a266ec9c..b1f23e50d31d2 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -193,10 +193,10 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va IF(WITH_XPU) cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto - shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils tcmpt tcmpt_utils) + shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils pten pten_utils) ELSE() cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto - shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils tcmpt tcmpt_utils) + shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils pten pten_utils) ENDIF() cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) @@ -390,7 +390,7 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer) cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer) cc_library(generator SRCS generator.cc DEPS enforce place) -cc_library(tcmpt_utils SRCS tcmpt_utils.cc DEPS lod_tensor selected_rows place tcmpt var_type_traits) +cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits) # Get the current working branch execute_process( @@ -454,4 +454,4 @@ if(WITH_TESTING AND TEST selected_rows_test) endif() cc_test(scope_guard_test SRCS scope_guard_test.cc) -cc_test(tcmpt_utils_test SRCS tcmpt_utils_test.cc DEPS tcmpt_utils) +cc_test(pten_utils_test SRCS pten_utils_test.cc DEPS pten_utils) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 5a1c03327d592..d2704f046cb36 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -23,8 +23,8 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/op_call_stack.h" +#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/shape_inference.h" -#include "paddle/fluid/framework/tcmpt_utils.h" #include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/unused_var_check.h" #include "paddle/fluid/framework/var_type.h" @@ -1140,7 +1140,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second // phase if (FLAGS_run_pt_kernel && - pt::KernelFactory::Instance().ContainsKernel(type_.c_str())) { + pten::KernelFactory::Instance().ContainsKernel(type_.c_str())) { if (pt_kernel_signature_.get() == nullptr || pt_kernel_.get() == nullptr) { ChoosePtKernel(exe_ctx); } @@ -1286,10 +1286,11 @@ void OperatorWithKernel::ChoosePtKernel(const ExecutionContext& ctx) const { kernel_type_.reset(new OpKernelType(InnerGetExpectedKernelType(ctx))); - auto pt_kernel_name = pt::KernelName(pt_kernel_signature_->first); + auto pt_kernel_name = pten::KernelName(pt_kernel_signature_->first); auto pt_kernel_key = TransOpKernelTypeToPtKernelKey(*kernel_type_.get()); - pt_kernel_.reset(new pt::Kernel(pt::KernelFactory::Instance().SelectKernel( - pt_kernel_name, pt_kernel_key))); + pt_kernel_.reset( + new pten::Kernel(pten::KernelFactory::Instance().SelectKernel( + pt_kernel_name, pt_kernel_key))); if (pt_kernel_->IsValid()) { VLOG(1) << "Static mode ChoosePtKernel - kernel name: " << pt_kernel_name @@ -1781,7 +1782,7 @@ KernelSignature OperatorWithKernel::GetExpectedPtKernelArgs( } } -pt::KernelContext OperatorWithKernel::BuildPtKernelContext( +pten::KernelContext OperatorWithKernel::BuildPtKernelContext( const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const { VLOG(1) << RuntimeContextDebugString(ctx); @@ -1792,7 +1793,7 @@ pt::KernelContext OperatorWithKernel::BuildPtKernelContext( // 3. needless attributes remove // 4. use pt Tensor directly // 5. kernel input is not DenseTensor - pt::KernelContext op_kernel_ctx(dev_ctx); + pten::KernelContext op_kernel_ctx(dev_ctx); auto& input_names = std::get<0>(pt_kernel_signature_->second); auto& attr_names = std::get<1>(pt_kernel_signature_->second); @@ -1826,7 +1827,7 @@ pt::KernelContext OperatorWithKernel::BuildPtKernelContext( << in_def.layout; auto ins_vector = ctx.inputs.at(input_names[i]); - std::vector> tmp_inputs; + std::vector> tmp_inputs; for (auto var : ins_vector) { auto pt_in = framework::InputVariableToPtTensor(*var, in_def); @@ -1839,7 +1840,7 @@ pt::KernelContext OperatorWithKernel::BuildPtKernelContext( auto out_def = output_defs.at(i); auto outs_vector = ctx.outputs.at(output_names[i]); - std::vector> tmp_outputs; + std::vector> tmp_outputs; for (auto var : outs_vector) { auto pt_out = framework::OutputVariableToPtTensor(var, out_def); tmp_outputs.emplace_back(pt_out); @@ -1849,12 +1850,13 @@ pt::KernelContext OperatorWithKernel::BuildPtKernelContext( for (size_t i = 0; i < attr_names.size(); ++i) { auto& attr = Attrs().at(attr_names[i]); - if (attr_defs[i].type_index == std::type_index(typeid(pt::Scalar))) { + if (attr_defs[i].type_index == std::type_index(typeid(pten::Scalar))) { // TODO(chenweihang): support other attrs later // TODO(zhangyunfei): Scalar should hold scaler type, and we should check // attribtue type by attr_defs if (std::type_index(attr.type()) == std::type_index(typeid(float))) { - op_kernel_ctx.EmplaceBackAttr(pt::Scalar(BOOST_GET_CONST(float, attr))); + op_kernel_ctx.EmplaceBackAttr( + pten::Scalar(BOOST_GET_CONST(float, attr))); } else { PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` to Scalar when construct " diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 7581b65e3b68b..29c60877b8116 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -39,7 +39,7 @@ limitations under the License. */ #include "paddle/fluid/platform/variant.h" #include "paddle/utils/flat_hash_map.h" -#include "paddle/tcmpt/api/include/core.h" +#include "paddle/pten/api/include/core.h" namespace paddle { namespace framework { @@ -531,7 +531,7 @@ class OperatorWithKernel : public OperatorBase { return kernel_type_->place_; } - /* member functions for adapting to tcmpt lib */ + /* member functions for adapting to pten lib */ /** In the Tensor calculation library, the new Kernel adopts a clearer and * more streamlined design. The arguments of the Kernel and the input and * output arguments registered in the original OpMaker do not match in some @@ -582,10 +582,10 @@ class OperatorWithKernel : public OperatorBase { Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx, const std::string& name) const; - /* member functions for adapting to tcmpt lib */ + /* member functions for adapting to pten lib */ void ChoosePtKernel(const ExecutionContext& ctx) const; - pt::KernelContext BuildPtKernelContext( + pten::KernelContext BuildPtKernelContext( const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const; protected: @@ -599,11 +599,11 @@ class OperatorWithKernel : public OperatorBase { mutable std::mutex cache_update_mutex_; mutable bool enable_cache_transfer_scope_ = false; // NOTE(chenweihang): Similar op members are used to adapt to - // new tcmpt kernel, if there is a better design in the future, + // new pten kernel, if there is a better design in the future, // we may polish the implementation here mutable bool run_pt_kernel_ = false; mutable std::unique_ptr pt_kernel_signature_; - mutable std::unique_ptr pt_kernel_; + mutable std::unique_ptr pt_kernel_; }; extern bool OpSupportGPU(const std::string& op_type); diff --git a/paddle/fluid/framework/tcmpt_utils.cc b/paddle/fluid/framework/pten_utils.cc similarity index 68% rename from paddle/fluid/framework/tcmpt_utils.cc rename to paddle/fluid/framework/pten_utils.cc index fc38eb42d74c7..22d07e0d38fdb 100644 --- a/paddle/fluid/framework/tcmpt_utils.cc +++ b/paddle/fluid/framework/pten_utils.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include -#include "paddle/fluid/framework/tcmpt_utils.h" +#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/selected_rows.h" @@ -26,13 +26,14 @@ namespace framework { // TODO(chenweihang, shixiaowei): adapt SelectedRows template <> -std::shared_ptr MakeTensorImpl( - const LoDTensor& tensor, pt::Backend backend, pt::DataType dtype, - pt::DataLayout layout) { +std::shared_ptr MakeTensorImpl( + const LoDTensor& tensor, pten::Backend backend, + paddle::experimental::DataType dtype, + paddle::experimental::DataLayout layout) { auto holder = tensor.Holder(); - auto tensor_impl = std::make_shared( - pt::TensorMeta(tensor.dims(), backend, dtype, layout, tensor.offset()), - pt::TensorStatus()); + auto tensor_impl = std::make_shared( + pten::TensorMeta(tensor.dims(), backend, dtype, layout, tensor.offset()), + pten::TensorStatus()); if (holder != nullptr) { tensor_impl->ShareAllocation(tensor.Holder()); @@ -43,13 +44,14 @@ std::shared_ptr MakeTensorImpl( } template <> -std::shared_ptr MakeTensorImpl( - const Tensor& tensor, pt::Backend backend, pt::DataType dtype, - pt::DataLayout layout) { +std::shared_ptr MakeTensorImpl( + const Tensor& tensor, pten::Backend backend, + paddle::experimental::DataType dtype, + paddle::experimental::DataLayout layout) { auto holder = tensor.Holder(); - auto tensor_impl = std::make_shared( - pt::TensorMeta(tensor.dims(), backend, dtype, layout, tensor.offset()), - pt::TensorStatus()); + auto tensor_impl = std::make_shared( + pten::TensorMeta(tensor.dims(), backend, dtype, layout, tensor.offset()), + pten::TensorStatus()); if (holder != nullptr) { tensor_impl->ShareAllocation(tensor.Holder()); @@ -60,26 +62,26 @@ std::shared_ptr MakeTensorImpl( } template <> -std::shared_ptr MakeTensorImpl( +std::shared_ptr MakeTensorImpl( const LoDTensor& tensor, const platform::Place& place, proto::VarType::Type type) { - return MakeTensorImpl( - tensor, pt::TransToPtBackend(place), pt::TransToPtDataType(type), - pt::TransToPtDataLayout(tensor.layout())); + return MakeTensorImpl( + tensor, pten::TransToPtBackend(place), pten::TransToPtDataType(type), + pten::TransToPtDataLayout(tensor.layout())); } template <> -std::shared_ptr MakeTensorImpl( +std::shared_ptr MakeTensorImpl( const Tensor& tensor, const platform::Place& place, proto::VarType::Type type) { - return MakeTensorImpl( - tensor, pt::TransToPtBackend(place), pt::TransToPtDataType(type), - pt::TransToPtDataLayout(tensor.layout())); + return MakeTensorImpl( + tensor, pten::TransToPtBackend(place), pten::TransToPtDataType(type), + pten::TransToPtDataLayout(tensor.layout())); } -std::shared_ptr InputVariableToPtTensor( - const framework::Variable& variable, const pt::TensorArgDef& arg_def) { - auto expected_place = pt::TransToFluidPlace(arg_def.backend); +std::shared_ptr InputVariableToPtTensor( + const framework::Variable& variable, const pten::TensorArgDef& arg_def) { + auto expected_place = pten::TransToFluidPlace(arg_def.backend); if (variable.template IsType()) { const auto& tensor = variable.template Get(); @@ -87,12 +89,12 @@ std::shared_ptr InputVariableToPtTensor( framework::LoDTensor tmp_tensor; framework::TensorCopySync(tensor, expected_place, &tmp_tensor); auto pt_in = - framework::MakeTensorImpl( + framework::MakeTensorImpl( tmp_tensor, arg_def.backend, arg_def.dtype, arg_def.layout); return pt_in; } else { auto pt_in = - framework::MakeTensorImpl( + framework::MakeTensorImpl( tensor, arg_def.backend, arg_def.dtype, arg_def.layout); return pt_in; } @@ -105,12 +107,12 @@ std::shared_ptr InputVariableToPtTensor( TensorCopySync(tensor.value(), expected_place, &tmp_tensor); // TODO(chenweihang): adapt SelectedRows by xiaowei's design auto pt_in = - framework::MakeTensorImpl( + framework::MakeTensorImpl( tmp_tensor, arg_def.backend, arg_def.dtype, arg_def.layout); return pt_in; } else { auto pt_in = - framework::MakeTensorImpl( + framework::MakeTensorImpl( tensor.value(), arg_def.backend, arg_def.dtype, arg_def.layout); return pt_in; } @@ -122,27 +124,28 @@ std::shared_ptr InputVariableToPtTensor( return nullptr; } -std::shared_ptr OutputVariableToPtTensor( - framework::Variable* variable, const pt::TensorArgDef& arg_def) { +std::shared_ptr OutputVariableToPtTensor( + framework::Variable* variable, const pten::TensorArgDef& arg_def) { // mutable_data before run kernel, to avoid share output form // KernelContext to original tensor if (variable->template IsType()) { auto* tensor = variable->template GetMutable(); - tensor->mutable_data(pt::TransToFluidPlace(arg_def.backend), - pt::TransToProtoVarType(arg_def.dtype)); + tensor->mutable_data(pten::TransToFluidPlace(arg_def.backend), + pten::TransToProtoVarType(arg_def.dtype)); auto pt_out = - framework::MakeTensorImpl( + framework::MakeTensorImpl( *tensor, arg_def.backend, arg_def.dtype, arg_def.layout); return pt_out; } else if (variable->template IsType()) { auto* tensor = variable->template GetMutable(); tensor->mutable_value()->mutable_data( - pt::TransToFluidPlace(arg_def.backend), - pt::TransToProtoVarType(arg_def.dtype)); + pten::TransToFluidPlace(arg_def.backend), + pten::TransToProtoVarType(arg_def.dtype)); // TODO(chenweihang): adapt SelectedRows by xiaowei's design, // here the row and height will lost in output! - auto pt_out = framework::MakeTensorImpl( - tensor->value(), arg_def.backend, arg_def.dtype, arg_def.layout); + auto pt_out = + framework::MakeTensorImpl( + tensor->value(), arg_def.backend, arg_def.dtype, arg_def.layout); return pt_out; } else { PADDLE_THROW(platform::errors::Unimplemented( @@ -153,14 +156,15 @@ std::shared_ptr OutputVariableToPtTensor( return nullptr; } -OpKernelType TransPtKernelKeyToOpKernelType(const pt::KernelKey& kernel_key) { - proto::VarType::Type data_type = pt::TransToProtoVarType(kernel_key.dtype()); - platform::Place place = pt::TransToFluidPlace(kernel_key.backend()); - DataLayout data_layout = pt::TransToFluidDataLayout(kernel_key.layout()); +OpKernelType TransPtKernelKeyToOpKernelType(const pten::KernelKey& kernel_key) { + proto::VarType::Type data_type = + pten::TransToProtoVarType(kernel_key.dtype()); + platform::Place place = pten::TransToFluidPlace(kernel_key.backend()); + DataLayout data_layout = pten::TransToFluidDataLayout(kernel_key.layout()); LibraryType library_type = LibraryType::kPlain; - if (kernel_key.backend() == pt::Backend::kMKLDNN) { + if (kernel_key.backend() == pten::Backend::kMKLDNN) { library_type = LibraryType::kMKLDNN; - } else if (kernel_key.backend() == pt::Backend::kCUDNN) { + } else if (kernel_key.backend() == pten::Backend::kCUDNN) { library_type = LibraryType::kCUDNN; } else { // do nothing @@ -169,18 +173,21 @@ OpKernelType TransPtKernelKeyToOpKernelType(const pt::KernelKey& kernel_key) { return OpKernelType(data_type, place, data_layout, library_type); } -pt::KernelKey TransOpKernelTypeToPtKernelKey(const OpKernelType& kernel_type) { - pt::Backend backend = pt::TransToPtBackend(kernel_type.place_); +pten::KernelKey TransOpKernelTypeToPtKernelKey( + const OpKernelType& kernel_type) { + pten::Backend backend = pten::TransToPtBackend(kernel_type.place_); if (kernel_type.library_type_ == LibraryType::kMKLDNN) { - backend = pt::Backend::kMKLDNN; + backend = pten::Backend::kMKLDNN; } else if (kernel_type.library_type_ == LibraryType::kCUDNN) { - backend = pt::Backend::kCUDNN; + backend = pten::Backend::kCUDNN; } else { // do } - pt::DataLayout layout = pt::TransToPtDataLayout(kernel_type.data_layout_); - pt::DataType dtype = pt::TransToPtDataType(kernel_type.data_type_); - return pt::KernelKey(backend, layout, dtype); + paddle::experimental::DataLayout layout = + pten::TransToPtDataLayout(kernel_type.data_layout_); + paddle::experimental::DataType dtype = + pten::TransToPtDataType(kernel_type.data_type_); + return pten::KernelKey(backend, layout, dtype); } KernelSignatureMap& KernelSignatureMap::Instance() { diff --git a/paddle/fluid/framework/tcmpt_utils.h b/paddle/fluid/framework/pten_utils.h similarity index 83% rename from paddle/fluid/framework/tcmpt_utils.h rename to paddle/fluid/framework/pten_utils.h index 4d08692bd9c26..14dbe933195be 100644 --- a/paddle/fluid/framework/tcmpt_utils.h +++ b/paddle/fluid/framework/pten_utils.h @@ -24,7 +24,7 @@ limitations under the License. */ #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/place.h" -#include "paddle/tcmpt/api/include/core.h" +#include "paddle/pten/api/include/core.h" #include "paddle/utils/flat_hash_map.h" #include "paddle/utils/small_vector.h" @@ -34,10 +34,10 @@ namespace framework { /* tensor translate */ template -std::shared_ptr MakeTensorImpl(const VariableT& tensor, - pt::Backend backend, - pt::DataType dtype, - pt::DataLayout layout); +std::shared_ptr MakeTensorImpl( + const VariableT& tensor, pten::Backend backend, + paddle::experimental::DataType dtype, + paddle::experimental::DataLayout layout); template std::shared_ptr MakeTensorImpl(const LoDTensor& tensor, @@ -55,15 +55,15 @@ void ShareTensorImpl(PtTensorImplT* tensor_impl, LoDTensor* out); template void ShareTensorImpl(PtTensorImplT* tensor_impl, Tensor* out); -std::shared_ptr InputVariableToPtTensor( - const framework::Variable& variable, const pt::TensorArgDef& arg_def); -std::shared_ptr OutputVariableToPtTensor( - framework::Variable* variable, const pt::TensorArgDef& arg_def); +std::shared_ptr InputVariableToPtTensor( + const framework::Variable& variable, const pten::TensorArgDef& arg_def); +std::shared_ptr OutputVariableToPtTensor( + framework::Variable* variable, const pten::TensorArgDef& arg_def); /* Kernel Key translate */ -OpKernelType TransPtKernelKeyToOpKernelType(const pt::KernelKey& kernel_key); -pt::KernelKey TransOpKernelTypeToPtKernelKey(const OpKernelType& kernel_type); +OpKernelType TransPtKernelKeyToOpKernelType(const pten::KernelKey& kernel_key); +pten::KernelKey TransOpKernelTypeToPtKernelKey(const OpKernelType& kernel_type); /* Kernel Args parse */ diff --git a/paddle/fluid/framework/tcmpt_utils_test.cc b/paddle/fluid/framework/pten_utils_test.cc similarity index 73% rename from paddle/fluid/framework/tcmpt_utils_test.cc rename to paddle/fluid/framework/pten_utils_test.cc index 200bd5429cd46..96f75ac0c1121 100644 --- a/paddle/fluid/framework/tcmpt_utils_test.cc +++ b/paddle/fluid/framework/pten_utils_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/tcmpt_utils.h" +#include "paddle/fluid/framework/pten_utils.h" #include "gtest/gtest.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/selected_rows.h" @@ -31,14 +31,14 @@ TEST(TcmptUtils, MakeTensor) { x.data()[1] = 0.5; // 2. test API - auto dense_x = MakeTensorImpl(x, x.place(), x.type()); + auto dense_x = MakeTensorImpl(x, x.place(), x.type()); // 3. check result std::vector expect_value = {0.2, 0.5}; ASSERT_EQ(dense_x->data()[0], expect_value[0]); ASSERT_EQ(dense_x->data()[1], expect_value[1]); - ASSERT_EQ(dense_x->backend(), pt::Backend::kCPU); - ASSERT_EQ(dense_x->data_type(), pt::DataType::kFLOAT32); + ASSERT_EQ(dense_x->backend(), pten::Backend::kCPU); + ASSERT_EQ(dense_x->data_type(), paddle::experimental::DataType::kFLOAT32); } TEST(TcmptUtils, VarToPtTensor) { @@ -49,18 +49,19 @@ TEST(TcmptUtils, VarToPtTensor) { auto* data = value->mutable_data(make_ddim({1, 1}), paddle::platform::CPUPlace()); data[0] = 123; - pt::Backend expect_backend = pt::Backend::kCPU; + pten::Backend expect_backend = pten::Backend::kCPU; #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - expect_backend = pt::Backend::kCUDA; + expect_backend = pten::Backend::kCUDA; #endif - auto tensor_def = pt::TensorArgDef(expect_backend, pt::DataLayout::kNCHW, - pt::DataType::kINT32); + auto tensor_def = pten::TensorArgDef(expect_backend, + paddle::experimental::DataLayout::kNCHW, + paddle::experimental::DataType::kINT32); // 2. test API auto tensor_x = InputVariableToPtTensor(v, tensor_def); // 3. check result ASSERT_EQ(tensor_x->backend(), expect_backend); - ASSERT_EQ(tensor_x->data_type(), pt::DataType::kINT32); + ASSERT_EQ(tensor_x->data_type(), paddle::experimental::DataType::kINT32); } } // namespace framework diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 617825870301b..c45f92496b3e8 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -1,9 +1,9 @@ cc_library(imperative_flag SRCS flags.cc DEPS gflags flags) IF(WITH_XPU) -cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils tcmpt_utils) +cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten_utils) ELSE() -cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils tcmpt_utils) +cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten_utils) ENDIF() cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry) add_subdirectory(jit) diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index f65b799e150fc..97d893babae18 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -16,7 +16,7 @@ #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" -#include "paddle/fluid/framework/tcmpt_utils.h" +#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/imperative/infer_shape_context.h" #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/xpu/xpu_op_list.h" @@ -109,7 +109,7 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op, const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, const framework::KernelSignature& kernel_signature, - const pt::Kernel& pt_kernel, + const pten::Kernel& pt_kernel, platform::DeviceContext* dev_ctx) : op_(op), ctx_(ctx), @@ -152,15 +152,15 @@ PreparedOp PrepareImpl(const NameVarMap& ins, VLOG(3) << "expected_kernel_key:" << expected_kernel_key; if (FLAGS_run_pt_kernel && - pt::KernelFactory::Instance().ContainsKernel(op.Type().c_str())) { + pten::KernelFactory::Instance().ContainsKernel(op.Type().c_str())) { auto pt_kernel_signature = op.GetExpectedPtKernelArgs(dygraph_exe_ctx); VLOG(1) << framework::KernelSignatureToString(pt_kernel_signature); - auto pt_kernel_name = pt::KernelName(pt_kernel_signature.first); + auto pt_kernel_name = pten::KernelName(pt_kernel_signature.first); auto pt_kernel_key = TransOpKernelTypeToPtKernelKey(expected_kernel_key); - auto pt_kernel = pt::KernelFactory::Instance().SelectKernel(pt_kernel_name, - pt_kernel_key); + auto pt_kernel = pten::KernelFactory::Instance().SelectKernel( + pt_kernel_name, pt_kernel_key); if (pt_kernel.IsValid()) { VLOG(1) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name @@ -243,9 +243,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap& ins, } template -static pt::KernelContext BuildDygraphPtKernelContext( +static pten::KernelContext BuildDygraphPtKernelContext( const framework::KernelSignature& pt_kernel_signature, - const pt::Kernel& pt_kernel, const NameVarMap& ins, + const pten::Kernel& pt_kernel, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs, const platform::DeviceContext& dev_ctx) { @@ -256,7 +256,7 @@ static pt::KernelContext BuildDygraphPtKernelContext( // 3. needless attributes remove // 4. use pt Tensor directly // 5. kernel input is not DenseTensor - pt::KernelContext op_kernel_ctx(dev_ctx); + pten::KernelContext op_kernel_ctx(dev_ctx); auto& input_names = std::get<0>(pt_kernel_signature.second); auto& attr_names = std::get<1>(pt_kernel_signature.second); @@ -288,7 +288,7 @@ static pt::KernelContext BuildDygraphPtKernelContext( auto& in_def = input_defs.at(i); auto& ins_vector = ins.at(input_names[i]); - std::vector> tmp_inputs; + std::vector> tmp_inputs; for (auto var : ins_vector) { const auto& variable = var->Var(); @@ -302,7 +302,7 @@ static pt::KernelContext BuildDygraphPtKernelContext( auto& out_def = output_defs.at(i); auto& outs_vector = outs.at(output_names[i]); - std::vector> tmp_outputs; + std::vector> tmp_outputs; for (auto var : outs_vector) { auto* variable = var->MutableVar(); @@ -314,12 +314,13 @@ static pt::KernelContext BuildDygraphPtKernelContext( for (size_t i = 0; i < attr_names.size(); ++i) { auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (attr_defs[i].type_index == std::type_index(typeid(pt::Scalar))) { + if (attr_defs[i].type_index == std::type_index(typeid(pten::Scalar))) { // TODO(chenweihang): support other attrs later // TODO(zhangyunfei): Scalar should hold scaler type, and we should check // attribtue type by attr_defs if (std::type_index(attr.type()) == std::type_index(typeid(float))) { - op_kernel_ctx.EmplaceBackAttr(pt::Scalar(BOOST_GET_CONST(float, attr))); + op_kernel_ctx.EmplaceBackAttr( + pten::Scalar(BOOST_GET_CONST(float, attr))); } else { PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` to Scalar when construct " @@ -391,7 +392,7 @@ template static void PreparedOpRunPtImpl( const framework::OperatorBase& op, const framework::KernelSignature& pt_kernel_signature, - const pt::Kernel& pt_kernel, platform::DeviceContext* dev_ctx, + const pten::Kernel& pt_kernel, platform::DeviceContext* dev_ctx, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index d1a47117f389b..42bd581b9f24a 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -26,7 +26,7 @@ #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/type_defs.h" -#include "paddle/tcmpt/api/include/core.h" +#include "paddle/pten/api/include/core.h" DECLARE_bool(use_mkldnn); @@ -154,7 +154,7 @@ class PreparedOp { const framework::RuntimeContext& ctx, const framework::OpKernelType& kernel_type, const framework::KernelSignature& kernel_signature, - const pt::Kernel& pt_kernel, platform::DeviceContext* dev_ctx); + const pten::Kernel& pt_kernel, platform::DeviceContext* dev_ctx); static PreparedOp Prepare(const NameVarMap& ins, const NameVarMap& outs, @@ -188,11 +188,11 @@ class PreparedOp { framework::OperatorWithKernel::OpKernelFunc func_; platform::DeviceContext* dev_ctx_; // NOTE(chenweihang): Similar op members are used to adapt to - // new tcmpt kernel, if there is a better design in the future, + // new pten kernel, if there is a better design in the future, // we may polish the implementation here bool run_pt_kernel_{false}; framework::KernelSignature pt_kernel_signature_; - pt::Kernel pt_kernel_; + pten::Kernel pt_kernel_; }; } // namespace imperative diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 3357625b74c22..09c72cb13b803 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -35,7 +35,7 @@ endif() # fluid_modules exclude API-interface of inference/api and inference/capi_exp get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) -get_property(tcmpt_modules GLOBAL PROPERTY TCMPT_MODULES) +get_property(pten_modules GLOBAL PROPERTY PTEN_MODULES) # Adapt to custom op mechanism: Include the header files related to the data type # to avoid exposing the path of the underlying file @@ -51,9 +51,9 @@ set(STATIC_INFERENCE_API paddle_inference_api analysis_predictor analysis_config paddle_pass_builder activation_functions ${mkldnn_quantizer_cfg}) #TODO(wilber, T8T9): Do we still need to support windows gpu static library? if(WIN32 AND WITH_GPU) - cc_library(paddle_inference DEPS ${fluid_modules} ${tcmpt_modules} ${STATIC_INFERENCE_API}) + cc_library(paddle_inference DEPS ${fluid_modules} ${pten_modules} ${STATIC_INFERENCE_API}) else() - create_static_lib(paddle_inference ${fluid_modules} ${tcmpt_modules} ${STATIC_INFERENCE_API}) + create_static_lib(paddle_inference ${fluid_modules} ${pten_modules} ${STATIC_INFERENCE_API}) endif() if(NOT APPLE) @@ -83,7 +83,7 @@ set(SHARED_INFERENCE_SRCS ${PADDLE_CUSTOM_OP_SRCS}) # shared inference library deps -set(SHARED_INFERENCE_DEPS ${fluid_modules} ${tcmpt_modules} analysis_predictor) +set(SHARED_INFERENCE_DEPS ${fluid_modules} ${pten_modules} analysis_predictor) if (WITH_CRYPTO) set(SHARED_INFERENCE_DEPS ${SHARED_INFERENCE_DEPS} paddle_crypto) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 1ce7fd8d0f91b..bfeb2db6d885b 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -78,8 +78,8 @@ if(WITH_UNITY_BUILD) include(unity_build_rule.cmake) endif() -set(OP_HEADER_DEPS ${OP_HEADER_DEPS} tcmpt) -set(OP_HEADER_DEPS ${OP_HEADER_DEPS} tcmpt_utils) +set(OP_HEADER_DEPS ${OP_HEADER_DEPS} pten) +set(OP_HEADER_DEPS ${OP_HEADER_DEPS} pten_utils) register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op sparse_attention_op lstm_op run_program_op eye_op recurrent_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) diff --git a/paddle/fluid/operators/dot_op.h b/paddle/fluid/operators/dot_op.h index a427da4f40f9f..641b0d653d5b0 100644 --- a/paddle/fluid/operators/dot_op.h +++ b/paddle/fluid/operators/dot_op.h @@ -16,13 +16,13 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tcmpt_utils.h" +#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/platform/for_range.h" -// only can include the headers in paddle/tcmpt/api dirs -#include "paddle/tcmpt/api/include/core.h" -#include "paddle/tcmpt/api/include/linalg.h" +// only can include the headers in paddle/pten/api dirs +#include "paddle/pten/api/include/core.h" +#include "paddle/pten/api/include/linalg.h" namespace paddle { namespace operators { @@ -245,14 +245,14 @@ class DotKernel : public framework::OpKernel { out->mutable_data(x->place()); auto pt_x = - framework::MakeTensorImpl(*x, x->place(), x->type()); + framework::MakeTensorImpl(*x, x->place(), x->type()); auto pt_y = - framework::MakeTensorImpl(*y, y->place(), y->type()); - auto pt_out = - framework::MakeTensorImpl(*out, x->place(), x->type()); + framework::MakeTensorImpl(*y, y->place(), y->type()); + auto pt_out = framework::MakeTensorImpl(*out, x->place(), + x->type()); // call new kernel - pt::Dot(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get()); + pten::Dot(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get()); } }; diff --git a/paddle/fluid/operators/fill_any_like_op.h b/paddle/fluid/operators/fill_any_like_op.h index c1c7152581ce5..73170c6e2e277 100644 --- a/paddle/fluid/operators/fill_any_like_op.h +++ b/paddle/fluid/operators/fill_any_like_op.h @@ -17,10 +17,10 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tcmpt_utils.h" +#include "paddle/fluid/framework/pten_utils.h" -#include "paddle/tcmpt/api/include/core.h" -#include "paddle/tcmpt/api/include/creation.h" +#include "paddle/pten/api/include/core.h" +#include "paddle/pten/api/include/creation.h" namespace paddle { namespace operators { @@ -62,14 +62,14 @@ class FillAnyLikeKernel : public framework::OpKernel { std::isnan(value), false, platform::errors::InvalidArgument("The filled value is NaN.")); - auto pt_x = framework::MakeTensorImpl(*in, in->place(), - in->type()); - auto pt_out = framework::MakeTensorImpl(*out, out->place(), - out->type()); + auto pt_x = framework::MakeTensorImpl(*in, in->place(), + in->type()); + auto pt_out = framework::MakeTensorImpl( + *out, out->place(), out->type()); const auto& dev_ctx = context.template device_context(); // call new kernel - pt::FillAnyLike(dev_ctx, *pt_x, value, pt_out.get()); + pten::FillAnyLike(dev_ctx, *pt_x, value, pt_out.get()); } }; diff --git a/paddle/fluid/operators/mean_op.h b/paddle/fluid/operators/mean_op.h index 1ae6f453a873e..661ff41f10f85 100644 --- a/paddle/fluid/operators/mean_op.h +++ b/paddle/fluid/operators/mean_op.h @@ -15,11 +15,11 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tcmpt_utils.h" +#include "paddle/fluid/framework/pten_utils.h" // only can include the headers in paddle/top/api dirs -#include "paddle/tcmpt/api/include/core.h" -#include "paddle/tcmpt/api/include/math.h" +#include "paddle/pten/api/include/core.h" +#include "paddle/pten/api/include/math.h" namespace paddle { namespace operators { @@ -62,13 +62,13 @@ class MeanKernel : public framework::OpKernel { out->mutable_data(x->place()); auto pt_x = - framework::MakeTensorImpl(*x, x->place(), x->type()); - auto pt_out = - framework::MakeTensorImpl(*out, x->place(), x->type()); + framework::MakeTensorImpl(*x, x->place(), x->type()); + auto pt_out = framework::MakeTensorImpl(*out, x->place(), + x->type()); // call new kernel VLOG(1) << "chenweihang: call original mean kernel compute."; - pt::Mean(dev_ctx, *pt_x.get(), pt_out.get()); + pten::Mean(dev_ctx, *pt_x.get(), pt_out.get()); } }; diff --git a/paddle/fluid/operators/scale_op.h b/paddle/fluid/operators/scale_op.h index ffc2a49232cd8..9a043361678b2 100644 --- a/paddle/fluid/operators/scale_op.h +++ b/paddle/fluid/operators/scale_op.h @@ -15,11 +15,11 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tcmpt_utils.h" +#include "paddle/fluid/framework/pten_utils.h" // only can include the headers in paddle/top/api dirs -#include "paddle/tcmpt/api/include/core.h" -#include "paddle/tcmpt/api/include/math.h" +#include "paddle/pten/api/include/core.h" +#include "paddle/pten/api/include/math.h" namespace paddle { namespace operators { @@ -66,14 +66,14 @@ class ScaleKernel : public framework::OpKernel { out->mutable_data(in->place()); auto& dev_ctx = ctx.device_context(); - auto pt_x = framework::MakeTensorImpl(*in, in->place(), - in->type()); - auto pt_out = framework::MakeTensorImpl(*out, in->place(), + auto pt_x = framework::MakeTensorImpl(*in, in->place(), in->type()); + auto pt_out = framework::MakeTensorImpl( + *out, in->place(), in->type()); // call new kernel - pt::Scale(dev_ctx, *pt_x.get(), scale, bias, bias_after_scale, - pt_out.get()); + pten::Scale(dev_ctx, *pt_x.get(), scale, bias, bias_after_scale, + pt_out.get()); } }; diff --git a/paddle/fluid/operators/sign_op.h b/paddle/fluid/operators/sign_op.h index bb439839bd330..f3083f4937875 100644 --- a/paddle/fluid/operators/sign_op.h +++ b/paddle/fluid/operators/sign_op.h @@ -16,12 +16,12 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tcmpt_utils.h" +#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -// only can include the headers in paddle/tcmpt/api dirs -#include "paddle/tcmpt/api/include/core.h" -#include "paddle/tcmpt/api/include/math.h" +// only can include the headers in paddle/pten/api dirs +#include "paddle/pten/api/include/core.h" +#include "paddle/pten/api/include/math.h" namespace paddle { namespace operators { @@ -37,12 +37,12 @@ class SignKernel : public framework::OpKernel { out->mutable_data(x->place()); auto pt_x = - framework::MakeTensorImpl(*x, x->place(), x->type()); - auto pt_out = - framework::MakeTensorImpl(*out, x->place(), x->type()); + framework::MakeTensorImpl(*x, x->place(), x->type()); + auto pt_out = framework::MakeTensorImpl(*out, x->place(), + x->type()); // call new kernel - pt::Sign(dev_ctx, *pt_x.get(), pt_out.get()); + pten::Sign(dev_ctx, *pt_x.get(), pt_out.get()); } }; diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 96bcbe7d0238e..54e73c5c1d9fa 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -169,7 +169,7 @@ if(WITH_GPU) nv_test(device_event_test SRCS device_event_test.cc DEPS device_event_gpu) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) - nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda tcmpt) + nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda pten) nv_test(cudnn_desc_test SRCS cudnn_desc_test.cc DEPS dynload_cuda) nv_test(transform_test SRCS transform_test.cu DEPS memory place device_context) endif() diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index c92173b230ae6..b8b0f65eaa1ce 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -554,9 +554,9 @@ GenerateOpFunctions() { auto& op_type = op_proto->type(); // Skip ooerator which is not inherit form OperatorWithKernel, like while, // since only OperatorWithKernel can run in dygraph mode. - // if the tcmpt lib contains op kernel, we still generate ops method + // if the pten lib contains op kernel, we still generate ops method if (!all_kernels.count(op_type) && - !pt::KernelFactory::Instance().ContainsKernel(op_type.c_str())) { + !pten::KernelFactory::Instance().ContainsKernel(op_type.c_str())) { continue; } diff --git a/paddle/pten/CMakeLists.txt b/paddle/pten/CMakeLists.txt new file mode 100644 index 0000000000000..3bf1e6759b35a --- /dev/null +++ b/paddle/pten/CMakeLists.txt @@ -0,0 +1,15 @@ +include(pten) +# pten api +add_subdirectory(api) +# pten high level api +add_subdirectory(hapi) +# pten core components +add_subdirectory(core) +# pten kernels for diff device +add_subdirectory(kernels) +# pten infershape +add_subdirectory(infershape) +# TODO(xingfeng): pten inner module API designed by a high-performance team +add_subdirectory(module) +# pten tests +add_subdirectory(tests) diff --git a/paddle/pten/api/CMakeLists.txt b/paddle/pten/api/CMakeLists.txt new file mode 100644 index 0000000000000..aabef9185f6c1 --- /dev/null +++ b/paddle/pten/api/CMakeLists.txt @@ -0,0 +1,21 @@ +# set(declare_file ${PADDLE_BINARY_DIR}/paddle/pten/api/symbols.h.tmp CACHE INTERNAL "symbols.h file") +# set(declare_file_final ${PADDLE_BINARY_DIR}/paddle/pten/api/symbols.h) +# file(WRITE ${declare_file} "// Generated by the paddle/pten/api/CMakeLists.txt. DO NOT EDIT!\n\n") + +# function(declare_module TARGTE) +# file(APPEND ${declare_file} "extern int RegisterSymbolsFor${TARGET}();\n") +# message(STATUS "") +# endfunction() + +# TODO(chenweihang): unify decclare into **_library +# declare_module(MathCPU) +# declare_module(MathCUDA) + +set(PTEN_DEPS convert_utils dense_tensor kernel_factory kernel_context) +set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu creation_cpu manipulation_cpu) +set(PTEN_DEPS ${PTEN_DEPS} unary binary) +if(WITH_GPU OR WITH_ROCM) + set(PTEN_DEPS ${PTEN_DEPS} math_cuda linalg_cuda creation_cuda manipulation_cuda) +endif() + +cc_library(pten SRCS all.cc DEPS ${PTEN_DEPS}) diff --git a/paddle/tcmpt/api/all.cc b/paddle/pten/api/all.cc similarity index 89% rename from paddle/tcmpt/api/all.cc rename to paddle/pten/api/all.cc index 05922e02c4998..0704d6c516fa6 100644 --- a/paddle/tcmpt/api/all.cc +++ b/paddle/pten/api/all.cc @@ -12,6 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/api/all.h" +#include "paddle/pten/api/all.h" -namespace pt {} // namespace pt +namespace pten {} // namespace pten diff --git a/paddle/tcmpt/api/all.h b/paddle/pten/api/all.h similarity index 69% rename from paddle/tcmpt/api/all.h rename to paddle/pten/api/all.h index 0f47f75f8a7fc..c760960967d95 100644 --- a/paddle/tcmpt/api/all.h +++ b/paddle/pten/api/all.h @@ -15,9 +15,9 @@ limitations under the License. */ #pragma once // develop apis -#include "paddle/tcmpt/api/include/core.h" -#include "paddle/tcmpt/api/include/creation.h" -#include "paddle/tcmpt/api/include/infershape.h" -#include "paddle/tcmpt/api/include/linalg.h" -#include "paddle/tcmpt/api/include/manipulation.h" -#include "paddle/tcmpt/api/include/math.h" +#include "paddle/pten/api/include/core.h" +#include "paddle/pten/api/include/creation.h" +#include "paddle/pten/api/include/infershape.h" +#include "paddle/pten/api/include/linalg.h" +#include "paddle/pten/api/include/manipulation.h" +#include "paddle/pten/api/include/math.h" diff --git a/paddle/tcmpt/api/include/core.h b/paddle/pten/api/include/core.h similarity index 75% rename from paddle/tcmpt/api/include/core.h rename to paddle/pten/api/include/core.h index fd863186abb30..7872580ad8d7c 100644 --- a/paddle/tcmpt/api/include/core.h +++ b/paddle/pten/api/include/core.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once // See Note: [ How do we organize the kernel directory ] -#include "paddle/tcmpt/core/convert_utils.h" -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/kernel_context.h" -#include "paddle/tcmpt/core/kernel_factory.h" -#include "paddle/tcmpt/core/scalar.h" +#include "paddle/pten/core/convert_utils.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_context.h" +#include "paddle/pten/core/kernel_factory.h" +#include "paddle/pten/core/scalar.h" diff --git a/paddle/tcmpt/api/include/creation.h b/paddle/pten/api/include/creation.h similarity index 87% rename from paddle/tcmpt/api/include/creation.h rename to paddle/pten/api/include/creation.h index 2a87453b32154..d7311e6cd283b 100644 --- a/paddle/tcmpt/api/include/creation.h +++ b/paddle/pten/api/include/creation.h @@ -14,5 +14,5 @@ #pragma once -#include "paddle/tcmpt/kernels/cpu/creation.h" -#include "paddle/tcmpt/kernels/cuda/creation.h" +#include "paddle/pten/kernels/cpu/creation.h" +#include "paddle/pten/kernels/cuda/creation.h" diff --git a/paddle/tcmpt/api/include/infershape.h b/paddle/pten/api/include/infershape.h similarity index 88% rename from paddle/tcmpt/api/include/infershape.h rename to paddle/pten/api/include/infershape.h index 01ed351fb59b2..8c1bd43aaa24e 100644 --- a/paddle/tcmpt/api/include/infershape.h +++ b/paddle/pten/api/include/infershape.h @@ -15,5 +15,5 @@ limitations under the License. */ #pragma once // See Note: [ How do we organize the kernel directory ] -#include "paddle/tcmpt/infershape/binary.h" -#include "paddle/tcmpt/infershape/unary.h" +#include "paddle/pten/infershape/binary.h" +#include "paddle/pten/infershape/unary.h" diff --git a/paddle/tcmpt/api/include/linalg.h b/paddle/pten/api/include/linalg.h similarity index 88% rename from paddle/tcmpt/api/include/linalg.h rename to paddle/pten/api/include/linalg.h index 81ea68abcd0bb..d9798c3a2e0a8 100644 --- a/paddle/tcmpt/api/include/linalg.h +++ b/paddle/pten/api/include/linalg.h @@ -15,5 +15,5 @@ #pragma once // See Note: [ How do we organize the kernel directory ] -#include "paddle/tcmpt/kernels/cpu/linalg.h" -#include "paddle/tcmpt/kernels/cuda/linalg.h" +#include "paddle/pten/kernels/cpu/linalg.h" +#include "paddle/pten/kernels/cuda/linalg.h" diff --git a/paddle/tcmpt/api/include/manipulation.h b/paddle/pten/api/include/manipulation.h similarity index 87% rename from paddle/tcmpt/api/include/manipulation.h rename to paddle/pten/api/include/manipulation.h index 1746929ca181d..f2acad9649969 100644 --- a/paddle/tcmpt/api/include/manipulation.h +++ b/paddle/pten/api/include/manipulation.h @@ -15,5 +15,5 @@ #pragma once // See Note: [ How do we organize the kernel directory ] -#include "paddle/tcmpt/kernels/cpu/manipulation.h" -#include "paddle/tcmpt/kernels/cuda/manipulation.h" +#include "paddle/pten/kernels/cpu/manipulation.h" +#include "paddle/pten/kernels/cuda/manipulation.h" diff --git a/paddle/tcmpt/api/include/math.h b/paddle/pten/api/include/math.h similarity index 88% rename from paddle/tcmpt/api/include/math.h rename to paddle/pten/api/include/math.h index ab3c229806990..5145c823a5c6e 100644 --- a/paddle/tcmpt/api/include/math.h +++ b/paddle/pten/api/include/math.h @@ -15,5 +15,5 @@ limitations under the License. */ #pragma once // See Note: [ How do we organize the kernel directory ] -#include "paddle/tcmpt/kernels/cpu/math.h" -#include "paddle/tcmpt/kernels/cuda/math.h" +#include "paddle/pten/kernels/cpu/math.h" +#include "paddle/pten/kernels/cuda/math.h" diff --git a/paddle/tcmpt/api/include/symbols.h b/paddle/pten/api/include/symbols.h similarity index 94% rename from paddle/tcmpt/api/include/symbols.h rename to paddle/pten/api/include/symbols.h index 8dc75f859ce52..1ec14a41861d8 100644 --- a/paddle/tcmpt/api/include/symbols.h +++ b/paddle/pten/api/include/symbols.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/core/kernel_registry.h" +#include "paddle/pten/core/kernel_registry.h" // symbol declare PT_DECLARE_MODULE(MathCPU); diff --git a/paddle/tcmpt/common/data_type.h b/paddle/pten/common/data_type.h similarity index 99% rename from paddle/tcmpt/common/data_type.h rename to paddle/pten/common/data_type.h index 03881e6bda1ca..bd33bf70541a8 100644 --- a/paddle/tcmpt/common/data_type.h +++ b/paddle/pten/common/data_type.h @@ -176,6 +176,6 @@ inline DataType& operator++(DataType& dtype, int) { } // namespace experimental } // namespace paddle -namespace pt { +namespace pten { using DataType = paddle::experimental::DataType; } diff --git a/paddle/tcmpt/common/layout.h b/paddle/pten/common/layout.h similarity index 98% rename from paddle/tcmpt/common/layout.h rename to paddle/pten/common/layout.h index ae4e43a9f7197..da41aaaaed33a 100644 --- a/paddle/tcmpt/common/layout.h +++ b/paddle/pten/common/layout.h @@ -59,6 +59,6 @@ inline DataLayout& operator++(DataLayout& layout, int) { } // namespace experimental } // namespace paddle -namespace pt { +namespace pten { using DataLayout = paddle::experimental::DataLayout; } diff --git a/paddle/tcmpt/core/CMakeLists.txt b/paddle/pten/core/CMakeLists.txt similarity index 100% rename from paddle/tcmpt/core/CMakeLists.txt rename to paddle/pten/core/CMakeLists.txt diff --git a/paddle/tcmpt/core/allocator.cc b/paddle/pten/core/allocator.cc similarity index 82% rename from paddle/tcmpt/core/allocator.cc rename to paddle/pten/core/allocator.cc index da1576f81ad71..bcf03ee5acf0a 100644 --- a/paddle/tcmpt/core/allocator.cc +++ b/paddle/pten/core/allocator.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/core/allocator.h" +#include "paddle/pten/core/allocator.h" -namespace paddle { -namespace tcmpt {} // namespace tcmpt -} // namespace paddle +namespace pten {} // namespace pten diff --git a/paddle/tcmpt/core/allocator.h b/paddle/pten/core/allocator.h similarity index 93% rename from paddle/tcmpt/core/allocator.h rename to paddle/pten/core/allocator.h index 592f7a4078f80..b96e695a4f8cf 100644 --- a/paddle/tcmpt/core/allocator.h +++ b/paddle/pten/core/allocator.h @@ -17,8 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/platform/place.h" -namespace paddle { -namespace tcmpt { +namespace pten { /// \brief Encapsulates strategies for access/addressing, allocation/ /// deallocation and construction/destruction of objects. @@ -44,7 +43,7 @@ class RawAllocator { /// \brief Get the place value of the allocator and the allocation. /// \return The place value of the allocator and the allocation. - virtual const platform::Place& place() const = 0; + virtual const paddle::platform::Place& place() const = 0; }; /// \brief Fancy pointer with context. The use of this data type @@ -59,18 +58,18 @@ class Allocation final { Allocation(Allocation&&) = default; Allocation& operator=(Allocation&&) = default; - Allocation(void* data, const platform::Place& place) + Allocation(void* data, const paddle::platform::Place& place) : data_(data), place_(place) {} Allocation(void* data, void* ctx, DeleterFnPtr ctx_deleter, - const platform::Place& place) + const paddle::platform::Place& place) : data_(data), ctx_(ctx, ctx_deleter), place_(place) {} void* operator->() const noexcept { return data_; } operator bool() const noexcept { return data_ || ctx_.Get(); } - const platform::Place& place() const noexcept { return place_; } + const paddle::platform::Place& place() const noexcept { return place_; } void Clear() noexcept { data_ = nullptr; @@ -133,7 +132,7 @@ class Allocation final { Context ctx_; // TODO(Shixiaowei02): Enum needs to be used instead to reduce // the construction overhead by more than 50%. - platform::Place place_; + paddle::platform::Place place_; }; inline void swap(Allocation::Context& a, Allocation::Context& b) noexcept { @@ -155,5 +154,4 @@ inline Allocation Allocate(const std::shared_ptr& a, size_t n) { return a->Allocate(n); } -} // namespace tcmpt -} // namespace paddle +} // namespace pten diff --git a/paddle/tcmpt/core/backend.cc b/paddle/pten/core/backend.cc similarity index 94% rename from paddle/tcmpt/core/backend.cc rename to paddle/pten/core/backend.cc index 68c7adfcc2810..0e4029cfc38e2 100644 --- a/paddle/tcmpt/core/backend.cc +++ b/paddle/pten/core/backend.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/core/backend.h" +#include "paddle/pten/core/backend.h" -namespace pt { +namespace pten { std::ostream& operator<<(std::ostream& os, Backend backend) { switch (backend) { @@ -55,4 +55,4 @@ std::ostream& operator<<(std::ostream& os, Backend backend) { return os; } -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/backend.h b/paddle/pten/core/backend.h similarity index 97% rename from paddle/tcmpt/core/backend.h rename to paddle/pten/core/backend.h index b1ee09c177f29..c10d4bd308331 100644 --- a/paddle/tcmpt/core/backend.h +++ b/paddle/pten/core/backend.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include -namespace pt { +namespace pten { /** * [ Why need Backend? ] @@ -45,4 +45,4 @@ enum class Backend { std::ostream& operator<<(std::ostream& os, Backend backend); -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/convert_utils.cc b/paddle/pten/core/convert_utils.cc similarity index 94% rename from paddle/tcmpt/core/convert_utils.cc rename to paddle/pten/core/convert_utils.cc index e5b8acba19cf0..2320fc632c936 100644 --- a/paddle/tcmpt/core/convert_utils.cc +++ b/paddle/pten/core/convert_utils.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/core/convert_utils.h" +#include "paddle/pten/core/convert_utils.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/gpu_info.h" -namespace pt { +namespace pten { // TODO(chenweihang): Add other place branchs Backend TransToPtBackend(const paddle::platform::Place& place) { @@ -38,7 +38,7 @@ Backend TransToPtBackend(const paddle::platform::Place& place) { } } -pt::DataType TransToPtDataType( +paddle::experimental::DataType TransToPtDataType( const paddle::framework::proto::VarType::Type& dtype) { // Set the order of case branches according to the frequency with // the data type is used @@ -90,29 +90,29 @@ DataLayout TransToPtDataLayout(const paddle::framework::DataLayout& layout) { paddle::platform::Place TransToFluidPlace(const Backend& backend) { // TODO(chenweihang): add other trans cases switch (backend) { - case pt::Backend::kCPU: + case pten::Backend::kCPU: return paddle::platform::CPUPlace(); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - case pt::Backend::kCUDA: + case pten::Backend::kCUDA: return paddle::platform::CUDAPlace( paddle::platform::GetCurrentDeviceId()); #endif #ifdef PADDLE_WITH_XPU - case pt::Backend::kXPU: + case pten::Backend::kXPU: // TODO(chenweihang): add device id return paddle::platform::XPUPlace(); #endif #ifdef PADDLE_WITH_NPU - case pt::Backend::kNPU: + case pten::Backend::kNPU: // TODO(chenweihang): add device id return paddle::platform::NPUPlace(); #endif #ifdef PADDLE_WITH_MKLDNN - case pt::Backend::kMKLDNN: + case pten::Backend::kMKLDNN: return paddle::platform::CPUPlace(); #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - case pt::Backend::kCUDNN: + case pten::Backend::kCUDNN: return paddle::platform::CUDAPlace( paddle::platform::GetCurrentDeviceId()); #endif @@ -124,7 +124,7 @@ paddle::platform::Place TransToFluidPlace(const Backend& backend) { } paddle::framework::proto::VarType::Type TransToProtoVarType( - const pt::DataType& dtype) { + const paddle::experimental::DataType& dtype) { // Set the order of case branches according to the frequency with // the data type is used switch (dtype) { @@ -178,4 +178,4 @@ paddle::framework::DataLayout TransToFluidDataLayout(const DataLayout& layout) { } } -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/convert_utils.h b/paddle/pten/core/convert_utils.h similarity index 90% rename from paddle/tcmpt/core/convert_utils.h rename to paddle/pten/core/convert_utils.h index 011652bdc9572..2c7ad35881e7c 100644 --- a/paddle/tcmpt/core/convert_utils.h +++ b/paddle/pten/core/convert_utils.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/common/data_type.h" -#include "paddle/tcmpt/common/layout.h" -#include "paddle/tcmpt/core/backend.h" +#include "paddle/pten/common/data_type.h" +#include "paddle/pten/common/layout.h" +#include "paddle/pten/core/backend.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/framework/data_layout.h" @@ -25,7 +25,7 @@ limitations under the License. */ // TODO(chenweihang): this file may need to be removed -namespace pt { +namespace pten { using DataType = paddle::experimental::DataType; using DataLayout = paddle::experimental::DataLayout; @@ -42,4 +42,4 @@ paddle::framework::proto::VarType::Type TransToProtoVarType( const DataType& dtype); paddle::framework::DataLayout TransToFluidDataLayout(const DataLayout& layout); -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/dense_tensor.cc b/paddle/pten/core/dense_tensor.cc similarity index 95% rename from paddle/tcmpt/core/dense_tensor.cc rename to paddle/pten/core/dense_tensor.cc index 9c34b5823d590..022127773909d 100644 --- a/paddle/tcmpt/core/dense_tensor.cc +++ b/paddle/pten/core/dense_tensor.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/convert_utils.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/convert_utils.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/framework/data_type.h" @@ -22,7 +22,7 @@ limitations under the License. */ #include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/place.h" -namespace pt { +namespace pten { using CPUPlace = paddle::platform::CPUPlace; using CUDAPlace = paddle::platform::CUDAPlace; @@ -43,7 +43,7 @@ const paddle::platform::Place& DenseTensor::place() const { // Inner methods void DenseTensor::ShareAllocation( - const std::shared_ptr& allocation) { + const std::shared_ptr& allocation) { // This operation can be very slow! // std::shared_ptr reference count is atomic. increasing or decreasing // the reference count requires atomic increment or decrement. @@ -137,4 +137,4 @@ void* DenseTensor::mutable_data() { reinterpret_cast(allocation_->ptr()) + meta_.offset); } -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/dense_tensor.h b/paddle/pten/core/dense_tensor.h similarity index 88% rename from paddle/tcmpt/core/dense_tensor.h rename to paddle/pten/core/dense_tensor.h index a0d195b740bed..e913440a7e663 100644 --- a/paddle/tcmpt/core/dense_tensor.h +++ b/paddle/pten/core/dense_tensor.h @@ -16,9 +16,9 @@ limitations under the License. */ #include -#include "paddle/tcmpt/core/tensor_base.h" -#include "paddle/tcmpt/core/tensor_meta.h" -#include "paddle/tcmpt/core/tensor_status.h" +#include "paddle/pten/core/tensor_base.h" +#include "paddle/pten/core/tensor_meta.h" +#include "paddle/pten/core/tensor_status.h" namespace paddle { namespace memory { @@ -28,15 +28,10 @@ class Allocation; } } -namespace pt { +namespace pten { -using TensorBase = paddle::tcmpt::TensorBase; using DataType = paddle::experimental::DataType; -// TODO(chenweihang): Allocation still link to framework, Redesign and -// decoupled Allocation and Allocator? -using Allocation = paddle::memory::allocation::Allocation; - /** * The implementation of general Tensor (For CPU, CUDA, HIP, etc.), similar * to the Tensor in fluid, contains a pointer to Allocation and a series of @@ -92,7 +87,10 @@ class DenseTensor : public TensorBase { /* member methods */ - const std::shared_ptr& allocation() const { return allocation_; } + const std::shared_ptr& allocation() + const { + return allocation_; + } const TensorMeta& meta() const { return meta_; } @@ -131,7 +129,8 @@ class DenseTensor : public TensorBase { void Resize(const DDim& dims) { meta_.dims = dims; } - void ShareAllocation(const std::shared_ptr& allocation); + void ShareAllocation(const std::shared_ptr< + paddle::memory::allocation::Allocation>& allocation); paddle::platform::Place GetPlaceByBackend() const; @@ -141,11 +140,11 @@ class DenseTensor : public TensorBase { private: // The actual Tensor storage holder - std::shared_ptr allocation_; + std::shared_ptr allocation_; // The Tensor meta data TensorMeta meta_; // The Tensor status data TensorStatus status_; }; -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/kernel_context.cc b/paddle/pten/core/kernel_context.cc similarity index 88% rename from paddle/tcmpt/core/kernel_context.cc rename to paddle/pten/core/kernel_context.cc index 5bfcaf137fedf..443990c07247d 100644 --- a/paddle/tcmpt/core/kernel_context.cc +++ b/paddle/pten/core/kernel_context.cc @@ -12,6 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/tcmpt/core/kernel_context.h" +#include "paddle/pten/core/kernel_context.h" -namespace pt {} // namespace pt +namespace pten {} // namespace pten diff --git a/paddle/tcmpt/core/kernel_context.h b/paddle/pten/core/kernel_context.h similarity index 97% rename from paddle/tcmpt/core/kernel_context.h rename to paddle/pten/core/kernel_context.h index 022d8a6713155..c17248831c10e 100644 --- a/paddle/tcmpt/core/kernel_context.h +++ b/paddle/pten/core/kernel_context.h @@ -16,17 +16,16 @@ #include -#include "paddle/tcmpt/core/tensor_base.h" +#include "paddle/pten/core/tensor_base.h" #include "paddle/utils/any.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" -namespace pt { +namespace pten { using DeviceContext = paddle::platform::DeviceContext; -using TensorBase = paddle::tcmpt::TensorBase; using DataType = paddle::experimental::DataType; using DataLayout = paddle::experimental::DataLayout; @@ -132,4 +131,4 @@ class KernelContext { std::vector output_names_{{}}; }; -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/kernel_def.h b/paddle/pten/core/kernel_def.h similarity index 97% rename from paddle/tcmpt/core/kernel_def.h rename to paddle/pten/core/kernel_def.h index 70b8be19aaeea..48a579cd02b51 100644 --- a/paddle/tcmpt/core/kernel_def.h +++ b/paddle/pten/core/kernel_def.h @@ -14,7 +14,7 @@ #pragma once -namespace pt { +namespace pten { class Kernel; class KernelKey; @@ -39,4 +39,4 @@ constexpr char kContainSelectedRowsSuffix[] = "sr"; // For kernels with intermediate output constexpr char kContainMidOutputTensorSuffix[] = "mid"; -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/kernel_factory.cc b/paddle/pten/core/kernel_factory.cc similarity index 91% rename from paddle/tcmpt/core/kernel_factory.cc rename to paddle/pten/core/kernel_factory.cc index a301d6a995ce7..243808c67b843 100644 --- a/paddle/tcmpt/core/kernel_factory.cc +++ b/paddle/pten/core/kernel_factory.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/tcmpt/core/kernel_factory.h" +#include "paddle/pten/core/kernel_factory.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/enforce.h" -namespace pt { +namespace pten { KernelFactory& KernelFactory::Instance() { static KernelFactory g_op_kernel_factory; @@ -51,9 +51,11 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( "The kernel `%s` is not registered.", kernel_name)); auto kernel_iter = iter->second.find(kernel_key); - if (kernel_key.layout() != pt::DataLayout::kAny) { - pt::KernelKey any_layout_kernel_key( - kernel_key.backend(), pt::DataLayout::kAny, kernel_key.dtype()); + if (kernel_key.layout() != paddle::experimental::DataLayout::kAny) { + pten::KernelKey any_layout_kernel_key( + kernel_key.backend(), + paddle::experimental::DataLayout::kAny, + kernel_key.dtype()); kernel_iter = iter->second.find(any_layout_kernel_key); } PADDLE_ENFORCE_NE( @@ -98,4 +100,4 @@ std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory) { return os; } -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/kernel_factory.h b/paddle/pten/core/kernel_factory.h similarity index 97% rename from paddle/tcmpt/core/kernel_factory.h rename to paddle/pten/core/kernel_factory.h index 6e4a3fa86dfda..32c8462585878 100644 --- a/paddle/tcmpt/core/kernel_factory.h +++ b/paddle/pten/core/kernel_factory.h @@ -19,17 +19,17 @@ #include #include -#include "paddle/tcmpt/common/data_type.h" -#include "paddle/tcmpt/common/layout.h" -#include "paddle/tcmpt/core/backend.h" -#include "paddle/tcmpt/core/kernel_def.h" +#include "paddle/pten/common/data_type.h" +#include "paddle/pten/common/layout.h" +#include "paddle/pten/core/backend.h" +#include "paddle/pten/core/kernel_def.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/enforce.h" #include "paddle/utils/flat_hash_map.h" #include "paddle/utils/small_vector.h" -namespace pt { +namespace pten { using DataType = paddle::experimental::DataType; using DataLayout = paddle::experimental::DataLayout; @@ -323,4 +323,4 @@ std::ostream& operator<<(std::ostream& os, const Kernel& kernel); std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory); -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h similarity index 91% rename from paddle/tcmpt/core/kernel_registry.h rename to paddle/pten/core/kernel_registry.h index caa42546ab054..666b700a671b9 100644 --- a/paddle/tcmpt/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -20,15 +20,15 @@ #include #include -#include "paddle/tcmpt/core/kernel_def.h" -#include "paddle/tcmpt/core/kernel_factory.h" -#include "paddle/tcmpt/core/kernel_utils.h" +#include "paddle/pten/core/kernel_def.h" +#include "paddle/pten/core/kernel_factory.h" +#include "paddle/pten/core/kernel_utils.h" -namespace pt { +namespace pten { -#define BACKEND(arg__) pt::Backend::k##arg__ -#define DATALAYOUT(arg__) pt::DataLayout::k##arg__ -#define DATATYPE(arg__) pt::DataType::k##arg__ +#define BACKEND(arg__) pten::Backend::k##arg__ +#define DATALAYOUT(arg__) paddle::experimental::DataLayout::k##arg__ +#define DATATYPE(arg__) paddle::experimental::DataType::k##arg__ template struct KernelArgsParseFunctor; @@ -45,8 +45,8 @@ struct KernelArgsParseFunctor { // TODO(chenweihang): The fluid Tensor's default layout is NCHW, // it is not same as kernel's layout, we should fix this error on // fluid Tensor - auto default_tensor_layout = pt::DataLayout::kNCHW; - if (default_key.layout() != pt::DataLayout::kAny) { + auto default_tensor_layout = paddle::experimental::DataLayout::kNCHW; + if (default_key.layout() != paddle::experimental::DataLayout::kAny) { default_tensor_layout = default_key.layout(); } auto args_type = ParseArgType(Indices{}); @@ -216,7 +216,7 @@ struct KernelRegistrar { "PT_REGISTER_KERNEL must be called in global namespace."); \ PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, __VA_ARGS__); \ static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ - func_id)(::pt::Kernel*); \ + func_id)(::pten::Kernel*); \ PT_KERNEL_REGISTRAR_INIT(kernel_name, \ func_id, \ backend, \ @@ -225,7 +225,8 @@ struct KernelRegistrar { meta_kernel_fn, \ cpp_dtype, \ __VA_ARGS__); \ - void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id)(::pt::Kernel * kernel) + void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ + func_id)(::pten::Kernel * kernel) #else #define _PT_REGISTER_KERNEL( \ kernel_name, func_id, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ @@ -233,7 +234,7 @@ struct KernelRegistrar { PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ "PT_REGISTER_KERNEL must be called in global namespace."); \ static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ - func_id)(::pt::Kernel*); \ + func_id)(::pten::Kernel*); \ PT_KERNEL_REGISTRAR_INIT(kernel_name, \ func_id, \ backend, \ @@ -242,7 +243,8 @@ struct KernelRegistrar { meta_kernel_fn, \ cpp_dtype, \ __VA_ARGS__); \ - void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id)(::pt::Kernel * kernel) + void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ + func_id)(::pten::Kernel * kernel) #endif #define PT_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, ...) \ @@ -345,13 +347,13 @@ struct KernelRegistrar { meta_kernel_fn, \ cpp_dtype, \ ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ __reg_pt_op_kernel_##func_id##_, registrar_id)( \ kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ args_def_fn, \ PT_KERNEL(meta_kernel_fn)); @@ -364,13 +366,13 @@ struct KernelRegistrar { meta_kernel_fn, \ cpp_dtype, \ ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ __reg_pt_op_kernel_##func_id##_, registrar_id)( \ kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ @@ -391,13 +393,13 @@ struct KernelRegistrar { meta_kernel_fn, \ cpp_dtype, \ ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ __reg_pt_op_kernel_##func_id##_, registrar_id)( \ kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ @@ -418,13 +420,13 @@ struct KernelRegistrar { meta_kernel_fn, \ cpp_dtype, \ ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ __reg_pt_op_kernel_##func_id##_, registrar_id)( \ kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ @@ -445,13 +447,13 @@ struct KernelRegistrar { meta_kernel_fn, \ cpp_dtype, \ ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ __reg_pt_op_kernel_##func_id##_, registrar_id)( \ kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ @@ -472,13 +474,13 @@ struct KernelRegistrar { meta_kernel_fn, \ cpp_dtype, \ ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ __reg_pt_op_kernel_##func_id##_, registrar_id)( \ kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ @@ -499,13 +501,13 @@ struct KernelRegistrar { meta_kernel_fn, \ cpp_dtype, \ ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ __reg_pt_op_kernel_##func_id##_, registrar_id)( \ kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ @@ -526,13 +528,13 @@ struct KernelRegistrar { meta_kernel_fn, \ cpp_dtype, \ ...) \ - static const ::pt::KernelRegistrar PT_CONCATENATE( \ + static const ::pten::KernelRegistrar PT_CONCATENATE( \ __reg_pt_op_kernel_##func_id##_, registrar_id)( \ kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ - ::pt::KernelArgsParseFunctor)>::Parse, \ args_def_fn, \ PT_KERNEL(meta_kernel_fn)); \ @@ -557,17 +559,17 @@ struct KernelRegistrar { "_PT_REGISTER_KERNEL_STANDARD must be called in global namespace."); \ template decltype(kernel_fn) kernel_fn; \ static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ - func_id)(::pt::Kernel*); \ - static const ::pt::KernelRegistrar PT_CONCATENATE(__reg_pt_op_kernel_, \ - func_id)( \ + func_id)(::pten::Kernel*); \ + static const ::pten::KernelRegistrar PT_CONCATENATE(__reg_pt_op_kernel_, \ + func_id)( \ kernel_name, \ BACKEND(backend), \ DATALAYOUT(layout), \ DATATYPE(dtype), \ - ::pt::KernelArgsParseFunctor::Parse, \ + ::pten::KernelArgsParseFunctor::Parse, \ args_def_fn, \ PT_KERNEL(kernel_fn)); \ - void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id)(::pt::Kernel*) + void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id)(::pten::Kernel*) // use to declare symbol #define PT_REGISTER_MODULE(name) \ @@ -595,7 +597,7 @@ struct KernelRegistrar { PT_CONCATENATE(pt_op_kernel_for_test_ns_check_, func_id), \ "PT_REGISTER_KERNEL must be called in global namespace."); \ static void PT_CONCATENATE(__PT_KERNEL_for_test_args_def_FN_, \ - func_id)(::pt::Kernel*); \ + func_id)(::pten::Kernel*); \ PT_KERNEL_REGISTRAR_INIT( \ kernel_name, \ func_id, \ @@ -606,27 +608,28 @@ struct KernelRegistrar { cpp_dtype, \ __VA_ARGS__); \ void PT_CONCATENATE(__PT_KERNEL_for_test_args_def_FN_, \ - func_id)(::pt::Kernel * kernel) + func_id)(::pten::Kernel * kernel) #define PT_REGISTER_KERNEL_WITH_NO_TYPE( \ kernel_name, backend, layout, meta_kernel_fn) \ _PT_REGISTER_KERNEL_WITH_NO_TYPE( \ kernel_name, PT_ID, backend, layout, meta_kernel_fn) -#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \ - kernel_name, func_id, backend, layout, meta_kernel_fn) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ - "PT_REGISTER_KERNEL must be called in global namespace."); \ - decltype(meta_kernel_fn) meta_kernel_fn; \ - static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ - func_id)(::pt::Kernel*); \ - static const ::pt::KernelRegistrar __reg_pt_op_kernel_##func_id( \ - kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::pt::KernelArgsParseFunctor::Parse, \ - &PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \ - PT_KERNEL(meta_kernel_fn)); \ - void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id)(::pt::Kernel * kernel) -} // namespace pt +#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \ + kernel_name, func_id, backend, layout, meta_kernel_fn) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ + "PT_REGISTER_KERNEL must be called in global namespace."); \ + decltype(meta_kernel_fn) meta_kernel_fn; \ + static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ + func_id)(::pten::Kernel*); \ + static const ::pten::KernelRegistrar __reg_pt_op_kernel_##func_id( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::pten::KernelArgsParseFunctor::Parse, \ + &PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \ + PT_KERNEL(meta_kernel_fn)); \ + void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ + func_id)(::pten::Kernel * kernel) +} // namespace pten diff --git a/paddle/tcmpt/core/kernel_utils.h b/paddle/pten/core/kernel_utils.h similarity index 96% rename from paddle/tcmpt/core/kernel_utils.h rename to paddle/pten/core/kernel_utils.h index 54d3d373da7c7..3f8458aed6dfc 100644 --- a/paddle/tcmpt/core/kernel_utils.h +++ b/paddle/pten/core/kernel_utils.h @@ -14,16 +14,16 @@ #pragma once -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/kernel_context.h" -#include "paddle/tcmpt/core/kernel_def.h" -#include "paddle/tcmpt/core/scalar.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_context.h" +#include "paddle/pten/core/kernel_def.h" +#include "paddle/pten/core/scalar.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" -namespace pt { +namespace pten { // TODO(shixiaowei): replaced by new DeviceContext later using CPUContext = paddle::platform::CPUDeviceContext; @@ -41,7 +41,7 @@ using XPUContext = paddle::platform::XPUDeviceContext; #endif #define PT_KERNEL(...) \ - ::pt::KernelImpl::Compute + ::pten::KernelImpl::Compute #define PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(dev_ctx) \ template \ @@ -163,7 +163,7 @@ struct KernelImpl { PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16); - PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const pt::Scalar&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const pten::Scalar&); /* Output Helpers */ @@ -185,4 +185,4 @@ struct KernelImpl { }; }; -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/scalar.h b/paddle/pten/core/scalar.h similarity index 97% rename from paddle/tcmpt/core/scalar.h rename to paddle/pten/core/scalar.h index 8f30d81bcfb28..f8cdd43cc5e4c 100644 --- a/paddle/tcmpt/core/scalar.h +++ b/paddle/pten/core/scalar.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -namespace pt { +namespace pten { class Scalar { public: @@ -60,4 +60,4 @@ class Scalar { } data_; }; -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/spatial_tensor.h b/paddle/pten/core/spatial_tensor.h similarity index 95% rename from paddle/tcmpt/core/spatial_tensor.h rename to paddle/pten/core/spatial_tensor.h index 0e5bdd8be50a3..f1bd4add19771 100644 --- a/paddle/tcmpt/core/spatial_tensor.h +++ b/paddle/pten/core/spatial_tensor.h @@ -14,9 +14,9 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/core/tensor_base.h" +#include "paddle/pten/core/tensor_base.h" -namespace pt { +namespace pten { /** * SpatialTensor represents a Tensor whose memory layout is different from @@ -48,4 +48,4 @@ class MetalTensor : public SpatialTensor {}; template class OpenCLTensor : public SpatialTensor {}; -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/storage.cc b/paddle/pten/core/storage.cc similarity index 85% rename from paddle/tcmpt/core/storage.cc rename to paddle/pten/core/storage.cc index 02fbea8d0b3a1..5cac122b7dee6 100644 --- a/paddle/tcmpt/core/storage.cc +++ b/paddle/pten/core/storage.cc @@ -12,10 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/core/storage.h" +#include "paddle/pten/core/storage.h" -namespace paddle { -namespace tcmpt { +namespace pten { void TensorStorage::Realloc(size_t size) { data_.Clear(); @@ -23,5 +22,4 @@ void TensorStorage::Realloc(size_t size) { size_ = size; } -} // namespace tcmpt -} // namespace paddle +} // namespace pten diff --git a/paddle/tcmpt/core/storage.h b/paddle/pten/core/storage.h similarity index 85% rename from paddle/tcmpt/core/storage.h rename to paddle/pten/core/storage.h index d838d0cd1c957..b1c6de7fff8f6 100644 --- a/paddle/tcmpt/core/storage.h +++ b/paddle/pten/core/storage.h @@ -17,14 +17,13 @@ limitations under the License. */ #include #include "boost/intrusive_ptr.hpp" -#include "paddle/tcmpt/core/utils/intrusive_ptr.h" -#include "paddle/tcmpt/core/utils/intrusive_ref_counter.h" +#include "paddle/pten/core/utils/intrusive_ptr.h" +#include "paddle/pten/core/utils/intrusive_ref_counter.h" #include "paddle/fluid/platform/place.h" -#include "paddle/tcmpt/core/allocator.h" +#include "paddle/pten/core/allocator.h" -namespace paddle { -namespace tcmpt { +namespace pten { /// \brief The interface of contiguous storage used for the dense tensor. /// It should be used in conjunction with the intrusive pointer. We prohibit @@ -44,7 +43,7 @@ class Storage : public intrusive_ref_counter { void* data() const noexcept { return data_.operator->(); } virtual size_t size() const = 0; - virtual const platform::Place& place() const = 0; + virtual const paddle::platform::Place& place() const = 0; virtual bool OwnsMemory() const = 0; virtual void Realloc(size_t n) = 0; @@ -63,7 +62,9 @@ class TensorStorage : public Storage { void Realloc(size_t size) override; size_t size() const noexcept override { return size_; } - const platform::Place& place() const override { return data_.place(); } + const paddle::platform::Place& place() const override { + return data_.place(); + } bool OwnsMemory() const noexcept override { return true; } const std::shared_ptr& allocator() const noexcept { return alloc_; @@ -74,5 +75,4 @@ class TensorStorage : public Storage { int64_t size_{0}; }; -} // namespace tcmpt -} // namespace paddle +} // namespace pten diff --git a/paddle/tcmpt/core/tensor_base.cc b/paddle/pten/core/tensor_base.cc similarity index 81% rename from paddle/tcmpt/core/tensor_base.cc rename to paddle/pten/core/tensor_base.cc index 05dba1206075d..f9169674a4bbe 100644 --- a/paddle/tcmpt/core/tensor_base.cc +++ b/paddle/pten/core/tensor_base.cc @@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/core/tensor_base.h" -#include "paddle/tcmpt/core/utils/type_registry.h" +#include "paddle/pten/core/tensor_base.h" +#include "paddle/pten/core/utils/type_registry.h" -namespace paddle { -namespace tcmpt {} -} +namespace pten {} diff --git a/paddle/tcmpt/core/tensor_base.h b/paddle/pten/core/tensor_base.h similarity index 81% rename from paddle/tcmpt/core/tensor_base.h rename to paddle/pten/core/tensor_base.h index 240808e3cc492..92b1ebaca4f1c 100644 --- a/paddle/tcmpt/core/tensor_base.h +++ b/paddle/pten/core/tensor_base.h @@ -16,20 +16,19 @@ limitations under the License. */ #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/platform/place.h" -#include "paddle/tcmpt/common/data_type.h" -#include "paddle/tcmpt/common/layout.h" -#include "paddle/tcmpt/core/storage.h" -#include "paddle/tcmpt/core/utils/type_registry.h" +#include "paddle/pten/common/data_type.h" +#include "paddle/pten/common/layout.h" +#include "paddle/pten/core/storage.h" +#include "paddle/pten/core/utils/type_registry.h" -#include "paddle/tcmpt/core/backend.h" +#include "paddle/pten/core/backend.h" -namespace paddle { -namespace tcmpt { +namespace pten { class TensorBase { public: - using DataType = experimental::DataType; - using DataLayout = experimental::DataLayout; + using DataType = paddle::experimental::DataType; + using DataLayout = paddle::experimental::DataLayout; virtual ~TensorBase() = default; @@ -51,7 +50,7 @@ class TensorBase { /// \brief Returns the data place of the tensor. /// \return The data place of the tensor. - virtual const platform::Place& place() const = 0; + virtual const paddle::platform::Place& place() const = 0; /// \brief Test whether the metadata is valid. /// \return Whether the metadata is valid. @@ -61,7 +60,7 @@ class TensorBase { /// return Whether the storage is allocated. virtual bool initialized() const = 0; - virtual pt::Backend backend() const = 0; + virtual pten::Backend backend() const = 0; /// \brief Return the type information of the derived class to support /// safely downcast in non-rtti environment. @@ -74,5 +73,4 @@ class TensorBase { TypeInfo type_info_{TypeInfo::kUnknownType}; }; -} // namespace tcmpt -} // namespace paddle +} // namespace pten diff --git a/paddle/tcmpt/core/tensor_meta.h b/paddle/pten/core/tensor_meta.h similarity index 96% rename from paddle/tcmpt/core/tensor_meta.h rename to paddle/pten/core/tensor_meta.h index 3cc557e05b4c1..c305ed2a850ee 100644 --- a/paddle/tcmpt/core/tensor_meta.h +++ b/paddle/pten/core/tensor_meta.h @@ -16,9 +16,9 @@ limitations under the License. */ #include -#include "paddle/tcmpt/common/data_type.h" -#include "paddle/tcmpt/common/layout.h" -#include "paddle/tcmpt/core/backend.h" +#include "paddle/pten/common/data_type.h" +#include "paddle/pten/common/layout.h" +#include "paddle/pten/core/backend.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/framework/ddim.h" @@ -26,7 +26,7 @@ limitations under the License. */ // used on CUDA device? Can we use small_vector here? // #include "paddle/fluid/framework/mixed_vector.h" -namespace pt { +namespace pten { using DataType = paddle::experimental::DataType; using DataLayout = paddle::experimental::DataLayout; @@ -144,4 +144,4 @@ struct TensorMeta { LoD lod; }; -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/tensor_status.h b/paddle/pten/core/tensor_status.h similarity index 92% rename from paddle/tcmpt/core/tensor_status.h rename to paddle/pten/core/tensor_status.h index 1eb56397414b5..2abc8ff1b1b92 100644 --- a/paddle/tcmpt/core/tensor_status.h +++ b/paddle/pten/core/tensor_status.h @@ -14,11 +14,11 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/common/data_type.h" -#include "paddle/tcmpt/common/layout.h" -#include "paddle/tcmpt/core/backend.h" +#include "paddle/pten/common/data_type.h" +#include "paddle/pten/common/layout.h" +#include "paddle/pten/core/backend.h" -namespace pt { +namespace pten { class TensorInplaceVersion { public: @@ -61,4 +61,4 @@ struct TensorStatus { bool is_scalar{false}; }; -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/core/utils/CMakeLists.txt b/paddle/pten/core/utils/CMakeLists.txt similarity index 100% rename from paddle/tcmpt/core/utils/CMakeLists.txt rename to paddle/pten/core/utils/CMakeLists.txt diff --git a/paddle/tcmpt/core/utils/intrusive_ptr.h b/paddle/pten/core/utils/intrusive_ptr.h similarity index 95% rename from paddle/tcmpt/core/utils/intrusive_ptr.h rename to paddle/pten/core/utils/intrusive_ptr.h index f368d05cb47db..f0e94fadac973 100644 --- a/paddle/tcmpt/core/utils/intrusive_ptr.h +++ b/paddle/pten/core/utils/intrusive_ptr.h @@ -18,8 +18,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" -namespace paddle { -namespace tcmpt { +namespace pten { template class intrusive_ptr { @@ -58,7 +57,7 @@ class intrusive_ptr { T& operator*() const { PADDLE_ENFORCE_NOT_NULL( px, - platform::errors::PreconditionNotMet( + paddle::platform::errors::PreconditionNotMet( "The pointer must be non-null before the dereference operation.")); return *px; } @@ -66,7 +65,7 @@ class intrusive_ptr { T* operator->() const { PADDLE_ENFORCE_NOT_NULL( px, - platform::errors::PreconditionNotMet( + paddle::platform::errors::PreconditionNotMet( "The pointer must be non-null before the dereference operation.")); return px; } @@ -156,5 +155,4 @@ inline intrusive_ptr copy_intrusive(const intrusive_ptr& rhs) { return intrusive_ptr(rhs.get(), true); } -} // namespace tcmpt -} // namespace paddle +} // namespace pten diff --git a/paddle/tcmpt/core/utils/intrusive_ref_counter.h b/paddle/pten/core/utils/intrusive_ref_counter.h similarity index 96% rename from paddle/tcmpt/core/utils/intrusive_ref_counter.h rename to paddle/pten/core/utils/intrusive_ref_counter.h index 1c93bede71df1..8e18c82197eb6 100644 --- a/paddle/tcmpt/core/utils/intrusive_ref_counter.h +++ b/paddle/pten/core/utils/intrusive_ref_counter.h @@ -16,8 +16,7 @@ limitations under the License. */ #include -namespace paddle { -namespace tcmpt { +namespace pten { template class intrusive_ref_counter; @@ -62,5 +61,4 @@ inline void intrusive_ptr_release( } } -} // namespace tcmpt -} // namespace paddle +} // namespace pten diff --git a/paddle/tcmpt/core/utils/type_info.h b/paddle/pten/core/utils/type_info.h similarity index 95% rename from paddle/tcmpt/core/utils/type_info.h rename to paddle/pten/core/utils/type_info.h index ba5bc641b94b2..4e4084a4c785b 100644 --- a/paddle/tcmpt/core/utils/type_info.h +++ b/paddle/pten/core/utils/type_info.h @@ -16,8 +16,7 @@ limitations under the License. */ #include -namespace paddle { -namespace tcmpt { +namespace pten { template class TypeRegistry; @@ -57,5 +56,4 @@ template const TypeInfo TypeInfoTraits::kType = RegisterStaticType(DerivedT::name()); -} // namespace tcmpt -} // namespace paddle +} // namespace pten diff --git a/paddle/tcmpt/core/utils/type_registry.h b/paddle/pten/core/utils/type_registry.h similarity index 94% rename from paddle/tcmpt/core/utils/type_registry.h rename to paddle/pten/core/utils/type_registry.h index 52b699a0dd413..82eb9ae52bd7e 100644 --- a/paddle/tcmpt/core/utils/type_registry.h +++ b/paddle/pten/core/utils/type_registry.h @@ -18,10 +18,9 @@ limitations under the License. */ #include #include -#include "paddle/tcmpt/core/utils/type_info.h" +#include "paddle/pten/core/utils/type_info.h" -namespace paddle { -namespace tcmpt { +namespace pten { template class TypeRegistry { @@ -82,5 +81,4 @@ template const TypeInfo TypeInfo::kUnknownType = RegisterStaticType("Unknown"); -} // namespace tcmpt -} // namespace paddle +} // namespace pten diff --git a/paddle/pten/hapi/CMakeLists.txt b/paddle/pten/hapi/CMakeLists.txt new file mode 100644 index 0000000000000..8a33de85bddd3 --- /dev/null +++ b/paddle/pten/hapi/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(lib) + +cc_library(pten_hapi SRCS all.cc DEPS math_api linalg_api creation_api) diff --git a/paddle/tcmpt/hapi/all.cc b/paddle/pten/hapi/all.cc similarity index 95% rename from paddle/tcmpt/hapi/all.cc rename to paddle/pten/hapi/all.cc index f43cdb9f78b53..4ea6fabeecf2e 100644 --- a/paddle/tcmpt/hapi/all.cc +++ b/paddle/pten/hapi/all.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/hapi/all.h" +#include "paddle/pten/hapi/all.h" namespace paddle { namespace experimental {} // namespace experimental diff --git a/paddle/tcmpt/hapi/all.h b/paddle/pten/hapi/all.h similarity index 77% rename from paddle/tcmpt/hapi/all.h rename to paddle/pten/hapi/all.h index bd1c51fc49ed3..de2e14db421f6 100644 --- a/paddle/tcmpt/hapi/all.h +++ b/paddle/pten/hapi/all.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once // user apis -#include "paddle/tcmpt/hapi/include/creation.h" -#include "paddle/tcmpt/hapi/include/linalg.h" -#include "paddle/tcmpt/hapi/include/math.h" -#include "paddle/tcmpt/hapi/include/tensor.h" +#include "paddle/pten/hapi/include/creation.h" +#include "paddle/pten/hapi/include/linalg.h" +#include "paddle/pten/hapi/include/math.h" +#include "paddle/pten/hapi/include/tensor.h" diff --git a/paddle/tcmpt/hapi/include/creation.h b/paddle/pten/hapi/include/creation.h similarity index 56% rename from paddle/tcmpt/hapi/include/creation.h rename to paddle/pten/hapi/include/creation.h index d2d68e3bb7e61..3929d8d026e08 100644 --- a/paddle/tcmpt/hapi/include/creation.h +++ b/paddle/pten/hapi/include/creation.h @@ -14,20 +14,25 @@ #pragma once -#include "paddle/tcmpt/common/data_type.h" -#include "paddle/tcmpt/core/scalar.h" -#include "paddle/tcmpt/hapi/include/tensor.h" +#include "paddle/pten/common/data_type.h" +#include "paddle/pten/core/scalar.h" +#include "paddle/pten/hapi/include/tensor.h" namespace paddle { namespace experimental { Tensor full_like(const Tensor& x, - const pt::Scalar& value, - pt::DataType dtype = pt::DataType::kUndef); + const pten::Scalar& value, + paddle::experimental::DataType dtype = + paddle::experimental::DataType::kUndef); -Tensor ones_like(const Tensor& x, pt::DataType dtype = pt::DataType::kUndef); +Tensor ones_like(const Tensor& x, + paddle::experimental::DataType dtype = + paddle::experimental::DataType::kUndef); -Tensor zeros_like(const Tensor& x, pt::DataType dtype = pt::DataType::kUndef); +Tensor zeros_like(const Tensor& x, + paddle::experimental::DataType dtype = + paddle::experimental::DataType::kUndef); } // namespace experimental } // namespace paddle diff --git a/paddle/tcmpt/hapi/include/linalg.h b/paddle/pten/hapi/include/linalg.h similarity index 95% rename from paddle/tcmpt/hapi/include/linalg.h rename to paddle/pten/hapi/include/linalg.h index df709b6a3c50f..6e78b50af11c3 100644 --- a/paddle/tcmpt/hapi/include/linalg.h +++ b/paddle/pten/hapi/include/linalg.h @@ -14,7 +14,7 @@ #pragma once -#include "paddle/tcmpt/hapi/include/tensor.h" +#include "paddle/pten/hapi/include/tensor.h" namespace paddle { namespace experimental { diff --git a/paddle/tcmpt/hapi/include/manipulation.h b/paddle/pten/hapi/include/manipulation.h similarity index 94% rename from paddle/tcmpt/hapi/include/manipulation.h rename to paddle/pten/hapi/include/manipulation.h index 35695f4f6d8b6..4622032f5ad54 100644 --- a/paddle/tcmpt/hapi/include/manipulation.h +++ b/paddle/pten/hapi/include/manipulation.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/hapi/include/tensor.h" +#include "paddle/pten/hapi/include/tensor.h" namespace paddle { namespace experimental { diff --git a/paddle/tcmpt/hapi/include/math.h b/paddle/pten/hapi/include/math.h similarity index 94% rename from paddle/tcmpt/hapi/include/math.h rename to paddle/pten/hapi/include/math.h index 9245d1033c791..0b3dbab70e86f 100644 --- a/paddle/tcmpt/hapi/include/math.h +++ b/paddle/pten/hapi/include/math.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/hapi/include/tensor.h" +#include "paddle/pten/hapi/include/tensor.h" namespace paddle { namespace experimental { diff --git a/paddle/tcmpt/hapi/include/tensor.h b/paddle/pten/hapi/include/tensor.h similarity index 91% rename from paddle/tcmpt/hapi/include/tensor.h rename to paddle/pten/hapi/include/tensor.h index ccca911cf8c86..1982483fe4119 100644 --- a/paddle/tcmpt/hapi/include/tensor.h +++ b/paddle/pten/hapi/include/tensor.h @@ -18,14 +18,14 @@ limitations under the License. */ #include #include -#include "paddle/tcmpt/core/tensor_base.h" +#include "paddle/pten/core/tensor_base.h" /** * [ Why still include the fluid headers? ] * * We hope to organize the basic implementation of Tensor and the logic related * to Tensor computation into an independent library, which we call - * [Tensor Compute Library, tcmpt], so we extract or rewrite the original + * [Tensor Compute Library, pten], so we extract or rewrite the original * Kernels. * * In the future, the training library, inference library and custom operators @@ -54,7 +54,7 @@ class AutogradMetaInterface { /** * Tensor is the API description of the basic data structure in the - * [ Paddle "Tensor CoMPuTe (tcmpt)" Library ]. + * [ Paddle "Tensor CoMPuTe (pten)" Library ]. * * It is not limited to a simple n-dimensional array. * It contains a smart pointer to `TensorImpl`. The data description contained @@ -91,7 +91,7 @@ class Tensor final { * @param {shared_ptr} tensor_impl * @return {Tensor} */ - explicit Tensor(std::shared_ptr tensor_impl) + explicit Tensor(std::shared_ptr tensor_impl) : impl_(std::move(tensor_impl)) { if (impl_.get() == nullptr) { throw std::runtime_error("TensorImpl with nullptr is not supported"); @@ -118,14 +118,14 @@ class Tensor final { * @param None * @return {DataType} */ - pt::DataType type() const { return impl_->data_type(); } + paddle::experimental::DataType type() const { return impl_->data_type(); } /** * @description: Return the layout of current Tensor. * @param None * @return {DataLayout} */ - pt::DataLayout layout() const { return impl_->layout(); } + paddle::experimental::DataLayout layout() const { return impl_->layout(); } /* Part 3: Device and Backend methods */ /** @@ -138,8 +138,8 @@ class Tensor final { /** * Backend judgment APIs, shield the concept of Backend. */ - bool is_cpu() const { return impl_->backend() == pt::Backend::kCPU; } - bool is_cuda() const { return impl_->backend() == pt::Backend::kCUDA; } + bool is_cpu() const { return impl_->backend() == pten::Backend::kCPU; } + bool is_cuda() const { return impl_->backend() == pten::Backend::kCUDA; } bool is_hip() const; bool is_xpu() const; bool is_npu() const; @@ -165,16 +165,14 @@ class Tensor final { * @param None * @return {std::shared_ptr} */ - std::shared_ptr impl() const { return impl_; } + std::shared_ptr impl() const { return impl_; } /** * @description: Set the implemention of current Tensor. * @param {std::shared_ptr} * @return None */ - void set_impl(const std::shared_ptr& impl) { - impl_ = impl; - } + void set_impl(const std::shared_ptr& impl) { impl_ = impl; } // TODO(chenweihang): Whether API Tensor need `data` and `mutable_data`? @@ -245,7 +243,7 @@ class Tensor final { * heterogeneous Tensor implementation, so that the API level can be unified * to one `Tensor`. */ - std::shared_ptr impl_; + std::shared_ptr impl_; /** * [ Why need abstract AutogradMetaInterface here? ] diff --git a/paddle/pten/hapi/lib/CMakeLists.txt b/paddle/pten/hapi/lib/CMakeLists.txt new file mode 100644 index 0000000000000..54cabb7e69baa --- /dev/null +++ b/paddle/pten/hapi/lib/CMakeLists.txt @@ -0,0 +1,4 @@ +cc_library(math_api SRCS math.cc DEPS pten) +cc_library(linalg_api SRCS linalg.cc DEPS pten) +cc_library(creation_api SRCS creation.cc DEPS pten) +cc_library(manipulation_api SRCS manipulation.cc DEPS pten) diff --git a/paddle/tcmpt/hapi/lib/creation.cc b/paddle/pten/hapi/lib/creation.cc similarity index 65% rename from paddle/tcmpt/hapi/lib/creation.cc rename to paddle/pten/hapi/lib/creation.cc index 057855a3dba4c..3004f935f4833 100644 --- a/paddle/tcmpt/hapi/lib/creation.cc +++ b/paddle/pten/hapi/lib/creation.cc @@ -12,36 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/hapi/include/creation.h" +#include "paddle/pten/hapi/include/creation.h" #include #include "glog/logging.h" -#include "paddle/tcmpt/api/include/core.h" -#include "paddle/tcmpt/api/include/infershape.h" -#include "paddle/tcmpt/hapi/lib/kernel_generate.h" +#include "paddle/pten/api/include/core.h" +#include "paddle/pten/api/include/infershape.h" +#include "paddle/pten/hapi/lib/kernel_generate.h" namespace paddle { namespace experimental { -Tensor full_like(const Tensor& x, const pt::Scalar& value, pt::DataType dtype) { +Tensor full_like(const Tensor& x, + const pten::Scalar& value, + paddle::experimental::DataType dtype) { // 1. Get kernel signature and kernel auto kernel_signature = ParseKernelNameAndKeyByArgs("fill_any_like", x); VLOG(1) << kernel_signature.first; VLOG(1) << kernel_signature.second; - VLOG(1) << pt::KernelFactory::Instance(); + VLOG(1) << pten::KernelFactory::Instance(); - auto kernel = pt::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( kernel_signature.first, kernel_signature.second); VLOG(1) << kernel; // 2. Get Device Context auto* dev_ctx = GetDeviceContextByBackend(kernel_signature.second.backend()); - auto kernel_context = pt::KernelContext(*dev_ctx); + auto kernel_context = pten::KernelContext(*dev_ctx); // 3. Auto data transform - auto dense_x = std::dynamic_pointer_cast(x.impl()); + auto dense_x = std::dynamic_pointer_cast(x.impl()); kernel_context.EmplaceBackInput(dense_x); kernel_context.EmplaceBackAttr(value); @@ -52,11 +54,11 @@ Tensor full_like(const Tensor& x, const pt::Scalar& value, pt::DataType dtype) { // 5. Prepare outputs Tensor out; // InferDataType - if (dtype != pt::DataType::kUndef) { + if (dtype != paddle::experimental::DataType::kUndef) { out_meta.type = dtype; } auto dense_out = - std::make_shared(out_meta, pt::TensorStatus()); + std::make_shared(out_meta, pten::TensorStatus()); kernel_context.EmplaceBackOutput(dense_out); out.set_impl(dense_out); @@ -66,11 +68,11 @@ Tensor full_like(const Tensor& x, const pt::Scalar& value, pt::DataType dtype) { return out; } -Tensor ones_like(const Tensor& x, pt::DataType dtype) { +Tensor ones_like(const Tensor& x, paddle::experimental::DataType dtype) { return full_like(x, 1, dtype); } -Tensor zeros_like(const Tensor& x, pt::DataType dtype) { +Tensor zeros_like(const Tensor& x, paddle::experimental::DataType dtype) { return full_like(x, 0, dtype); } diff --git a/paddle/tcmpt/hapi/lib/kernel_generate.h b/paddle/pten/hapi/lib/kernel_generate.h similarity index 86% rename from paddle/tcmpt/hapi/lib/kernel_generate.h rename to paddle/pten/hapi/lib/kernel_generate.h index 1b5f9d7ae02ac..82214c96fb5c7 100644 --- a/paddle/tcmpt/hapi/lib/kernel_generate.h +++ b/paddle/pten/hapi/lib/kernel_generate.h @@ -17,10 +17,10 @@ limitations under the License. */ #include #include -#include "paddle/tcmpt/hapi/include/tensor.h" +#include "paddle/pten/hapi/include/tensor.h" // TODO(chenweihang): split KernelName, Key, Kernel, Factory into diff files -#include "paddle/tcmpt/core/kernel_factory.h" +#include "paddle/pten/core/kernel_factory.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" @@ -61,9 +61,9 @@ struct ArgsIterator { struct KernelNameAndKeyParser : ArgsIterator { std::string kernel_name; - pt::Backend backend; - pt::DataLayout layout; - pt::DataType dtype; + pten::Backend backend; + paddle::experimental::DataLayout layout; + paddle::experimental::DataType dtype; explicit KernelNameAndKeyParser(const std::string& name) : kernel_name(name) {} @@ -72,9 +72,9 @@ struct KernelNameAndKeyParser : ArgsIterator { // TODO(chenweihang): deal with multiple diff input Tensors void operator()(const Tensor& x) { if (x.is_cpu()) { - backend = pt::Backend::kCPU; + backend = pten::Backend::kCPU; } else if (x.is_cuda()) { - backend = pt::Backend::kCUDA; + backend = pten::Backend::kCUDA; } else { throw std::runtime_error("Unsupported backend when parser args."); } @@ -97,20 +97,20 @@ struct KernelNameAndKeyParser : ArgsIterator { // suffix on the basis of the function name, or the input contains HostTensor, // and the `host` suffix should be added on the basis of the function name. template -std::pair ParseKernelNameAndKeyByArgs( +std::pair ParseKernelNameAndKeyByArgs( const std::string& fn_name, const Args&... args) { auto parser = detail::KernelNameAndKeyParser(fn_name); parser(args...); // TODO(chenweihang): polish design here - pt::KernelName kernel_name(parser.kernel_name); - pt::KernelKey kernel_key(parser.backend, parser.layout, parser.dtype); + pten::KernelName kernel_name(parser.kernel_name); + pten::KernelKey kernel_key(parser.backend, parser.layout, parser.dtype); return std::make_pair(kernel_name, kernel_key); } paddle::platform::DeviceContext* GetDeviceContextByBackend( - pt::Backend backend) { + pten::Backend backend) { auto& pool = paddle::platform::DeviceContextPool::Instance(); - auto place = pt::TransToFluidPlace(backend); + auto place = pten::TransToFluidPlace(backend); // switch (backend) { // case Backend::kCPU: // return pool.GetByPlace(paddle::platform::CPUPlace()); diff --git a/paddle/tcmpt/hapi/lib/linalg.cc b/paddle/pten/hapi/lib/linalg.cc similarity index 69% rename from paddle/tcmpt/hapi/lib/linalg.cc rename to paddle/pten/hapi/lib/linalg.cc index dc11bae3e37b7..c8198052f43b0 100644 --- a/paddle/tcmpt/hapi/lib/linalg.cc +++ b/paddle/pten/hapi/lib/linalg.cc @@ -12,19 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/hapi/include/linalg.h" +#include "paddle/pten/hapi/include/linalg.h" #include #include "glog/logging.h" -#include "paddle/tcmpt/api/include/core.h" -#include "paddle/tcmpt/api/include/infershape.h" -#include "paddle/tcmpt/core/convert_utils.h" -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/kernel_context.h" -#include "paddle/tcmpt/hapi/lib/kernel_generate.h" -#include "paddle/tcmpt/infershape/binary.h" +#include "paddle/pten/api/include/core.h" +#include "paddle/pten/api/include/infershape.h" +#include "paddle/pten/core/convert_utils.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_context.h" +#include "paddle/pten/hapi/lib/kernel_generate.h" +#include "paddle/pten/infershape/binary.h" namespace paddle { namespace experimental { @@ -34,20 +34,20 @@ Tensor dot(const Tensor& x, const Tensor& y) { auto kernel_signature = ParseKernelNameAndKeyByArgs("dot", x); VLOG(1) << kernel_signature.first; VLOG(1) << kernel_signature.second; - VLOG(1) << pt::KernelFactory::Instance(); + VLOG(1) << pten::KernelFactory::Instance(); - auto kernel = pt::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( kernel_signature.first, kernel_signature.second); VLOG(1) << kernel; // 2. Get Device Context auto* dev_ctx = GetDeviceContextByBackend(kernel_signature.second.backend()); - auto kernel_context = pt::KernelContext(*dev_ctx); + auto kernel_context = pten::KernelContext(*dev_ctx); // 3. Auto data transform - auto dense_x = std::dynamic_pointer_cast(x.impl()); + auto dense_x = std::dynamic_pointer_cast(x.impl()); kernel_context.EmplaceBackInput(dense_x); - auto dense_y = std::dynamic_pointer_cast(y.impl()); + auto dense_y = std::dynamic_pointer_cast(y.impl()); kernel_context.EmplaceBackInput(dense_y); // TODO(chenweihang): add transform impl @@ -59,7 +59,7 @@ Tensor dot(const Tensor& x, const Tensor& y) { Tensor out; // TODO(chenweihang): deal with multiple outputs auto dense_out = - std::make_shared(out_meta, pt::TensorStatus()); + std::make_shared(out_meta, pten::TensorStatus()); kernel_context.EmplaceBackOutput(dense_out); out.set_impl(dense_out); diff --git a/paddle/tcmpt/hapi/lib/manipulation.cc b/paddle/pten/hapi/lib/manipulation.cc similarity index 77% rename from paddle/tcmpt/hapi/lib/manipulation.cc rename to paddle/pten/hapi/lib/manipulation.cc index c8448eecfe2de..8a64d0e9f4a45 100644 --- a/paddle/tcmpt/hapi/lib/manipulation.cc +++ b/paddle/pten/hapi/lib/manipulation.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/hapi/include/manipulation.h" +#include "paddle/pten/hapi/include/manipulation.h" #include #include "glog/logging.h" -#include "paddle/tcmpt/api/include/core.h" -#include "paddle/tcmpt/hapi/lib/kernel_generate.h" -#include "paddle/tcmpt/infershape/unary.h" +#include "paddle/pten/api/include/core.h" +#include "paddle/pten/hapi/lib/kernel_generate.h" +#include "paddle/pten/infershape/unary.h" namespace paddle { namespace experimental { @@ -30,18 +30,18 @@ Tensor flatten(const Tensor& x, int start_axis, int stop_axis) { ParseKernelNameAndKeyByArgs("flatten_contiguous_range", x); VLOG(1) << kernel_signature.first; VLOG(1) << kernel_signature.second; - VLOG(1) << pt::KernelFactory::Instance(); + VLOG(1) << pten::KernelFactory::Instance(); - auto kernel = pt::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( kernel_signature.first, kernel_signature.second); VLOG(1) << kernel; // 2. Get Device Context auto* dev_ctx = GetDeviceContextByBackend(kernel_signature.second.backend()); - auto kernel_context = pt::KernelContext(*dev_ctx); + auto kernel_context = pten::KernelContext(*dev_ctx); // 3. Auto data transform - auto dense_x = std::dynamic_pointer_cast(x.impl()); + auto dense_x = std::dynamic_pointer_cast(x.impl()); kernel_context.EmplaceBackInput(dense_x); kernel_context.EmplaceBackAttr(start_axis); kernel_context.EmplaceBackAttr(stop_axis); @@ -54,7 +54,7 @@ Tensor flatten(const Tensor& x, int start_axis, int stop_axis) { Tensor out; // TODO(chenweihang): deal with multiple outputs auto dense_out = - std::make_shared(out_meta, pt::TensorStatus()); + std::make_shared(out_meta, pten::TensorStatus()); kernel_context.EmplaceBackOutput(dense_out); out.set_impl(dense_out); diff --git a/paddle/tcmpt/hapi/lib/math.cc b/paddle/pten/hapi/lib/math.cc similarity index 75% rename from paddle/tcmpt/hapi/lib/math.cc rename to paddle/pten/hapi/lib/math.cc index 531e85298758c..764511702f0ea 100644 --- a/paddle/tcmpt/hapi/lib/math.cc +++ b/paddle/pten/hapi/lib/math.cc @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/hapi/include/math.h" +#include "paddle/pten/hapi/include/math.h" #include #include "glog/logging.h" -#include "paddle/tcmpt/api/include/core.h" -#include "paddle/tcmpt/api/include/infershape.h" -#include "paddle/tcmpt/hapi/lib/kernel_generate.h" -#include "paddle/tcmpt/infershape/unary.h" +#include "paddle/pten/api/include/core.h" +#include "paddle/pten/api/include/infershape.h" +#include "paddle/pten/hapi/lib/kernel_generate.h" +#include "paddle/pten/infershape/unary.h" namespace paddle { namespace experimental { @@ -31,18 +31,18 @@ Tensor mean(const Tensor& x) { auto kernel_signature = ParseKernelNameAndKeyByArgs("mean", x); VLOG(1) << kernel_signature.first; VLOG(1) << kernel_signature.second; - VLOG(1) << pt::KernelFactory::Instance(); + VLOG(1) << pten::KernelFactory::Instance(); - auto kernel = pt::KernelFactory::Instance().SelectKernelOrThrowError( + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( kernel_signature.first, kernel_signature.second); VLOG(1) << kernel; // 2. Get Device Context auto* dev_ctx = GetDeviceContextByBackend(kernel_signature.second.backend()); - auto kernel_context = pt::KernelContext(*dev_ctx); + auto kernel_context = pten::KernelContext(*dev_ctx); // 3. Auto data transform - auto dense_x = std::dynamic_pointer_cast(x.impl()); + auto dense_x = std::dynamic_pointer_cast(x.impl()); kernel_context.EmplaceBackInput(dense_x); // TODO(chenweihang): add transform impl @@ -54,7 +54,7 @@ Tensor mean(const Tensor& x) { Tensor out; // TODO(chenweihang): deal with multiple outputs auto dense_out = - std::make_shared(out_meta, pt::TensorStatus()); + std::make_shared(out_meta, pten::TensorStatus()); kernel_context.EmplaceBackOutput(dense_out); out.set_impl(dense_out); diff --git a/paddle/tcmpt/infershape/CMakeLists.txt b/paddle/pten/infershape/CMakeLists.txt similarity index 100% rename from paddle/tcmpt/infershape/CMakeLists.txt rename to paddle/pten/infershape/CMakeLists.txt diff --git a/paddle/tcmpt/infershape/binary.cc b/paddle/pten/infershape/binary.cc similarity index 96% rename from paddle/tcmpt/infershape/binary.cc rename to paddle/pten/infershape/binary.cc index 936af8767ca62..7d224835cc05a 100644 --- a/paddle/tcmpt/infershape/binary.cc +++ b/paddle/pten/infershape/binary.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ // See Note [ Why still include the fluid headers? ] -#include "paddle/tcmpt/infershape/binary.h" +#include "paddle/pten/infershape/binary.h" -namespace pt { +namespace pten { TensorMeta DotInferShape(const TensorMeta& x_meta, const TensorMeta& y_meta) { auto x_dims = x_meta.dims; @@ -59,4 +59,4 @@ TensorMeta DotInferShape(const TensorMeta& x_meta, const TensorMeta& y_meta) { return return_meta; } -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/infershape/binary.h b/paddle/pten/infershape/binary.h similarity index 94% rename from paddle/tcmpt/infershape/binary.h rename to paddle/pten/infershape/binary.h index 816963a277ade..8e44b520e0a9f 100644 --- a/paddle/tcmpt/infershape/binary.h +++ b/paddle/pten/infershape/binary.h @@ -15,9 +15,9 @@ limitations under the License. */ #pragma once // See Note [ Why still include the fluid headers? ] -#include "paddle/tcmpt/core/tensor_meta.h" +#include "paddle/pten/core/tensor_meta.h" -namespace pt { +namespace pten { // Common InferShape Functions for binary operators, The format like: // @@ -32,4 +32,4 @@ namespace pt { TensorMeta DotInferShape(const TensorMeta& x_meta, const TensorMeta& y_meta); -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/infershape/unary.cc b/paddle/pten/infershape/unary.cc similarity index 96% rename from paddle/tcmpt/infershape/unary.cc rename to paddle/pten/infershape/unary.cc index 3e4a633fa7a7c..57e74345b7d42 100644 --- a/paddle/tcmpt/infershape/unary.cc +++ b/paddle/pten/infershape/unary.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ // See Note [ Why still include the fluid headers? ] -#include "paddle/tcmpt/infershape/unary.h" +#include "paddle/pten/infershape/unary.h" -namespace pt { +namespace pten { TensorMeta UnchangedInferShape(const TensorMeta& x_meta) { return x_meta; } @@ -74,4 +74,4 @@ TensorMeta FlattenInferShape(const TensorMeta& x_meta, return return_meta; } -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/infershape/unary.h b/paddle/pten/infershape/unary.h similarity index 94% rename from paddle/tcmpt/infershape/unary.h rename to paddle/pten/infershape/unary.h index b835ec4bcfa72..1d8fac05d0eaa 100644 --- a/paddle/tcmpt/infershape/unary.h +++ b/paddle/pten/infershape/unary.h @@ -15,9 +15,9 @@ limitations under the License. */ #pragma once // See Note [ Why still include the fluid headers? ] -#include "paddle/tcmpt/core/tensor_meta.h" +#include "paddle/pten/core/tensor_meta.h" -namespace pt { +namespace pten { // Common InferShape Functions for unary operators, The format like: // @@ -38,4 +38,4 @@ TensorMeta FlattenInferShape(const TensorMeta& x_meta, int start_axis, int stop_axis); -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/kernels/CMakeLists.txt b/paddle/pten/kernels/CMakeLists.txt similarity index 94% rename from paddle/tcmpt/kernels/CMakeLists.txt rename to paddle/pten/kernels/CMakeLists.txt index 26b5e16d4428d..09f7a1b102436 100644 --- a/paddle/tcmpt/kernels/CMakeLists.txt +++ b/paddle/pten/kernels/CMakeLists.txt @@ -1,4 +1,4 @@ -# tcmpt kernels for diff device +# pten kernels for diff device add_subdirectory(cpu) if(WITH_GPU OR WITH_ROCM) # TODO(chenweihang): if hip can split from cuda impl, we should add hip dir diff --git a/paddle/tcmpt/kernels/common/eigen/CMakeLists.txt b/paddle/pten/kernels/common/eigen/CMakeLists.txt similarity index 100% rename from paddle/tcmpt/kernels/common/eigen/CMakeLists.txt rename to paddle/pten/kernels/common/eigen/CMakeLists.txt diff --git a/paddle/tcmpt/kernels/common/eigen/common.h b/paddle/pten/kernels/common/eigen/common.h similarity index 86% rename from paddle/tcmpt/kernels/common/eigen/common.h rename to paddle/pten/kernels/common/eigen/common.h index 37bed55a7d97a..f3a6f5fb51ff2 100644 --- a/paddle/tcmpt/kernels/common/eigen/common.h +++ b/paddle/pten/kernels/common/eigen/common.h @@ -16,10 +16,10 @@ limitations under the License. */ #include -#include "paddle/tcmpt/core/dense_tensor.h" +#include "paddle/pten/core/dense_tensor.h" #include "unsupported/Eigen/CXX11/Tensor" -namespace pt { +namespace pten { // EigenDim converts paddle::platform::DDim into Eigen::DSizes. template @@ -55,24 +55,24 @@ struct EigenTensor { using ConstType = Eigen::TensorMap>; - static Type From(pt::DenseTensor& tensor, DDim dims) { // NOLINT + static Type From(pten::DenseTensor& tensor, DDim dims) { // NOLINT // why tensor.data() not work? // return Type(const_cast(reinterpret_cast(tensor.data())), // EigenDim::From(dims)); return Type(const_cast(tensor.data()), EigenDim::From(dims)); } - static Type From(pt::DenseTensor& tensor) { // NOLINT + static Type From(pten::DenseTensor& tensor) { // NOLINT return From(tensor, tensor.dims()); } // NOLINT - static ConstType From(const pt::DenseTensor& tensor, DDim dims) { + static ConstType From(const pten::DenseTensor& tensor, DDim dims) { // return ConstType(reinterpret_cast(tensor.data()), // EigenDim::From(dims)); return ConstType(tensor.data(), EigenDim::From(dims)); } - static ConstType From(const pt::DenseTensor& tensor) { + static ConstType From(const pten::DenseTensor& tensor) { return From(tensor, tensor.dims()); } }; @@ -81,8 +81,9 @@ template struct EigenMatrix : public EigenTensor { - static typename EigenMatrix::Type Reshape(pt::DenseTensor& tensor, // NOLINT - int num_col_dims) { + static typename EigenMatrix::Type Reshape( + pten::DenseTensor& tensor, // NOLINT + int num_col_dims) { int rank = tensor.dims().size(); PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true, @@ -95,8 +96,8 @@ struct EigenMatrix : public EigenTensor { flatten_to_2d(tensor.dims(), num_col_dims)); } - static typename EigenMatrix::ConstType Reshape(const pt::DenseTensor& tensor, - int num_col_dims) { + static typename EigenMatrix::ConstType Reshape( + const pten::DenseTensor& tensor, int num_col_dims) { int rank = tensor.dims().size(); PADDLE_ENFORCE_EQ((num_col_dims > 0 && num_col_dims < rank), true, @@ -116,12 +117,12 @@ template { // Flatten reshapes a Tensor into an EigenVector. static typename EigenVector::Type Flatten( - pt::DenseTensor& tensor) { // NOLINT + pten::DenseTensor& tensor) { // NOLINT return EigenVector::From(tensor, {product(tensor.dims())}); } static typename EigenVector::ConstType Flatten( - const pt::DenseTensor& tensor) { // NOLINT + const pten::DenseTensor& tensor) { // NOLINT return EigenVector::From(tensor, {product(tensor.dims())}); } }; @@ -136,11 +137,11 @@ struct EigenScalar { using ConstType = Eigen::TensorMap< Eigen::TensorFixedSize, MajorType, IndexType>>; - static Type From(pt::DenseTensor& tensor) { // NOLINT + static Type From(pten::DenseTensor& tensor) { // NOLINT return Type(const_cast(tensor.data())); } - static ConstType From(const pt::DenseTensor& tensor) { + static ConstType From(const pten::DenseTensor& tensor) { return ConstType(tensor.data()); } }; @@ -167,4 +168,4 @@ To32BitIndex(EigenTensor in) { return RetType(in.data(), To32BitDims(in.dimensions())); } -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/kernels/common/eigen/dot.h b/paddle/pten/kernels/common/eigen/dot.h similarity index 72% rename from paddle/tcmpt/kernels/common/eigen/dot.h rename to paddle/pten/kernels/common/eigen/dot.h index 32c1e1439fac7..8a7789f3dfb64 100644 --- a/paddle/tcmpt/kernels/common/eigen/dot.h +++ b/paddle/pten/kernels/common/eigen/dot.h @@ -14,13 +14,13 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/kernels/common/eigen/common.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/common/eigen/common.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/operators/eigen/eigen_function.h" -namespace pt { +namespace pten { namespace eigen { template @@ -30,16 +30,16 @@ void Dot(const DevCtx& dev_ctx, DenseTensor* out) { out->mutable_data(); if (1 == out->dims().size()) { - auto eigen_out = pt::EigenScalar::From(*out); - auto eigen_x = pt::EigenVector::Flatten(x); - auto eigen_y = pt::EigenVector::Flatten(y); + auto eigen_out = pten::EigenScalar::From(*out); + auto eigen_x = pten::EigenVector::Flatten(x); + auto eigen_y = pten::EigenVector::Flatten(y); auto& dev = *dev_ctx.eigen_device(); eigen_out.device(dev) = (eigen_x * eigen_y).sum(); } else { - auto eigen_out = pt::EigenMatrix::From(*out); - auto eigen_x = pt::EigenMatrix::From(x); - auto eigen_y = pt::EigenMatrix::From(y); + auto eigen_out = pten::EigenMatrix::From(*out); + auto eigen_x = pten::EigenMatrix::From(x); + auto eigen_y = pten::EigenMatrix::From(y); auto& dev = *dev_ctx.eigen_device(); eigen_out.device(dev) = (eigen_x * eigen_y).sum(Eigen::DSizes(1)); @@ -47,4 +47,4 @@ void Dot(const DevCtx& dev_ctx, } } // namespace eigen -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/kernels/common/eigen/fill.h b/paddle/pten/kernels/common/eigen/fill.h similarity index 91% rename from paddle/tcmpt/kernels/common/eigen/fill.h rename to paddle/pten/kernels/common/eigen/fill.h index 186163c3fedc4..df76194839ed7 100644 --- a/paddle/tcmpt/kernels/common/eigen/fill.h +++ b/paddle/pten/kernels/common/eigen/fill.h @@ -14,13 +14,13 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/kernels/common/eigen/common.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/common/eigen/common.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/operators/eigen/eigen_function.h" -namespace pt { +namespace pten { namespace eigen { template @@ -51,9 +51,9 @@ void fill(const DeviceContext& context, DenseTensor* tensor, VType val) { static_cast(std::numeric_limits::max()), static_cast(val))); - auto t = pt::EigenVector::Flatten(*tensor); + auto t = pten::EigenVector::Flatten(*tensor); t.device(*context.eigen_device()) = t.constant(static_cast(val)); } } // namespace eigen -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/kernels/common/eigen/mean.h b/paddle/pten/kernels/common/eigen/mean.h similarity index 82% rename from paddle/tcmpt/kernels/common/eigen/mean.h rename to paddle/pten/kernels/common/eigen/mean.h index 2b1ea95940727..9ee5ab12c9332 100644 --- a/paddle/tcmpt/kernels/common/eigen/mean.h +++ b/paddle/pten/kernels/common/eigen/mean.h @@ -14,13 +14,13 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/kernels/common/eigen/common.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/common/eigen/common.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/operators/eigen/eigen_function.h" -namespace pt { +namespace pten { namespace eigen { template @@ -30,12 +30,12 @@ void Mean(const DevCtx& dev_ctx, const DenseTensor& x, DenseTensor* out) { // TODO(chenweihang): if we design new tensor, we should support // the low-level calc functor use new tensor as input, // which may be a big project! - auto eigen_x = pt::EigenVector::Flatten(x); - auto eigen_out = pt::EigenScalar::From(*out); + auto eigen_x = pten::EigenVector::Flatten(x); + auto eigen_out = pten::EigenScalar::From(*out); auto& dev = *dev_ctx.eigen_device(); eigen_out.device(dev) = eigen_x.mean(); } } // namespace eigen -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/kernels/common/eigen/scale.h b/paddle/pten/kernels/common/eigen/scale.h similarity index 85% rename from paddle/tcmpt/kernels/common/eigen/scale.h rename to paddle/pten/kernels/common/eigen/scale.h index 0f3e92d9db787..fda15302e2971 100644 --- a/paddle/tcmpt/kernels/common/eigen/scale.h +++ b/paddle/pten/kernels/common/eigen/scale.h @@ -14,13 +14,13 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/kernels/common/eigen/common.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/common/eigen/common.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/operators/eigen/eigen_function.h" -namespace pt { +namespace pten { namespace eigen { template @@ -32,8 +32,8 @@ void Scale(const DevCtx& dev_ctx, DenseTensor* out) { // calc out->mutable_data(); - auto eigen_out = pt::EigenVector::Flatten(*out); - auto eigen_x = pt::EigenVector::Flatten(x); + auto eigen_out = pten::EigenVector::Flatten(*out); + auto eigen_x = pten::EigenVector::Flatten(x); auto& dev = *dev_ctx.eigen_device(); // TODO(chenweihang): now the eigen function here need the dtype of scale, // eigen_x, bias should be same, so here need cast for two scalar arg, @@ -48,4 +48,4 @@ void Scale(const DevCtx& dev_ctx, } } // namespace eigen -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/kernels/common/eigen/sign.h b/paddle/pten/kernels/common/eigen/sign.h similarity index 84% rename from paddle/tcmpt/kernels/common/eigen/sign.h rename to paddle/pten/kernels/common/eigen/sign.h index 3980976ac9cf5..1e60965b1d91b 100644 --- a/paddle/tcmpt/kernels/common/eigen/sign.h +++ b/paddle/pten/kernels/common/eigen/sign.h @@ -14,13 +14,13 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/kernels/common/eigen/common.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/common/eigen/common.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/operators/eigen/eigen_function.h" -namespace pt { +namespace pten { namespace eigen { template @@ -33,8 +33,8 @@ void Sign(const DevCtx& dev_ctx, const DenseTensor& x, DenseTensor* out) { // TODO(chenweihang): if we design new tensor, we should support // the low-level calc functor use new tensor as input, // which may be a big project! - auto eigen_out = pt::EigenVector::Flatten(*out); - auto eigen_x = pt::EigenVector::Flatten(x); + auto eigen_out = pten::EigenVector::Flatten(*out); + auto eigen_x = pten::EigenVector::Flatten(x); auto& dev = *dev_ctx.eigen_device(); paddle::operators::EigenSign, T>::Eval( @@ -42,4 +42,4 @@ void Sign(const DevCtx& dev_ctx, const DenseTensor& x, DenseTensor* out) { } } // namespace eigen -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/kernels/cpu/CMakeLists.txt b/paddle/pten/kernels/cpu/CMakeLists.txt similarity index 89% rename from paddle/tcmpt/kernels/cpu/CMakeLists.txt rename to paddle/pten/kernels/cpu/CMakeLists.txt index b70c5f9ec81f0..9536f7e7d50f5 100644 --- a/paddle/tcmpt/kernels/cpu/CMakeLists.txt +++ b/paddle/pten/kernels/cpu/CMakeLists.txt @@ -1,5 +1,5 @@ if(WIN32) - set(CURRENT_BINARY_DIR ${PADDLE_BINARY_DIR}/paddle/tcmpt/kernels/cpu) + set(CURRENT_BINARY_DIR ${PADDLE_BINARY_DIR}/paddle/pten/kernels/cpu) kernel_instantiate(creation.cc) kernel_instantiate(math.cc) kernel_instantiate(linalg.cc) diff --git a/paddle/tcmpt/kernels/cpu/creation.cc b/paddle/pten/kernels/cpu/creation.cc similarity index 84% rename from paddle/tcmpt/kernels/cpu/creation.cc rename to paddle/pten/kernels/cpu/creation.cc index 37b589d776822..c150a7f5ae442 100644 --- a/paddle/tcmpt/kernels/cpu/creation.cc +++ b/paddle/pten/kernels/cpu/creation.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/tcmpt/kernels/cpu/creation.h" +#include "paddle/pten/kernels/cpu/creation.h" -#include "paddle/tcmpt/core/kernel_registry.h" -#include "paddle/tcmpt/kernels/common/eigen/fill.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/common/eigen/fill.h" -namespace pt { +namespace pten { template void FillAnyLike(const CPUContext& dev_ctx, @@ -27,14 +27,14 @@ void FillAnyLike(const CPUContext& dev_ctx, eigen::fill(dev_ctx, out, val.to()); } -} // namespace pt +} // namespace pten PT_REGISTER_MODULE(CreationCPU); PT_REGISTER_KERNEL("fill_any_like", CPU, Any, - pt::FillAnyLike, + pten::FillAnyLike, float, double, int, diff --git a/paddle/tcmpt/kernels/cpu/creation.h b/paddle/pten/kernels/cpu/creation.h similarity index 88% rename from paddle/tcmpt/kernels/cpu/creation.h rename to paddle/pten/kernels/cpu/creation.h index 2c67945892b82..7674e6bb05157 100644 --- a/paddle/tcmpt/kernels/cpu/creation.h +++ b/paddle/pten/kernels/cpu/creation.h @@ -14,12 +14,12 @@ #pragma once -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/scalar.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/scalar.h" #include "paddle/fluid/platform/device_context.h" -namespace pt { +namespace pten { using CPUContext = paddle::platform::CPUDeviceContext; @@ -29,4 +29,4 @@ void FillAnyLike(const CPUContext& dev_ctx, const Scalar& val, DenseTensor* out); -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/kernels/cpu/linalg.cc b/paddle/pten/kernels/cpu/linalg.cc similarity index 92% rename from paddle/tcmpt/kernels/cpu/linalg.cc rename to paddle/pten/kernels/cpu/linalg.cc index 821cd5c092e85..5da375c99e91d 100644 --- a/paddle/tcmpt/kernels/cpu/linalg.cc +++ b/paddle/pten/kernels/cpu/linalg.cc @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/tcmpt/kernels/cpu/linalg.h" +#include "paddle/pten/kernels/cpu/linalg.h" -#include "paddle/tcmpt/core/kernel_registry.h" +#include "paddle/pten/core/kernel_registry.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/complex.h" -namespace pt { +namespace pten { template void Dot(const CPUContext& dev_ctx, @@ -53,7 +53,7 @@ void matmul(const CPUContext& dev_ctx, bool transpose_y, DenseTensor* out) {} -} // namespace pt +} // namespace pten PT_REGISTER_MODULE(LinalgCPU); @@ -63,7 +63,7 @@ using complex128 = ::paddle::platform::complex; PT_REGISTER_KERNEL("dot", CPU, Any, - pt::Dot, + pten::Dot, float, double, int, diff --git a/paddle/tcmpt/kernels/cpu/linalg.h b/paddle/pten/kernels/cpu/linalg.h similarity index 93% rename from paddle/tcmpt/kernels/cpu/linalg.h rename to paddle/pten/kernels/cpu/linalg.h index 6d9550b2882b2..a9447be74934c 100644 --- a/paddle/tcmpt/kernels/cpu/linalg.h +++ b/paddle/pten/kernels/cpu/linalg.h @@ -14,12 +14,12 @@ #pragma once -#include "paddle/tcmpt/core/dense_tensor.h" +#include "paddle/pten/core/dense_tensor.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" -namespace pt { +namespace pten { using CPUContext = paddle::platform::CPUDeviceContext; @@ -37,4 +37,4 @@ void matmul(const CPUContext& dev_ctx, bool transpose_y, DenseTensor* out); -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/kernels/cpu/manipulation.cc b/paddle/pten/kernels/cpu/manipulation.cc similarity index 89% rename from paddle/tcmpt/kernels/cpu/manipulation.cc rename to paddle/pten/kernels/cpu/manipulation.cc index edf7f5aff0389..8bc3fcc14cf7e 100644 --- a/paddle/tcmpt/kernels/cpu/manipulation.cc +++ b/paddle/pten/kernels/cpu/manipulation.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/tcmpt/kernels/cpu/manipulation.h" -#include "paddle/tcmpt/infershape/unary.h" -#include "paddle/tcmpt/kernels/cpu/utils.h" +#include "paddle/pten/kernels/cpu/manipulation.h" +#include "paddle/pten/infershape/unary.h" +#include "paddle/pten/kernels/cpu/utils.h" -namespace pt { +namespace pten { template void Flatten(const CPUContext& dev_ctx, @@ -25,7 +25,7 @@ void Flatten(const CPUContext& dev_ctx, int stop_axis, DenseTensor* out) { auto out_meta = FlattenInferShape(x.meta(), start_axis, stop_axis); - pt::Copy(dev_ctx, x, out); + pten::Copy(dev_ctx, x, out); out->mutable_meta()->lod = out_meta.lod; out->Resize(out_meta.dims); } @@ -51,7 +51,7 @@ void FlattenWithXShape(const CPUContext& dev_ctx, xshape->mutable_meta()->lod = x.meta().lod; } -} // namespace pt +} // namespace pten // TODO(chenweihang): replace by better impl PT_REGISTER_MODULE(ManipulationCPU); @@ -61,7 +61,7 @@ PT_REGISTER_MODULE(ManipulationCPU); PT_REGISTER_KERNEL("flatten_contiguous_range", CPU, Any, - pt::Flatten, + pten::Flatten, float, double, uint8_t, @@ -72,7 +72,7 @@ PT_REGISTER_KERNEL("flatten_contiguous_range", PT_REGISTER_KERNEL("flatten_contiguous_range.mid", CPU, Any, - pt::FlattenWithXShape, + pten::FlattenWithXShape, float, double, uint8_t, diff --git a/paddle/tcmpt/kernels/cpu/manipulation.h b/paddle/pten/kernels/cpu/manipulation.h similarity index 88% rename from paddle/tcmpt/kernels/cpu/manipulation.h rename to paddle/pten/kernels/cpu/manipulation.h index 0147dca441b25..22dfb0d8fccba 100644 --- a/paddle/tcmpt/kernels/cpu/manipulation.h +++ b/paddle/pten/kernels/cpu/manipulation.h @@ -14,13 +14,13 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/kernel_registry.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" -namespace pt { +namespace pten { using CPUContext = paddle::platform::CPUDeviceContext; @@ -31,4 +31,4 @@ void Flatten(const CPUContext& dev_ctx, int stop_axis, DenseTensor* out); -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/kernels/cpu/math.cc b/paddle/pten/kernels/cpu/math.cc similarity index 85% rename from paddle/tcmpt/kernels/cpu/math.cc rename to paddle/pten/kernels/cpu/math.cc index 4fa14141209a1..4fbd7cf04bf45 100644 --- a/paddle/tcmpt/kernels/cpu/math.cc +++ b/paddle/pten/kernels/cpu/math.cc @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/tcmpt/kernels/cpu/math.h" +#include "paddle/pten/kernels/cpu/math.h" -#include "paddle/tcmpt/kernels/common/eigen/mean.h" -#include "paddle/tcmpt/kernels/common/eigen/scale.h" -#include "paddle/tcmpt/kernels/common/eigen/sign.h" +#include "paddle/pten/kernels/common/eigen/mean.h" +#include "paddle/pten/kernels/common/eigen/scale.h" +#include "paddle/pten/kernels/common/eigen/sign.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/platform/bfloat16.h" -namespace pt { +namespace pten { template void Sign(const CPUContext& dev_ctx, const DenseTensor& x, DenseTensor* out) { @@ -61,7 +61,7 @@ void ScaleHost(const CPUContext& dev_ctx, out); } -} // namespace pt +} // namespace pten // TODO(chenweihang): replace by better impl PT_REGISTER_MODULE(MathCPU); @@ -69,12 +69,12 @@ PT_REGISTER_MODULE(MathCPU); // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // using bfloat16 = ::paddle::platform::bfloat16; -PT_REGISTER_KERNEL("sign", CPU, Any, pt::Sign, float, double) {} -PT_REGISTER_KERNEL("mean", CPU, Any, pt::Mean, float, double) {} +PT_REGISTER_KERNEL("sign", CPU, Any, pten::Sign, float, double) {} +PT_REGISTER_KERNEL("mean", CPU, Any, pten::Mean, float, double) {} PT_REGISTER_KERNEL("scale", CPU, Any, - pt::Scale, + pten::Scale, float, double, paddle::platform::bfloat16, @@ -86,7 +86,7 @@ PT_REGISTER_KERNEL("scale", PT_REGISTER_KERNEL("scale.host", CPU, Any, - pt::ScaleHost, + pten::ScaleHost, float, double, paddle::platform::bfloat16, @@ -95,5 +95,5 @@ PT_REGISTER_KERNEL("scale.host", int16_t, int, int64_t) { - kernel->InputAt(1).SetBackend(pt::Backend::kCPU); + kernel->InputAt(1).SetBackend(pten::Backend::kCPU); } diff --git a/paddle/tcmpt/kernels/cpu/math.h b/paddle/pten/kernels/cpu/math.h similarity index 91% rename from paddle/tcmpt/kernels/cpu/math.h rename to paddle/pten/kernels/cpu/math.h index 3fb669b084095..3013ad9d04d0b 100644 --- a/paddle/tcmpt/kernels/cpu/math.h +++ b/paddle/pten/kernels/cpu/math.h @@ -14,13 +14,13 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/kernel_registry.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" -namespace pt { +namespace pten { using CPUContext = paddle::platform::CPUDeviceContext; @@ -46,4 +46,4 @@ void ScaleHost(const CPUContext& dev_ctx, bool bias_after_scale, DenseTensor* out); -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/kernels/cpu/utils.cc b/paddle/pten/kernels/cpu/utils.cc similarity index 89% rename from paddle/tcmpt/kernels/cpu/utils.cc rename to paddle/pten/kernels/cpu/utils.cc index a50cfad481693..f79a0a34fa6fd 100644 --- a/paddle/tcmpt/kernels/cpu/utils.cc +++ b/paddle/pten/kernels/cpu/utils.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/kernels/cpu/utils.h" +#include "paddle/pten/kernels/cpu/utils.h" #include "paddle/fluid/memory/memcpy.h" -#include "paddle/tcmpt/common/data_type.h" -#include "paddle/tcmpt/core/convert_utils.h" +#include "paddle/pten/common/data_type.h" +#include "paddle/pten/core/convert_utils.h" -namespace pt { +namespace pten { void Copy(const CPUContext& dev_ctx, const DenseTensor& src, DenseTensor* dst) { auto* src_ptr = src.data(); @@ -50,9 +50,9 @@ void Copy(const CPUContext& dev_ctx, const DenseTensor& src, DenseTensor* dst) { } } -} // namespace pt +} // namespace pten // TODO(chenweihang): replace by better impl PT_REGISTER_MODULE(UtilsCPU); -PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", CPU, Any, pt::Copy) {} +PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", CPU, Any, pten::Copy) {} diff --git a/paddle/tcmpt/kernels/cpu/utils.h b/paddle/pten/kernels/cpu/utils.h similarity index 87% rename from paddle/tcmpt/kernels/cpu/utils.h rename to paddle/pten/kernels/cpu/utils.h index 95ec606cc37d1..38f601b4cf91f 100644 --- a/paddle/tcmpt/kernels/cpu/utils.h +++ b/paddle/pten/kernels/cpu/utils.h @@ -14,15 +14,15 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/kernel_registry.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" -namespace pt { +namespace pten { using CPUContext = paddle::platform::CPUDeviceContext; void Copy(const CPUContext& dev_ctx, const DenseTensor& src, DenseTensor* dst); -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/kernels/cuda/CMakeLists.txt b/paddle/pten/kernels/cuda/CMakeLists.txt similarity index 94% rename from paddle/tcmpt/kernels/cuda/CMakeLists.txt rename to paddle/pten/kernels/cuda/CMakeLists.txt index e243bad09563b..1271d93558d5b 100644 --- a/paddle/tcmpt/kernels/cuda/CMakeLists.txt +++ b/paddle/pten/kernels/cuda/CMakeLists.txt @@ -1,5 +1,5 @@ if(WIN32) - set(CURRENT_BINARY_DIR ${PADDLE_BINARY_DIR}/paddle/tcmpt/kernels/cuda) + set(CURRENT_BINARY_DIR ${PADDLE_BINARY_DIR}/paddle/pten/kernels/cuda) kernel_instantiate(creation.cu) kernel_instantiate(math.cu) kernel_instantiate(linalg.cu) diff --git a/paddle/tcmpt/kernels/cuda/creation.cu b/paddle/pten/kernels/cuda/creation.cu similarity index 84% rename from paddle/tcmpt/kernels/cuda/creation.cu rename to paddle/pten/kernels/cuda/creation.cu index 54afec95735df..e0732269d874a 100644 --- a/paddle/tcmpt/kernels/cuda/creation.cu +++ b/paddle/pten/kernels/cuda/creation.cu @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/tcmpt/kernels/cuda/creation.h" +#include "paddle/pten/kernels/cuda/creation.h" -#include "paddle/tcmpt/core/kernel_registry.h" -#include "paddle/tcmpt/kernels/common/eigen/fill.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/common/eigen/fill.h" -namespace pt { +namespace pten { template void FillAnyLike(const CUDAContext& dev_ctx, @@ -27,14 +27,14 @@ void FillAnyLike(const CUDAContext& dev_ctx, eigen::fill(dev_ctx, out, val.to()); } -} // namespace pt +} // namespace pten PT_REGISTER_MODULE(CreationCUDA); PT_REGISTER_KERNEL("fill_any_like", CUDA, Any, - pt::FillAnyLike, + pten::FillAnyLike, float, double, int, diff --git a/paddle/tcmpt/kernels/cuda/creation.h b/paddle/pten/kernels/cuda/creation.h similarity index 89% rename from paddle/tcmpt/kernels/cuda/creation.h rename to paddle/pten/kernels/cuda/creation.h index 7de9ce1371fff..21772f1f98d07 100644 --- a/paddle/tcmpt/kernels/cuda/creation.h +++ b/paddle/pten/kernels/cuda/creation.h @@ -17,12 +17,12 @@ // CUDA and HIP use same api #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/scalar.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/scalar.h" #include "paddle/fluid/platform/device_context.h" -namespace pt { +namespace pten { using CUDAContext = paddle::platform::CUDADeviceContext; @@ -32,6 +32,6 @@ void FillAnyLike(const CUDAContext& dev_ctx, const Scalar& val, DenseTensor* out); -} // namespace pt +} // namespace pten #endif diff --git a/paddle/tcmpt/kernels/cuda/linalg.cu b/paddle/pten/kernels/cuda/linalg.cu similarity index 86% rename from paddle/tcmpt/kernels/cuda/linalg.cu rename to paddle/pten/kernels/cuda/linalg.cu index 77001d988038d..a57f230244dbb 100644 --- a/paddle/tcmpt/kernels/cuda/linalg.cu +++ b/paddle/pten/kernels/cuda/linalg.cu @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/tcmpt/kernels/cuda/linalg.h" +#include "paddle/pten/kernels/cuda/linalg.h" -#include "paddle/tcmpt/core/kernel_registry.h" -#include "paddle/tcmpt/kernels/common/eigen/dot.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/common/eigen/dot.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/complex.h" -namespace pt { +namespace pten { template void Dot(const CUDAContext& dev_ctx, @@ -30,7 +30,7 @@ void Dot(const CUDAContext& dev_ctx, eigen::Dot(dev_ctx, x, y, out); } -} // namespace pt +} // namespace pten PT_REGISTER_MODULE(LinalgCUDA); @@ -40,7 +40,7 @@ using complex128 = ::paddle::platform::complex; PT_REGISTER_KERNEL("dot", CUDA, Any, - pt::Dot, + pten::Dot, float, double, int, diff --git a/paddle/tcmpt/kernels/cuda/linalg.h b/paddle/pten/kernels/cuda/linalg.h similarity index 92% rename from paddle/tcmpt/kernels/cuda/linalg.h rename to paddle/pten/kernels/cuda/linalg.h index 20fe0d1a4f49a..ad38f71ec080a 100644 --- a/paddle/tcmpt/kernels/cuda/linalg.h +++ b/paddle/pten/kernels/cuda/linalg.h @@ -17,12 +17,12 @@ // CUDA and HIP use same api #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#include "paddle/tcmpt/core/dense_tensor.h" +#include "paddle/pten/core/dense_tensor.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" -namespace pt { +namespace pten { using CUDAContext = paddle::platform::CUDADeviceContext; @@ -32,6 +32,6 @@ void Dot(const CUDAContext& dev_ctx, const DenseTensor& y, DenseTensor* out); -} // namespace pt +} // namespace pten #endif diff --git a/paddle/tcmpt/kernels/cuda/manipulation.cu b/paddle/pten/kernels/cuda/manipulation.cu similarity index 90% rename from paddle/tcmpt/kernels/cuda/manipulation.cu rename to paddle/pten/kernels/cuda/manipulation.cu index 99ee2506fdf41..2b68d4a292017 100644 --- a/paddle/tcmpt/kernels/cuda/manipulation.cu +++ b/paddle/pten/kernels/cuda/manipulation.cu @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/tcmpt/infershape/unary.h" -#include "paddle/tcmpt/kernels/cuda/manipulation.h" -#include "paddle/tcmpt/kernels/cuda/utils.h" +#include "paddle/pten/infershape/unary.h" +#include "paddle/pten/kernels/cuda/manipulation.h" +#include "paddle/pten/kernels/cuda/utils.h" -namespace pt { +namespace pten { template void Flatten(const CUDAContext& dev_ctx, @@ -25,7 +25,7 @@ void Flatten(const CUDAContext& dev_ctx, int stop_axis, DenseTensor* out) { auto out_meta = FlattenInferShape(x.meta(), start_axis, stop_axis); - pt::Copy(dev_ctx, x, out); + pten::Copy(dev_ctx, x, out); out->mutable_meta()->lod = out_meta.lod; out->Resize(out_meta.dims); } @@ -51,7 +51,7 @@ void FlattenWithXShape(const CUDAContext& dev_ctx, xshape->mutable_meta()->lod = x.meta().lod; } -} // namespace pt +} // namespace pten // TODO(chenweihang): replace by better impl PT_REGISTER_MODULE(ManipulationCUDA); @@ -62,7 +62,7 @@ using float16 = paddle::platform::float16; PT_REGISTER_KERNEL("flatten_contiguous_range", CUDA, Any, - pt::Flatten, + pten::Flatten, float, float16, double, @@ -74,7 +74,7 @@ PT_REGISTER_KERNEL("flatten_contiguous_range", PT_REGISTER_KERNEL("flatten_contiguous_range.mid", CUDA, Any, - pt::FlattenWithXShape, + pten::FlattenWithXShape, float, double, uint8_t, diff --git a/paddle/tcmpt/kernels/cuda/manipulation.h b/paddle/pten/kernels/cuda/manipulation.h similarity index 93% rename from paddle/tcmpt/kernels/cuda/manipulation.h rename to paddle/pten/kernels/cuda/manipulation.h index ca958eab8fa47..ac1cb0324f4ec 100644 --- a/paddle/tcmpt/kernels/cuda/manipulation.h +++ b/paddle/pten/kernels/cuda/manipulation.h @@ -17,12 +17,12 @@ // CUDA and HIP use same api #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#include "paddle/tcmpt/core/dense_tensor.h" +#include "paddle/pten/core/dense_tensor.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" -namespace pt { +namespace pten { using CUDAContext = paddle::platform::CUDADeviceContext; @@ -33,6 +33,6 @@ void Flatten(const CUDAContext& dev_ctx, int stop_axis, DenseTensor* out); -} // namespace pt +} // namespace pten #endif diff --git a/paddle/tcmpt/kernels/cuda/math.cu b/paddle/pten/kernels/cuda/math.cu similarity index 85% rename from paddle/tcmpt/kernels/cuda/math.cu rename to paddle/pten/kernels/cuda/math.cu index 113971126a71f..8a2d1dff9a67b 100644 --- a/paddle/tcmpt/kernels/cuda/math.cu +++ b/paddle/pten/kernels/cuda/math.cu @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/kernels/cuda/math.h" +#include "paddle/pten/kernels/cuda/math.h" -#include "paddle/tcmpt/kernels/common/eigen/mean.h" -#include "paddle/tcmpt/kernels/common/eigen/scale.h" -#include "paddle/tcmpt/kernels/common/eigen/sign.h" +#include "paddle/pten/kernels/common/eigen/mean.h" +#include "paddle/pten/kernels/common/eigen/scale.h" +#include "paddle/pten/kernels/common/eigen/sign.h" #ifdef __NVCC__ #include "cub/cub.cuh" @@ -27,10 +27,10 @@ namespace cub = hipcub; #endif #include "paddle/fluid/platform/float16.h" -#include "paddle/tcmpt/core/convert_utils.h" -#include "paddle/tcmpt/core/kernel_registry.h" +#include "paddle/pten/core/convert_utils.h" +#include "paddle/pten/core/kernel_registry.h" -namespace pt { +namespace pten { /** * Util Functors @@ -74,10 +74,10 @@ void Mean(const CUDAContext& dev_ctx, const DenseTensor& x, DenseTensor* out) { nullptr, temp_storage_bytes, trans_x, out_data, size_prob, stream); PADDLE_ENFORCE_CUDA_SUCCESS(err); - pt::DenseTensor tmp( + pten::DenseTensor tmp( TensorMeta(paddle::framework::make_ddim( {static_cast(temp_storage_bytes)}), - pt::TransToPtBackend(dev_ctx.GetPlace()), + pten::TransToPtBackend(dev_ctx.GetPlace()), x.data_type(), x.layout()), TensorStatus()); @@ -115,18 +115,18 @@ void ScaleHost(const CUDAContext& dev_ctx, out); } -} // namespace pt +} // namespace pten // TODO(chenweihang): replace by better impl PT_REGISTER_MODULE(MathCUDA); using float16 = paddle::platform::float16; -PT_REGISTER_KERNEL("sign", CUDA, Any, pt::Sign, float, double, float16) {} -PT_REGISTER_KERNEL("mean", CUDA, Any, pt::Mean, float, double, float16) {} +PT_REGISTER_KERNEL("sign", CUDA, Any, pten::Sign, float, double, float16) {} +PT_REGISTER_KERNEL("mean", CUDA, Any, pten::Mean, float, double, float16) {} PT_REGISTER_KERNEL("scale", CUDA, Any, - pt::Scale, + pten::Scale, float, double, float16, @@ -138,7 +138,7 @@ PT_REGISTER_KERNEL("scale", PT_REGISTER_KERNEL("scale.host", CUDA, Any, - pt::ScaleHost, + pten::ScaleHost, float, double, float16, @@ -147,5 +147,5 @@ PT_REGISTER_KERNEL("scale.host", int16_t, int, int64_t) { - kernel->InputAt(1).SetBackend(pt::Backend::kCPU); + kernel->InputAt(1).SetBackend(pten::Backend::kCPU); } diff --git a/paddle/tcmpt/kernels/cuda/math.h b/paddle/pten/kernels/cuda/math.h similarity index 94% rename from paddle/tcmpt/kernels/cuda/math.h rename to paddle/pten/kernels/cuda/math.h index dc8221d6345d6..65f4f41265836 100644 --- a/paddle/tcmpt/kernels/cuda/math.h +++ b/paddle/pten/kernels/cuda/math.h @@ -17,12 +17,12 @@ limitations under the License. */ // CUDA and HIP use same api #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#include "paddle/tcmpt/core/dense_tensor.h" +#include "paddle/pten/core/dense_tensor.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" -namespace pt { +namespace pten { using CUDAContext = paddle::platform::CUDADeviceContext; @@ -48,6 +48,6 @@ void ScaleHost(const CUDAContext& dev_ctx, bool bias_after_scale, DenseTensor* out); -} // namespace pt +} // namespace pten #endif diff --git a/paddle/tcmpt/kernels/cuda/utils.cu b/paddle/pten/kernels/cuda/utils.cu similarity index 97% rename from paddle/tcmpt/kernels/cuda/utils.cu rename to paddle/pten/kernels/cuda/utils.cu index 00b32e2fbb10a..0c83c1c5c3cae 100644 --- a/paddle/tcmpt/kernels/cuda/utils.cu +++ b/paddle/pten/kernels/cuda/utils.cu @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/memory/memcpy.h" -#include "paddle/tcmpt/common/data_type.h" -#include "paddle/tcmpt/core/convert_utils.h" -#include "paddle/tcmpt/core/kernel_registry.h" -#include "paddle/tcmpt/kernels/cuda/utils.h" +#include "paddle/pten/common/data_type.h" +#include "paddle/pten/core/convert_utils.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/cuda/utils.h" -namespace pt { +namespace pten { void Copy(const CUDAContext& dev_ctx, const DenseTensor& src, @@ -215,9 +215,9 @@ void Copy(const CUDAContext& dev_ctx, } } -} // namespace pt +} // namespace pten // TODO(chenweihang): replace by better impl PT_REGISTER_MODULE(UtilsCUDA); -PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", CUDA, Any, pt::Copy) {} +PT_REGISTER_KERNEL_WITH_NO_TYPE("copy", CUDA, Any, pten::Copy) {} diff --git a/paddle/tcmpt/kernels/cuda/utils.h b/paddle/pten/kernels/cuda/utils.h similarity index 87% rename from paddle/tcmpt/kernels/cuda/utils.h rename to paddle/pten/kernels/cuda/utils.h index 4d3196b2f877b..a8a6838f4602a 100644 --- a/paddle/tcmpt/kernels/cuda/utils.h +++ b/paddle/pten/kernels/cuda/utils.h @@ -14,15 +14,15 @@ limitations under the License. */ #pragma once -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/kernel_registry.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" -namespace pt { +namespace pten { using CUDAContext = paddle::platform::CUDADeviceContext; void Copy(const CUDAContext& dev_ctx, const DenseTensor& src, DenseTensor* dst); -} // namespace pt +} // namespace pten diff --git a/paddle/tcmpt/kernels/mkldnn/CMakeLists.txt b/paddle/pten/kernels/mkldnn/CMakeLists.txt similarity index 100% rename from paddle/tcmpt/kernels/mkldnn/CMakeLists.txt rename to paddle/pten/kernels/mkldnn/CMakeLists.txt diff --git a/paddle/tcmpt/kernels/npu/CMakeLists.txt b/paddle/pten/kernels/npu/CMakeLists.txt similarity index 100% rename from paddle/tcmpt/kernels/npu/CMakeLists.txt rename to paddle/pten/kernels/npu/CMakeLists.txt diff --git a/paddle/tcmpt/kernels/xpu/CMakeLists.txt b/paddle/pten/kernels/xpu/CMakeLists.txt similarity index 100% rename from paddle/tcmpt/kernels/xpu/CMakeLists.txt rename to paddle/pten/kernels/xpu/CMakeLists.txt diff --git a/paddle/tcmpt/module/CMakeLists.txt b/paddle/pten/module/CMakeLists.txt similarity index 100% rename from paddle/tcmpt/module/CMakeLists.txt rename to paddle/pten/module/CMakeLists.txt diff --git a/paddle/tcmpt/tests/CMakeLists.txt b/paddle/pten/tests/CMakeLists.txt similarity index 100% rename from paddle/tcmpt/tests/CMakeLists.txt rename to paddle/pten/tests/CMakeLists.txt diff --git a/paddle/tcmpt/tests/backend_test.cc b/paddle/pten/tests/backend_test.cc similarity index 94% rename from paddle/tcmpt/tests/backend_test.cc rename to paddle/pten/tests/backend_test.cc index 026e94ec4d0e7..46e099e216c41 100644 --- a/paddle/tcmpt/tests/backend_test.cc +++ b/paddle/pten/tests/backend_test.cc @@ -12,6 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/core/backend.h" +#include "paddle/pten/core/backend.h" #include diff --git a/paddle/tcmpt/tests/dense_tensor_test.cc b/paddle/pten/tests/dense_tensor_test.cc similarity index 62% rename from paddle/tcmpt/tests/dense_tensor_test.cc rename to paddle/pten/tests/dense_tensor_test.cc index 138ef1e30e76e..db747e15a8db7 100644 --- a/paddle/tcmpt/tests/dense_tensor_test.cc +++ b/paddle/pten/tests/dense_tensor_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/core/dense_tensor.h" +#include "paddle/pten/core/dense_tensor.h" #include @@ -20,16 +20,17 @@ namespace framework = paddle::framework; using DDim = paddle::framework::DDim; TEST(DenseTensor, Constructor) { - pt::DenseTensor tensor(pt::TensorMeta(framework::make_ddim({5, 10}), - pt::Backend::kCPU, - pt::DataType::kFLOAT32, - pt::DataLayout::kNCHW, - 0UL), - pt::TensorStatus()); + pten::DenseTensor tensor( + pten::TensorMeta(framework::make_ddim({5, 10}), + pten::Backend::kCPU, + paddle::experimental::DataType::kFLOAT32, + paddle::experimental::DataLayout::kNCHW, + 0UL), + pten::TensorStatus()); ASSERT_EQ(tensor.dims().size(), 2); - ASSERT_EQ(tensor.backend(), pt::Backend::kCPU); - ASSERT_EQ(tensor.data_type(), pt::DataType::kFLOAT32); - ASSERT_EQ(tensor.layout(), pt::DataLayout::kNCHW); + ASSERT_EQ(tensor.backend(), pten::Backend::kCPU); + ASSERT_EQ(tensor.data_type(), paddle::experimental::DataType::kFLOAT32); + ASSERT_EQ(tensor.layout(), paddle::experimental::DataLayout::kNCHW); } TEST(DenseTensor, Dims) { diff --git a/paddle/tcmpt/tests/dtype_test.cc b/paddle/pten/tests/dtype_test.cc similarity index 100% rename from paddle/tcmpt/tests/dtype_test.cc rename to paddle/pten/tests/dtype_test.cc diff --git a/paddle/tcmpt/tests/kernel_factory_test.cc b/paddle/pten/tests/kernel_factory_test.cc similarity index 75% rename from paddle/tcmpt/tests/kernel_factory_test.cc rename to paddle/pten/tests/kernel_factory_test.cc index 66ce7cd9892ef..a3ac561d6364a 100644 --- a/paddle/tcmpt/tests/kernel_factory_test.cc +++ b/paddle/pten/tests/kernel_factory_test.cc @@ -12,12 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/tcmpt/core/kernel_factory.h" +#include "paddle/pten/core/kernel_factory.h" #include "gtest/gtest.h" TEST(KernelFactory, KernelKey) { - pt::KernelKey key( - pt::Backend::kCPU, pt::DataLayout::kNCHW, pt::DataType::kFLOAT32); + pten::KernelKey key(pten::Backend::kCPU, + paddle::experimental::DataLayout::kNCHW, + paddle::experimental::DataType::kFLOAT32); std::cout << key; } diff --git a/paddle/tcmpt/tests/layout_test.cc b/paddle/pten/tests/layout_test.cc similarity index 100% rename from paddle/tcmpt/tests/layout_test.cc rename to paddle/pten/tests/layout_test.cc diff --git a/paddle/tcmpt/tests/test_copy_api.cc b/paddle/pten/tests/test_copy_api.cc similarity index 64% rename from paddle/tcmpt/tests/test_copy_api.cc rename to paddle/pten/tests/test_copy_api.cc index 2d70e37d051d9..3307ffeb1943b 100644 --- a/paddle/tcmpt/tests/test_copy_api.cc +++ b/paddle/pten/tests/test_copy_api.cc @@ -15,10 +15,10 @@ limitations under the License. */ #include #include -#include "paddle/tcmpt/core/kernel_registry.h" -#include "paddle/tcmpt/kernels/cpu/utils.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/cpu/utils.h" -#include "paddle/tcmpt/core/dense_tensor.h" +#include "paddle/pten/core/dense_tensor.h" PT_DECLARE_MODULE(UtilsCPU); @@ -30,20 +30,20 @@ using DDim = paddle::framework::DDim; // 'paddle/api', TEST(API, copy) { // 1. create tensor - auto dense_src = std::make_shared( - pt::TensorMeta(framework::make_ddim({2, 3}), - pt::Backend::kCPU, - pt::DataType::kFLOAT32, - pt::DataLayout::kNCHW), - pt::TensorStatus()); + auto dense_src = std::make_shared( + pten::TensorMeta(framework::make_ddim({2, 3}), + pten::Backend::kCPU, + paddle::experimental::DataType::kFLOAT32, + paddle::experimental::DataLayout::kNCHW), + pten::TensorStatus()); auto* dense_x_data = dense_src->mutable_data(); - auto dense_dst = std::make_shared( - pt::TensorMeta(framework::make_ddim({2, 3}), - pt::Backend::kCPU, - pt::DataType::kFLOAT32, - pt::DataLayout::kNCHW), - pt::TensorStatus()); + auto dense_dst = std::make_shared( + pten::TensorMeta(framework::make_ddim({2, 3}), + pten::Backend::kCPU, + paddle::experimental::DataType::kFLOAT32, + paddle::experimental::DataLayout::kNCHW), + pten::TensorStatus()); for (size_t i = 0; i < 2; ++i) { for (size_t j = 0; j < 3; ++j) { @@ -55,7 +55,7 @@ TEST(API, copy) { // 2. test API auto& pool = paddle::platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.GetByPlace(paddle::platform::CPUPlace()); - pt::Copy(*dev_ctx, *(dense_src.get()), dense_dst.get()); + pten::Copy(*dev_ctx, *(dense_src.get()), dense_dst.get()); // 3. check result for (int64_t i = 0; i < dense_src->numel(); i++) { diff --git a/paddle/tcmpt/tests/test_dot_api.cc b/paddle/pten/tests/test_dot_api.cc similarity index 67% rename from paddle/tcmpt/tests/test_dot_api.cc rename to paddle/pten/tests/test_dot_api.cc index 8fdae5050e239..967f1a8f17c1c 100644 --- a/paddle/tcmpt/tests/test_dot_api.cc +++ b/paddle/pten/tests/test_dot_api.cc @@ -15,10 +15,10 @@ limitations under the License. */ #include #include -#include "paddle/tcmpt/hapi/include/linalg.h" +#include "paddle/pten/hapi/include/linalg.h" -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/kernel_registry.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" PT_DECLARE_MODULE(LinalgCPU); @@ -31,20 +31,20 @@ using DDim = paddle::framework::DDim; TEST(API, dot) { // 1. create tensor - auto dense_x = std::make_shared( - pt::TensorMeta(framework::make_ddim({3, 10}), - pt::Backend::kCPU, - pt::DataType::kFLOAT32, - pt::DataLayout::kNCHW), - pt::TensorStatus()); + auto dense_x = std::make_shared( + pten::TensorMeta(framework::make_ddim({3, 10}), + pten::Backend::kCPU, + paddle::experimental::DataType::kFLOAT32, + paddle::experimental::DataLayout::kNCHW), + pten::TensorStatus()); auto* dense_x_data = dense_x->mutable_data(); - auto dense_y = std::make_shared( - pt::TensorMeta(framework::make_ddim({3, 10}), - pt::Backend::kCPU, - pt::DataType::kFLOAT32, - pt::DataLayout::kNCHW), - pt::TensorStatus()); + auto dense_y = std::make_shared( + pten::TensorMeta(framework::make_ddim({3, 10}), + pten::Backend::kCPU, + paddle::experimental::DataType::kFLOAT32, + paddle::experimental::DataLayout::kNCHW), + pten::TensorStatus()); auto* dense_y_data = dense_y->mutable_data(); float sum[3] = {0.0, 0.0, 0.0}; @@ -67,12 +67,12 @@ TEST(API, dot) { ASSERT_EQ(out.shape()[0], 3); ASSERT_EQ(out.numel(), 3); ASSERT_EQ(out.is_cpu(), true); - ASSERT_EQ(out.type(), pt::DataType::kFLOAT32); - ASSERT_EQ(out.layout(), pt::DataLayout::kNCHW); + ASSERT_EQ(out.type(), paddle::experimental::DataType::kFLOAT32); + ASSERT_EQ(out.layout(), paddle::experimental::DataLayout::kNCHW); ASSERT_EQ(out.initialized(), true); auto expect_result = sum; - auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto dense_out = std::dynamic_pointer_cast(out.impl()); auto actual_result0 = dense_out->data()[0]; auto actual_result1 = dense_out->data()[1]; auto actual_result2 = dense_out->data()[2]; diff --git a/paddle/tcmpt/tests/test_fill_api.cc b/paddle/pten/tests/test_fill_api.cc similarity index 54% rename from paddle/tcmpt/tests/test_fill_api.cc rename to paddle/pten/tests/test_fill_api.cc index 0ed7248604654..5c044f520af07 100644 --- a/paddle/tcmpt/tests/test_fill_api.cc +++ b/paddle/pten/tests/test_fill_api.cc @@ -15,10 +15,10 @@ limitations under the License. */ #include #include -#include "paddle/tcmpt/hapi/include/creation.h" +#include "paddle/pten/hapi/include/creation.h" -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/kernel_registry.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" PT_DECLARE_MODULE(CreationCPU); @@ -31,12 +31,12 @@ using DDim = paddle::framework::DDim; TEST(API, full_like) { // 1. create tensor - auto dense_x = std::make_shared( - pt::TensorMeta(framework::make_ddim({3, 2}), - pt::Backend::kCPU, - pt::DataType::kFLOAT32, - pt::DataLayout::kNCHW), - pt::TensorStatus()); + auto dense_x = std::make_shared( + pten::TensorMeta(framework::make_ddim({3, 2}), + pten::Backend::kCPU, + paddle::experimental::DataType::kFLOAT32, + paddle::experimental::DataLayout::kNCHW), + pten::TensorStatus()); auto* dense_x_data = dense_x->mutable_data(); dense_x_data[0] = 0; @@ -45,18 +45,19 @@ TEST(API, full_like) { paddle::experimental::Tensor x(dense_x); // 2. test API - auto out = paddle::experimental::full_like(x, val, pt::DataType::kFLOAT32); + auto out = paddle::experimental::full_like( + x, val, paddle::experimental::DataType::kFLOAT32); // 3. check result ASSERT_EQ(out.shape().size(), 2); ASSERT_EQ(out.shape()[0], 3); ASSERT_EQ(out.numel(), 6); ASSERT_EQ(out.is_cpu(), true); - ASSERT_EQ(out.type(), pt::DataType::kFLOAT32); - ASSERT_EQ(out.layout(), pt::DataLayout::kNCHW); + ASSERT_EQ(out.type(), paddle::experimental::DataType::kFLOAT32); + ASSERT_EQ(out.layout(), paddle::experimental::DataLayout::kNCHW); ASSERT_EQ(out.initialized(), true); - auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto dense_out = std::dynamic_pointer_cast(out.impl()); auto* actual_result = dense_out->data(); for (auto i = 0; i < 6; i++) { ASSERT_NEAR(actual_result[i], val, 1e-6f); @@ -65,30 +66,31 @@ TEST(API, full_like) { TEST(API, zeros_like) { // 1. create tensor - auto dense_x = std::make_shared( - pt::TensorMeta(framework::make_ddim({3, 2}), - pt::Backend::kCPU, - pt::DataType::kFLOAT32, - pt::DataLayout::kNCHW), - pt::TensorStatus()); + auto dense_x = std::make_shared( + pten::TensorMeta(framework::make_ddim({3, 2}), + pten::Backend::kCPU, + paddle::experimental::DataType::kFLOAT32, + paddle::experimental::DataLayout::kNCHW), + pten::TensorStatus()); auto* dense_x_data = dense_x->mutable_data(); dense_x_data[0] = 1; paddle::experimental::Tensor x(dense_x); // 2. test API - auto out = paddle::experimental::zeros_like(x, pt::DataType::kFLOAT32); + auto out = paddle::experimental::zeros_like( + x, paddle::experimental::DataType::kFLOAT32); // 3. check result ASSERT_EQ(out.shape().size(), 2); ASSERT_EQ(out.shape()[0], 3); ASSERT_EQ(out.numel(), 6); ASSERT_EQ(out.is_cpu(), true); - ASSERT_EQ(out.type(), pt::DataType::kFLOAT32); - ASSERT_EQ(out.layout(), pt::DataLayout::kNCHW); + ASSERT_EQ(out.type(), paddle::experimental::DataType::kFLOAT32); + ASSERT_EQ(out.layout(), paddle::experimental::DataLayout::kNCHW); ASSERT_EQ(out.initialized(), true); - auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto dense_out = std::dynamic_pointer_cast(out.impl()); auto* actual_result = dense_out->data(); for (auto i = 0; i < 6; i++) { ASSERT_NEAR(actual_result[i], 0, 1e-6f); @@ -97,30 +99,31 @@ TEST(API, zeros_like) { TEST(API, ones_like) { // 1. create tensor - auto dense_x = std::make_shared( - pt::TensorMeta(framework::make_ddim({3, 2}), - pt::Backend::kCPU, - pt::DataType::kFLOAT32, - pt::DataLayout::kNCHW), - pt::TensorStatus()); + auto dense_x = std::make_shared( + pten::TensorMeta(framework::make_ddim({3, 2}), + pten::Backend::kCPU, + paddle::experimental::DataType::kFLOAT32, + paddle::experimental::DataLayout::kNCHW), + pten::TensorStatus()); auto* dense_x_data = dense_x->mutable_data(); dense_x_data[0] = 0; paddle::experimental::Tensor x(dense_x); // 2. test API - auto out = paddle::experimental::ones_like(x, pt::DataType::kINT32); + auto out = paddle::experimental::ones_like( + x, paddle::experimental::DataType::kINT32); // 3. check result ASSERT_EQ(out.shape().size(), 2); ASSERT_EQ(out.shape()[0], 3); ASSERT_EQ(out.numel(), 6); ASSERT_EQ(out.is_cpu(), true); - ASSERT_EQ(out.type(), pt::DataType::kINT32); - ASSERT_EQ(out.layout(), pt::DataLayout::kNCHW); + ASSERT_EQ(out.type(), paddle::experimental::DataType::kINT32); + ASSERT_EQ(out.layout(), paddle::experimental::DataLayout::kNCHW); ASSERT_EQ(out.initialized(), true); - auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto dense_out = std::dynamic_pointer_cast(out.impl()); auto* actual_result = dense_out->data(); for (auto i = 0; i < 6; i++) { ASSERT_EQ(actual_result[i], 1); diff --git a/paddle/tcmpt/tests/test_flatten_api.cc b/paddle/pten/tests/test_flatten_api.cc similarity index 72% rename from paddle/tcmpt/tests/test_flatten_api.cc rename to paddle/pten/tests/test_flatten_api.cc index d2e3ee4278e1d..1deb41f3a6722 100644 --- a/paddle/tcmpt/tests/test_flatten_api.cc +++ b/paddle/pten/tests/test_flatten_api.cc @@ -15,10 +15,10 @@ limitations under the License. */ #include #include -#include "paddle/tcmpt/hapi/include/manipulation.h" +#include "paddle/pten/hapi/include/manipulation.h" -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/kernel_registry.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" PT_DECLARE_MODULE(ManipulationCPU); @@ -31,12 +31,12 @@ using DDim = paddle::framework::DDim; TEST(API, flatten) { // 1. create tensor - auto dense_x = std::make_shared( - pt::TensorMeta(framework::make_ddim({3, 2, 2, 3}), - pt::Backend::kCPU, - pt::DataType::kFLOAT32, - pt::DataLayout::kNCHW), - pt::TensorStatus()); + auto dense_x = std::make_shared( + pten::TensorMeta(framework::make_ddim({3, 2, 2, 3}), + pten::Backend::kCPU, + paddle::experimental::DataType::kFLOAT32, + paddle::experimental::DataLayout::kNCHW), + pten::TensorStatus()); auto* dense_x_data = dense_x->mutable_data(); for (int i = 0; i < dense_x->numel(); i++) { @@ -55,11 +55,11 @@ TEST(API, flatten) { ASSERT_EQ(out.shape()[2], expect_shape[2]); ASSERT_EQ(out.numel(), 36); ASSERT_EQ(out.is_cpu(), true); - ASSERT_EQ(out.type(), pt::DataType::kFLOAT32); - ASSERT_EQ(out.layout(), pt::DataLayout::kNCHW); + ASSERT_EQ(out.type(), paddle::experimental::DataType::kFLOAT32); + ASSERT_EQ(out.layout(), paddle::experimental::DataLayout::kNCHW); ASSERT_EQ(out.initialized(), true); bool value_equal = true; - auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto dense_out = std::dynamic_pointer_cast(out.impl()); auto* dense_out_data = dense_out->data(); for (int i = 0; i < dense_x->numel(); i++) { if (std::abs(dense_x_data[i] - dense_out_data[i]) > 1e-6f) diff --git a/paddle/tcmpt/tests/test_mean_api.cc b/paddle/pten/tests/test_mean_api.cc similarity index 69% rename from paddle/tcmpt/tests/test_mean_api.cc rename to paddle/pten/tests/test_mean_api.cc index 518a98738961c..fbcd375d51328 100644 --- a/paddle/tcmpt/tests/test_mean_api.cc +++ b/paddle/pten/tests/test_mean_api.cc @@ -15,10 +15,10 @@ limitations under the License. */ #include #include -#include "paddle/tcmpt/hapi/include/math.h" +#include "paddle/pten/hapi/include/math.h" -#include "paddle/tcmpt/core/dense_tensor.h" -#include "paddle/tcmpt/core/kernel_registry.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" PT_DECLARE_MODULE(MathCPU); @@ -31,12 +31,12 @@ using DDim = paddle::framework::DDim; TEST(API, mean) { // 1. create tensor - auto dense_x = std::make_shared( - pt::TensorMeta(framework::make_ddim({3, 4}), - pt::Backend::kCPU, - pt::DataType::kFLOAT32, - pt::DataLayout::kNCHW), - pt::TensorStatus()); + auto dense_x = std::make_shared( + pten::TensorMeta(framework::make_ddim({3, 4}), + pten::Backend::kCPU, + paddle::experimental::DataType::kFLOAT32, + paddle::experimental::DataLayout::kNCHW), + pten::TensorStatus()); auto* dense_x_data = dense_x->mutable_data(); float sum = 0.0; @@ -55,12 +55,12 @@ TEST(API, mean) { ASSERT_EQ(out.shape()[0], 1); ASSERT_EQ(out.numel(), 1); ASSERT_EQ(out.is_cpu(), true); - ASSERT_EQ(out.type(), pt::DataType::kFLOAT32); - ASSERT_EQ(out.layout(), pt::DataLayout::kNCHW); + ASSERT_EQ(out.type(), paddle::experimental::DataType::kFLOAT32); + ASSERT_EQ(out.layout(), paddle::experimental::DataLayout::kNCHW); ASSERT_EQ(out.initialized(), true); auto expect_result = sum / 12; - auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto dense_out = std::dynamic_pointer_cast(out.impl()); auto actual_result = dense_out->data()[0]; ASSERT_NEAR(expect_result, actual_result, 1e-6f); } diff --git a/paddle/tcmpt/CMakeLists.txt b/paddle/tcmpt/CMakeLists.txt deleted file mode 100644 index 0187a63c2ff6d..0000000000000 --- a/paddle/tcmpt/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -include(tcmpt) -# tcmpt api -add_subdirectory(api) -# tcmpt high level api -add_subdirectory(hapi) -# tcmpt core components -add_subdirectory(core) -# tcmpt kernels for diff device -add_subdirectory(kernels) -# tcmpt infershape -add_subdirectory(infershape) -# TODO(xingfeng): tcmpt inner module API designed by a high-performance team -add_subdirectory(module) -# tcmpt tests -add_subdirectory(tests) diff --git a/paddle/tcmpt/api/CMakeLists.txt b/paddle/tcmpt/api/CMakeLists.txt deleted file mode 100644 index bf4d163a62bfc..0000000000000 --- a/paddle/tcmpt/api/CMakeLists.txt +++ /dev/null @@ -1,21 +0,0 @@ -# set(declare_file ${PADDLE_BINARY_DIR}/paddle/tcmpt/api/symbols.h.tmp CACHE INTERNAL "symbols.h file") -# set(declare_file_final ${PADDLE_BINARY_DIR}/paddle/tcmpt/api/symbols.h) -# file(WRITE ${declare_file} "// Generated by the paddle/tcmpt/api/CMakeLists.txt. DO NOT EDIT!\n\n") - -# function(declare_module TARGTE) -# file(APPEND ${declare_file} "extern int RegisterSymbolsFor${TARGET}();\n") -# message(STATUS "") -# endfunction() - -# TODO(chenweihang): unify decclare into **_library -# declare_module(MathCPU) -# declare_module(MathCUDA) - -set(TCMPT_DEPS convert_utils dense_tensor kernel_factory kernel_context) -set(TCMPT_DEPS ${TCMPT_DEPS} math_cpu linalg_cpu creation_cpu manipulation_cpu) -set(TCMPT_DEPS ${TCMPT_DEPS} unary binary) -if(WITH_GPU OR WITH_ROCM) - set(TCMPT_DEPS ${TCMPT_DEPS} math_cuda linalg_cuda creation_cuda manipulation_cuda) -endif() - -cc_library(tcmpt SRCS all.cc DEPS ${TCMPT_DEPS}) diff --git a/paddle/tcmpt/hapi/CMakeLists.txt b/paddle/tcmpt/hapi/CMakeLists.txt deleted file mode 100644 index ebc247ef8a2e2..0000000000000 --- a/paddle/tcmpt/hapi/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -add_subdirectory(lib) - -cc_library(tcmpt_hapi SRCS all.cc DEPS math_api linalg_api creation_api) diff --git a/paddle/tcmpt/hapi/lib/CMakeLists.txt b/paddle/tcmpt/hapi/lib/CMakeLists.txt deleted file mode 100644 index 74467603c62b6..0000000000000 --- a/paddle/tcmpt/hapi/lib/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -cc_library(math_api SRCS math.cc DEPS tcmpt) -cc_library(linalg_api SRCS linalg.cc DEPS tcmpt) -cc_library(creation_api SRCS creation.cc DEPS tcmpt) -cc_library(manipulation_api SRCS manipulation.cc DEPS tcmpt)