From 854a7ab3589704499a8332b9967011c4457fd507 Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Fri, 21 Jan 2022 18:46:24 +0800 Subject: [PATCH 01/14] add pten dependency to infrt (#39079) * add pten dependency to infrt * fix code style * add pten::CPUContext * revert .ignore --- paddle/infrt/CMakeLists.txt | 7 ++++-- paddle/infrt/host_context/value.h | 5 ++++ paddle/infrt/kernel/CMakeLists.txt | 1 + paddle/infrt/kernel/pten_kernels.cc | 37 +++++++++++++++++++++++++++++ paddle/infrt/kernel/pten_kernels.h | 35 +++++++++++++++++++++++++++ paddle/scripts/infrt_build.sh | 0 6 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 paddle/infrt/kernel/pten_kernels.cc create mode 100644 paddle/infrt/kernel/pten_kernels.h mode change 100644 => 100755 paddle/scripts/infrt_build.sh diff --git a/paddle/infrt/CMakeLists.txt b/paddle/infrt/CMakeLists.txt index 8af3012a220ad..e371e2391829d 100644 --- a/paddle/infrt/CMakeLists.txt +++ b/paddle/infrt/CMakeLists.txt @@ -1,3 +1,6 @@ +#TO DO:remove fluid +include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform) + if (NOT WITH_INFRT) return() endif() @@ -88,8 +91,8 @@ set(infrt_mlir_incs ) message(STATUS "infrt srcs:\n${infrt_src}") -cc_library(infrt SHARED SRCS ${infrt_src} DEPS glog boost ${mlir_libs} paddle_framework_proto) -cc_library(infrt_static SRCS ${infrt_src} DEPS glog boost ${mlir_libs} paddle_framework_proto) +cc_library(infrt SHARED SRCS ${infrt_src} DEPS glog boost ${mlir_libs} paddle_framework_proto pten dense_tensor) +cc_library(infrt_static SRCS ${infrt_src} DEPS glog boost ${mlir_libs} paddle_framework_proto pten dense_tensor) add_dependencies(infrt ${infrt_mlir_incs}) add_custom_target(test_infrt_exec DEPENDS ${INFRT_TEST_TARGETS}) diff --git a/paddle/infrt/host_context/value.h b/paddle/infrt/host_context/value.h index 4a2b92a7e69c5..7f68e59f8a698 100644 --- a/paddle/infrt/host_context/value.h +++ b/paddle/infrt/host_context/value.h @@ -29,6 +29,9 @@ #include "paddle/infrt/tensor/tensor_map.h" #include "paddle/infrt/tensor/tensor_shape.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/dense_tensor.h" + namespace infrt { namespace host_context { @@ -45,6 +48,8 @@ using ValueVariantType = Variant, std::vector, std::vector, diff --git a/paddle/infrt/kernel/CMakeLists.txt b/paddle/infrt/kernel/CMakeLists.txt index da858aad28f81..7e9ed8e5572c0 100644 --- a/paddle/infrt/kernel/CMakeLists.txt +++ b/paddle/infrt/kernel/CMakeLists.txt @@ -2,6 +2,7 @@ core_gather_headers() gather_srcs(infrt_src SRCS basic_kernels.cc + pten_kernels.cc test_kernels.cc tensor_shape_kernels.cc tensor_kernels.cc diff --git a/paddle/infrt/kernel/pten_kernels.cc b/paddle/infrt/kernel/pten_kernels.cc new file mode 100644 index 0000000000000..70c44b829f774 --- /dev/null +++ b/paddle/infrt/kernel/pten_kernels.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/kernel_utils.h" +#include "paddle/infrt/kernel/pten_kernels.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/kernels/math_kernel.h" + +using infrt::host_context::Attribute; + +namespace infrt { +namespace kernel { + +void RegisterPtenKernels(host_context::KernelRegistry* registry) { + registry->AddKernel("pd_cpu.add.float32", + INFRT_KERNEL(pten::AddKernel)); + registry->AddKernel("pd_cpu.add.int32", + INFRT_KERNEL(pten::AddKernel)); +} + +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/pten_kernels.h b/paddle/infrt/kernel/pten_kernels.h new file mode 100644 index 0000000000000..c290f8ea524fb --- /dev/null +++ b/paddle/infrt/kernel/pten_kernels.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace infrt { +namespace host_context { + +struct KernelRegistry; + +} // namespace host_context +} // namespace infrt + +namespace infrt { +namespace kernel { + +/** + * Register all the pten kernels to registry. + */ +void RegisterPtenKernels(host_context::KernelRegistry* registry); + +} // namespace kernel +} // namespace infrt diff --git a/paddle/scripts/infrt_build.sh b/paddle/scripts/infrt_build.sh old mode 100644 new mode 100755 From a0f586bc626b3fddcc104e46e521e37bc7e4e302 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 21 Jan 2022 20:03:11 +0800 Subject: [PATCH 02/14] [PTen]Separate origin Kernel and add Kernel for C++ API (#39002) * add kernel for c++ api * fix compile bugs * fix kunlun compile bugs * perfect cmake * fix compile bugs when run ci-inference * fix compile bugs * add non-raw kernel for fluid op * fix compile bugs * fix compile bugs * fix unit test bug --- cmake/pten_kernel.cmake | 61 +++-- paddle/fluid/operators/cholesky_solve_op.h | 2 +- .../elementwise/elementwise_add_op.h | 2 +- .../elementwise/elementwise_div_op.h | 2 +- .../elementwise/elementwise_mul_op.cu | 4 +- .../elementwise/elementwise_mul_op.h | 2 +- .../operators/elementwise/elementwise_op.h | 24 +- .../elementwise/elementwise_sub_op.h | 2 +- paddle/fluid/operators/lu_op.h | 4 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 13 +- paddle/pten/api/include/kernel_signature.h | 6 - paddle/pten/core/kernel_alias_name.h | 12 +- paddle/pten/kernels/cpu/math_kernel.cc | 76 +++---- paddle/pten/kernels/gpu/math_kernel.cu | 77 ++++--- paddle/pten/kernels/math_kernel.cc | 212 ++++++++++++++++++ paddle/pten/kernels/math_kernel.h | 125 ++++++----- .../tests/kernels/test_elementwise_dev_api.cc | 12 +- python/paddle/utils/code_gen/api.yaml | 7 +- 18 files changed, 453 insertions(+), 190 deletions(-) create mode 100644 paddle/pten/kernels/math_kernel.cc diff --git a/cmake/pten_kernel.cmake b/cmake/pten_kernel.cmake index bc9fefb58f452..c2928376a02f8 100644 --- a/cmake/pten_kernel.cmake +++ b/cmake/pten_kernel.cmake @@ -103,38 +103,55 @@ function(kernel_library TARGET) list(LENGTH gpu_srcs gpu_srcs_len) list(LENGTH xpu_srcs xpu_srcs_len) - if (${common_srcs_len} GREATER 0) - # If the kernel has a device independent public implementation, - # we will use this implementation and will not adopt the implementation - # under specific devices + # Build Target according different src organization + if((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR + ${xpu_srcs_len} GREATER 0) AND ${common_srcs_len} GREATER 0) + # If the common_srcs depends on specific device srcs, build target using this rule. + if (WITH_GPU) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + nv_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + nv_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part) + endif() + elseif (WITH_ROCM) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + hip_library(${TARGET}_part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + hip_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part) + endif() + else() + if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) + cc_library(${TARGET}_part SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + cc_library(${TARGET} SRCS ${common_srcs} DEPS ${TARGET}_part) + endif() + endif() + elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) if (WITH_GPU) - nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() elseif (WITH_ROCM) - hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) + hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() else() - cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) + cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) + endif() endif() else() - # If the kernel has a header file declaration, but no corresponding - # implementation can be found, this is not allowed - if (${cpu_srcs_len} EQUAL 0 AND ${gpu_srcs_len} EQUAL 0 AND - ${xpu_srcs_len} EQUAL 0) - message(FATAL_ERROR "Cannot find any implementation for ${TARGET}") + if (${common_srcs_len} EQUAL 0) + message(FATAL_ERROR "Cannot find any implementation for ${TARGET}") else() + # If the kernel has a device independent public implementation, + # we will use this implementation and will not adopt the implementation + # under specific devices if (WITH_GPU) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) - nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) - endif() + nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) elseif (WITH_ROCM) - if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0) - hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) - endif() + hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) else() - if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0) - cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) - endif() + cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps}) endif() - endif() + endif() endif() if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR diff --git a/paddle/fluid/operators/cholesky_solve_op.h b/paddle/fluid/operators/cholesky_solve_op.h index 4b1d075de91ca..5004aad7c59bc 100644 --- a/paddle/fluid/operators/cholesky_solve_op.h +++ b/paddle/fluid/operators/cholesky_solve_op.h @@ -202,7 +202,7 @@ class CholeskySolveGradKernel : public framework::OpKernel { commonterm_for_range(commonterm_functor); commonterm_conj = helper.Transpose(commonterm_conj); - pten::AddKernel( + pten::AddRawKernel( static_cast::TYPE &>(dev_ctx), commonterm, commonterm_conj, -1, &commonterm); diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index a4897a06d5611..5c4f791b2270c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -61,7 +61,7 @@ class ElementwiseAddKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - pten::AddKernel( + pten::AddRawKernel( static_cast::TYPE &>(dev_ctx), *pt_x.get(), *pt_y.get(), axis, pt_z.get()); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 44f695278dca8..a45f09b63e9fe 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -51,7 +51,7 @@ class ElementwiseDivKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - pten::DivideKernel( + pten::DivideRawKernel( static_cast::TYPE&>(dev_ctx), *pt_x.get(), *pt_y.get(), axis, pt_z.get()); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 86a803106347d..0c7d12ae0ad55 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -51,8 +51,8 @@ class ElementwiseMulKernel auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y_lod); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod); - pten::MultiplyKernel(cuda_ctx, *pt_x.get(), *pt_y.get(), axis, - pt_z.get()); + pten::MultiplyRawKernel(cuda_ctx, *pt_x.get(), *pt_y.get(), axis, + pt_z.get()); } else { PADDLE_THROW(platform::errors::InvalidArgument( "X's type[%s] is not supported by elementwise_op. X's type should be " diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index d918407930d96..e7a5e48b1f1b5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -124,7 +124,7 @@ class ElementwiseMulKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod); - pten::MultiplyKernel( + pten::MultiplyRawKernel( static_cast::TYPE&>(dev_ctx), *pt_x.get(), *pt_y.get(), axis, pt_z.get()); diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index e1d9655e293a3..aaf33ca674488 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -140,26 +140,42 @@ class ElementwiseOp : public framework::OperatorWithKernel { framework::KernelSignature GetExpectedPtenKernelArgs( const framework::ExecutionContext &ctx) const override { + int axis = ctx.Attr("axis"); if (Type() == "elementwise_add") { if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature("add", {"X", "Y"}, {"axis"}, {"Out"}); + if (axis == -1) { + return framework::KernelSignature("add", {"X", "Y"}, {}, {"Out"}); + } + return framework::KernelSignature("add_raw", {"X", "Y"}, {"axis"}, + {"Out"}); } } if (Type() == "elementwise_sub") { if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature("subtract", {"X", "Y"}, {"axis"}, + if (axis == -1) { + return framework::KernelSignature("subtract", {"X", "Y"}, {}, + {"Out"}); + } + return framework::KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"}); } } if (Type() == "elementwise_div") { if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature("divide", {"X", "Y"}, {"axis"}, + if (axis == -1) { + return framework::KernelSignature("divide", {"X", "Y"}, {}, {"Out"}); + } + return framework::KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"}); } } if (Type() == "elementwise_mul") { if (ctx.InputVar("X")->IsType()) { - return framework::KernelSignature("multiply", {"X", "Y"}, {"axis"}, + if (axis == -1) { + return framework::KernelSignature("multiply", {"X", "Y"}, {}, + {"Out"}); + } + return framework::KernelSignature("multiply_raw", {"X", "Y"}, {"axis"}, {"Out"}); } } diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h index 46d4a93e804f5..7d1749f20abf2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -51,7 +51,7 @@ class ElementwiseSubKernel : public framework::OpKernel { auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - pten::SubtractKernel( + pten::SubtractRawKernel( static_cast::TYPE&>(dev_ctx), *pt_x.get(), *pt_y.get(), axis, pt_z.get()); diff --git a/paddle/fluid/operators/lu_op.h b/paddle/fluid/operators/lu_op.h index 6beef1add8e4c..c3b3552ba1329 100644 --- a/paddle/fluid/operators/lu_op.h +++ b/paddle/fluid/operators/lu_op.h @@ -221,7 +221,7 @@ void Tensor_Add(const DeviceContext& dev_ctx, const framework::Tensor& src1, out->Resize(src1.dims()); out->mutable_data(dev_ctx.GetPlace()); - pten::AddKernel< + pten::AddRawKernel< T, typename paddle::framework::ConvertToPtenContext::TYPE>( static_cast::TYPE&>(dev_ctx), @@ -234,7 +234,7 @@ void Tensor_Sub(const DeviceContext& dev_ctx, const framework::Tensor& src1, out->Resize(src1.dims()); out->mutable_data(dev_ctx.GetPlace()); - pten::SubtractKernel< + pten::SubtractRawKernel< T, typename paddle::framework::ConvertToPtenContext::TYPE>( static_cast::TYPE&>(dev_ctx), diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index e2002856a4d08..2e5bd7a42b1d1 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -551,17 +551,26 @@ class ReduceOp : public framework::OperatorWithKernel { framework::KernelSignature GetExpectedPtenKernelArgs( const framework::ExecutionContext& ctx) const override { + bool reduce_all = ctx.Attr("reduce_all"); if (Type() == "reduce_sum") { if (ctx.InputVar("X")->IsType()) { + if (!reduce_all) { + return framework::KernelSignature( + "sum", {"X"}, {"dim", "keep_dim", "out_dtype"}, {"Out"}); + } return framework::KernelSignature( - "sum", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"}, + "sum_raw", {"X"}, {"dim", "keep_dim", "reduce_all", "out_dtype"}, {"Out"}); } } if (Type() == "reduce_mean") { if (ctx.InputVar("X")->IsType()) { + if (!reduce_all) { + return framework::KernelSignature("mean", {"X"}, {"dim", "keep_dim"}, + {"Out"}); + } return framework::KernelSignature( - "mean", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); + "mean_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); } } // TODO(chentianyu03): support other cases after selected rows added diff --git a/paddle/pten/api/include/kernel_signature.h b/paddle/pten/api/include/kernel_signature.h index e3929d59159c1..d750b47ef864b 100644 --- a/paddle/pten/api/include/kernel_signature.h +++ b/paddle/pten/api/include/kernel_signature.h @@ -30,7 +30,6 @@ using DeviceContext = paddle::platform::DeviceContext; using add_kernel = void (*)(const DeviceContext&, const DenseTensor&, const DenseTensor&, - int, DenseTensor*); using cast_kernel = void (*)(const DeviceContext&, @@ -46,7 +45,6 @@ using concat_kernel = void (*)(const DeviceContext&, using divide_kernel = void (*)(const DeviceContext&, const DenseTensor&, const DenseTensor&, - int, DenseTensor*); using dot_kernel = void (*)(const DeviceContext&, @@ -82,13 +80,11 @@ using mean_kernel = void (*)(const DeviceContext&, const DenseTensor&, const std::vector&, bool, - bool, DenseTensor*); using multiply_kernel = void (*)(const DeviceContext&, const DenseTensor&, const DenseTensor&, - int, DenseTensor*); using reshape_kernel = void (*)(const DeviceContext&, @@ -107,14 +103,12 @@ using sum_kernel = void (*)(const DeviceContext&, const DenseTensor&, const std::vector&, bool, - bool, DataType, DenseTensor*); using subtract_kernel = void (*)(const DeviceContext&, const DenseTensor&, const DenseTensor&, - int, DenseTensor*); using conj_kernel = void (*)(const DeviceContext&, diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index 5c86787966368..8e089970f9139 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -20,10 +20,10 @@ namespace pten { // the key is kernel_name in fluid, the value is the kernel_name in pten // the key is sorted by key's alphabet const std::unordered_map kernel_alias_name_map = { - {"elementwise_add", "add"}, - {"elementwise_div", "divide"}, - {"elementwise_mul", "muliply"}, - {"elementwise_sub", "subtract"}, + {"elementwise_add", "add_raw"}, + {"elementwise_div", "divide_raw"}, + {"elementwise_mul", "muliply_raw"}, + {"elementwise_sub", "subtract_raw"}, {"fill_any_like", "full_like"}, {"fill_constant", "full"}, {"flatten_contiguous_range", "flatten"}, @@ -32,8 +32,8 @@ const std::unordered_map kernel_alias_name_map = { {"matmul_v2_grad", "matmul_grad"}, {"matmul_v2_grad_grad", "matmul_double_grad"}, {"matmul_v2_triple_grad", "matmul_triple_grad"}, - {"reduce_mean", "mean"}, - {"reduce_sum", "sum"}, + {"reduce_mean", "mean_raw"}, + {"reduce_sum", "sum_raw"}, {"reshape2", "reshape"}, {"reshape2_grad", "reshape_grad"}, {"reshape2_grad_grad", "reshape_double_grad"}, diff --git a/paddle/pten/kernels/cpu/math_kernel.cc b/paddle/pten/kernels/cpu/math_kernel.cc index 7841dd4113cff..706a40936a393 100644 --- a/paddle/pten/kernels/cpu/math_kernel.cc +++ b/paddle/pten/kernels/cpu/math_kernel.cc @@ -32,11 +32,11 @@ namespace pten { #define DEFINE_CPU_ELEMENTWISE_OP(name) \ template \ - void name##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - int axis, \ - DenseTensor* out) { \ + void name##RawKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + int axis, \ + DenseTensor* out) { \ out->mutable_data(); \ if (x.dims() == y.dims()) { \ SameDimsElementwiseCompute>()( \ @@ -55,23 +55,35 @@ namespace pten { } template -void MeanKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DenseTensor* out) { +void MeanRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { auto out_dtype = x.dtype(); pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } template -void DivideKernel(const Context& dev_ctx, +void SumRawKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType out_dtype, DenseTensor* out) { + pten::Reduce( + dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); +} + +template +void DivideRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { // allocate memory for out out->mutable_data(); if (x.dims() == y.dims() && std::is_floating_point::value) { @@ -90,18 +102,6 @@ void DivideKernel(const Context& dev_ctx, } } -template -void SumKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DataType out_dtype, - DenseTensor* out) { - pten::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); -} - // Create the definition of Add DEFINE_CPU_ELEMENTWISE_OP(Add) @@ -118,42 +118,40 @@ using complex128 = ::paddle::platform::complex; // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // using bfloat16 = ::paddle::platform::bfloat16; -PT_REGISTER_KERNEL( - mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {} -PT_REGISTER_KERNEL(add, +PT_REGISTER_KERNEL(add_raw, CPU, ALL_LAYOUT, - pten::AddKernel, + pten::AddRawKernel, float, double, int, int64_t, complex64, complex128) {} -PT_REGISTER_KERNEL(subtract, +PT_REGISTER_KERNEL(subtract_raw, CPU, ALL_LAYOUT, - pten::SubtractKernel, + pten::SubtractRawKernel, float, double, int, int64_t, complex64, complex128) {} -PT_REGISTER_KERNEL(divide, +PT_REGISTER_KERNEL(divide_raw, CPU, ALL_LAYOUT, - pten::DivideKernel, + pten::DivideRawKernel, float, double, int, int64_t, complex64, complex128) {} -PT_REGISTER_KERNEL(multiply, +PT_REGISTER_KERNEL(multiply_raw, CPU, ALL_LAYOUT, - pten::MultiplyKernel, + pten::MultiplyRawKernel, float, double, int, @@ -161,10 +159,10 @@ PT_REGISTER_KERNEL(multiply, bool, complex64, complex128) {} -PT_REGISTER_KERNEL(sum, +PT_REGISTER_KERNEL(sum_raw, CPU, ALL_LAYOUT, - pten::SumKernel, + pten::SumRawKernel, bool, float, double, @@ -175,3 +173,5 @@ PT_REGISTER_KERNEL(sum, complex128) { kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } +PT_REGISTER_KERNEL( + mean_raw, CPU, ALL_LAYOUT, pten::MeanRawKernel, float, double, bool) {} diff --git a/paddle/pten/kernels/gpu/math_kernel.cu b/paddle/pten/kernels/gpu/math_kernel.cu index d7a16ac49b1c9..6b6383f81065b 100644 --- a/paddle/pten/kernels/gpu/math_kernel.cu +++ b/paddle/pten/kernels/gpu/math_kernel.cu @@ -37,11 +37,11 @@ namespace pten { #define DEFINE_CUDA_ELEMENTWISE_OP(name) \ template \ - void name##Kernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& y, \ - int axis, \ - DenseTensor* out) { \ + void name##RawKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + int axis, \ + DenseTensor* out) { \ std::vector inputs; \ std::vector outputs; \ inputs.emplace_back(&x); \ @@ -57,17 +57,29 @@ namespace pten { */ template -void MeanKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DenseTensor* out) { +void MeanRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { auto out_dtype = x.dtype(); pten::Reduce( dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); } +template +void SumRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType out_dtype, + DenseTensor* out) { + pten::Reduce( + dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); +} + // Create the definition of Add DEFINE_CUDA_ELEMENTWISE_OP(Add) // Create the definition of Subtract @@ -77,30 +89,16 @@ DEFINE_CUDA_ELEMENTWISE_OP(Multiply) // Create the definition of Divide DEFINE_CUDA_ELEMENTWISE_OP(Divide) -template -void SumKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DataType out_dtype, - DenseTensor* out) { - pten::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); -} - } // namespace pten using float16 = paddle::platform::float16; using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; -PT_REGISTER_KERNEL( - mean, GPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool, float16) {} -PT_REGISTER_KERNEL(add, +PT_REGISTER_KERNEL(add_raw, GPU, ALL_LAYOUT, - pten::AddKernel, + pten::AddRawKernel, float, double, int, @@ -108,10 +106,10 @@ PT_REGISTER_KERNEL(add, float16, complex64, complex128) {} -PT_REGISTER_KERNEL(subtract, +PT_REGISTER_KERNEL(subtract_raw, GPU, ALL_LAYOUT, - pten::SubtractKernel, + pten::SubtractRawKernel, float, double, int, @@ -119,10 +117,10 @@ PT_REGISTER_KERNEL(subtract, float16, complex64, complex128) {} -PT_REGISTER_KERNEL(divide, +PT_REGISTER_KERNEL(divide_raw, GPU, ALL_LAYOUT, - pten::DivideKernel, + pten::DivideRawKernel, float, double, int, @@ -130,10 +128,10 @@ PT_REGISTER_KERNEL(divide, float16, complex64, complex128) {} -PT_REGISTER_KERNEL(multiply, +PT_REGISTER_KERNEL(multiply_raw, GPU, ALL_LAYOUT, - pten::MultiplyKernel, + pten::MultiplyRawKernel, float, double, int, @@ -142,10 +140,10 @@ PT_REGISTER_KERNEL(multiply, float16, complex64, complex128) {} -PT_REGISTER_KERNEL(sum, +PT_REGISTER_KERNEL(sum_raw, GPU, ALL_LAYOUT, - pten::SumKernel, + pten::SumRawKernel, bool, float, double, @@ -156,3 +154,12 @@ PT_REGISTER_KERNEL(sum, complex128) { kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } + +PT_REGISTER_KERNEL(mean_raw, + GPU, + ALL_LAYOUT, + pten::MeanRawKernel, + float, + double, + bool, + float16) {} diff --git a/paddle/pten/kernels/math_kernel.cc b/paddle/pten/kernels/math_kernel.cc new file mode 100644 index 0000000000000..423282ab97ca4 --- /dev/null +++ b/paddle/pten/kernels/math_kernel.cc @@ -0,0 +1,212 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/math_kernel.h" + +#include "paddle/pten/backends/all_context.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace pten { + +template +void MeanKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DenseTensor* out) { + bool reduce_all = false; + MeanRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); +} + +template +void SumKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DataType out_dtype, + DenseTensor* out) { + bool reduce_all = false; + SumRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out); +} + +template +void AddKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + AddRawKernel(dev_ctx, x, y, axis, out); +} + +template +void SubtractKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + SubtractRawKernel(dev_ctx, x, y, axis, out); +} + +template +void DivideKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + DivideRawKernel(dev_ctx, x, y, axis, out); +} + +template +void MultiplyKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + MultiplyRawKernel(dev_ctx, x, y, axis, out); +} + +} // namespace pten + +using complex64 = ::paddle::platform::complex; +using complex128 = ::paddle::platform::complex; + +PT_REGISTER_KERNEL( + mean, CPU, ALL_LAYOUT, pten::MeanKernel, float, double, bool) {} + +PT_REGISTER_KERNEL(sum, + CPU, + ALL_LAYOUT, + pten::SumKernel, + bool, + float, + double, + paddle::platform::float16, + int, + int64_t, + complex64, + complex128) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +} + +PT_REGISTER_KERNEL(add, + CPU, + ALL_LAYOUT, + pten::AddKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} +PT_REGISTER_KERNEL(subtract, + CPU, + ALL_LAYOUT, + pten::SubtractKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} +PT_REGISTER_KERNEL(divide, + CPU, + ALL_LAYOUT, + pten::DivideKernel, + float, + double, + int, + int64_t, + complex64, + complex128) {} +PT_REGISTER_KERNEL(multiply, + CPU, + ALL_LAYOUT, + pten::MultiplyKernel, + float, + double, + int, + int64_t, + bool, + complex64, + complex128) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_REGISTER_KERNEL(mean, + GPU, + ALL_LAYOUT, + pten::MeanKernel, + float, + double, + bool, + paddle::platform::float16) {} +PT_REGISTER_KERNEL(sum, + GPU, + ALL_LAYOUT, + pten::SumKernel, + bool, + float, + double, + paddle::platform::float16, + int, + int64_t, + complex64, + complex128) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +} +PT_REGISTER_KERNEL(add, + GPU, + ALL_LAYOUT, + pten::AddKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + complex64, + complex128) {} +PT_REGISTER_KERNEL(subtract, + GPU, + ALL_LAYOUT, + pten::SubtractKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + complex64, + complex128) {} +PT_REGISTER_KERNEL(divide, + GPU, + ALL_LAYOUT, + pten::DivideKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + complex64, + complex128) {} +PT_REGISTER_KERNEL(multiply, + GPU, + ALL_LAYOUT, + pten::MultiplyKernel, + float, + double, + int, + int64_t, + bool, + paddle::platform::float16, + complex64, + complex128) {} +#endif diff --git a/paddle/pten/kernels/math_kernel.h b/paddle/pten/kernels/math_kernel.h index 65c0f84e696de..95379baaf3504 100644 --- a/paddle/pten/kernels/math_kernel.h +++ b/paddle/pten/kernels/math_kernel.h @@ -22,104 +22,127 @@ limitations under the License. */ namespace pten { +template +void MeanRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out); + template void MeanKernel(const Context& dev_ctx, const DenseTensor& x, const std::vector& dims, bool keep_dim, - bool reduce_all, DenseTensor* out); +template +void SumRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + bool reduce_all, + DataType out_dtype, + DenseTensor* out); + +template +void SumKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DataType out_dtype, + DenseTensor* out); + +template +void AddRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + template void AddKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - int axis, DenseTensor* out); +template +void SubtractRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + template void SubtractKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - int axis, DenseTensor* out); +template +void DivideRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + template void DivideKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - int axis, DenseTensor* out); +template +void MultiplyRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + template void MultiplyKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - int axis, DenseTensor* out); -template -void SumKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - bool reduce_all, - DataType out_dtype, - DenseTensor* out); - template DenseTensor Add(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis) { - auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - AddKernel(dev_ctx, x, y, axis, &dense_out); + const DenseTensor& y) { + auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), -1); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + AddKernel(dev_ctx, x, y, &dense_out); return dense_out; } template DenseTensor Subtract(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis) { - auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - SubtractKernel(dev_ctx, x, y, axis, &dense_out); + const DenseTensor& y) { + auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), -1); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + SubtractKernel(dev_ctx, x, y, &dense_out); return dense_out; } template DenseTensor Divide(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis) { - auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - DivideKernel(dev_ctx, x, y, axis, &dense_out); + const DenseTensor& y) { + auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), -1); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + DivideKernel(dev_ctx, x, y, &dense_out); return dense_out; } template DenseTensor Multiply(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& y, - int axis) { - auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), axis); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - MultiplyKernel(dev_ctx, x, y, axis, &dense_out); + const DenseTensor& y) { + auto out_meta = ElementwiseInferMeta(x.meta(), y.meta(), -1); + auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); + MultiplyKernel(dev_ctx, x, y, &dense_out); return dense_out; } @@ -130,8 +153,7 @@ DenseTensor Mean(const Context& dev_ctx, bool keep_dim) { auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim); auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); - bool reduce_all = false; - MeanKernel(dev_ctx, x, axis, keep_dim, reduce_all, &dense_out); + MeanKernel(dev_ctx, x, axis, keep_dim, &dense_out); return dense_out; } @@ -144,12 +166,7 @@ DenseTensor Sum(const Context& dev_ctx, auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim, dtype); auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); - // The real value of reduce_all will be get in kernel - // so use default value(false) is OK. - bool reduce_all = false; - - SumKernel( - dev_ctx, x, axis, keep_dim, reduce_all, out_meta.dtype, &dense_out); + SumKernel(dev_ctx, x, axis, keep_dim, dtype, &dense_out); return dense_out; } diff --git a/paddle/pten/tests/kernels/test_elementwise_dev_api.cc b/paddle/pten/tests/kernels/test_elementwise_dev_api.cc index 0bc16371c0731..e5d9b05eec7b3 100644 --- a/paddle/pten/tests/kernels/test_elementwise_dev_api.cc +++ b/paddle/pten/tests/kernels/test_elementwise_dev_api.cc @@ -54,11 +54,10 @@ TEST(DEV_API, add) { for (size_t i = 0; i < 10; ++i) { dense_y_data[i] = i * 2.0; } - int axis = 1; // 2. test API pten::CPUContext dev_ctx; - auto dense_out = pten::Add(dev_ctx, dense_x, dense_y, axis); + auto dense_out = pten::Add(dev_ctx, dense_x, dense_y); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); @@ -101,11 +100,10 @@ TEST(DEV_API, subtract) { for (size_t i = 0; i < 10; ++i) { dense_y_data[i] = i * 2.0; } - int axis = 1; // 2. test API pten::CPUContext dev_ctx; - auto dense_out = pten::Subtract(dev_ctx, dense_x, dense_y, axis); + auto dense_out = pten::Subtract(dev_ctx, dense_x, dense_y); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); @@ -148,11 +146,10 @@ TEST(DEV_API, divide) { for (size_t i = 0; i < 10; ++i) { dense_y_data[i] = i * 2.0 + 1; } - int axis = 1; // 2. test API pten::CPUContext dev_ctx; - auto dense_out = pten::Divide(dev_ctx, dense_x, dense_y, axis); + auto dense_out = pten::Divide(dev_ctx, dense_x, dense_y); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); @@ -195,11 +192,10 @@ TEST(DEV_API, multiply) { for (size_t i = 0; i < 10; ++i) { dense_y_data[i] = i * 2.0; } - int axis = 1; // 2. test API pten::CPUContext dev_ctx; - auto dense_out = pten::Multiply(dev_ctx, dense_x, dense_y, axis); + auto dense_out = pten::Multiply(dev_ctx, dense_x, dense_y); // 3. check result ASSERT_EQ(dense_out.dims().size(), 2); diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 1bf5344e83746..a0d7ce84f75fd 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -6,7 +6,6 @@ param : [x, y, -1] kernel : func : add - param : [x, y, -1] - api : cast args : (const Tensor& x, DataType out_dtype) @@ -44,7 +43,6 @@ param : [x, y, -1] kernel : func : divide - param : [x, y, -1] - api : dot args : (const Tensor& x, const Tensor& y) @@ -130,7 +128,6 @@ param: [x, axis, keep_dim] kernel : func : mean - param : [x, axis, keep_dim, false] - api : multiply args : (const Tensor& x, const Tensor& y) @@ -140,7 +137,6 @@ param : [x, y, -1] kernel : func : multiply - param : [x, y, -1] - api : ones_like args : (const Tensor& x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED) @@ -172,7 +168,6 @@ param : [x, y, -1] kernel : func : subtract - param : [x, y, -1] - api : sum args : (const Tensor& x, const std::vector& axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false) @@ -182,7 +177,7 @@ param: [x, axis, keep_dim, dtype] kernel : func : sum - param : [x, axis, keep_dim, false, DataType::UNDEFINED] + param : [x, axis, keep_dim, dtype] data_type : x - api : zeros_like From a14dc68820dbb221831b13b8c43155f537e265e9 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Fri, 21 Jan 2022 20:56:04 +0800 Subject: [PATCH 03/14] [pten] fix test concat dev api build failed (#39117) * fix test concat dev api build failed * fix conflict * fix conflict --- paddle/fluid/operators/concat_op.h | 5 ++++- paddle/pten/kernels/cpu/concat_kernel.cc | 2 +- paddle/pten/kernels/gpu/concat_kernel.cu | 2 +- paddle/pten/tests/api/test_concat_api.cc | 6 ++++-- paddle/pten/tests/kernels/test_concat_dev_api.cc | 16 +++++++--------- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 3eaffbdc8bf35..1d9c10bdb8cc6 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -80,7 +80,10 @@ class ConcatKernel : public framework::OpKernel { pt_ins.push_back(*in); } - pten::ConcatKernel(dev_ctx, pt_ins, axis, out); + pten::ConcatKernel( + static_cast::TYPE&>(dev_ctx), + pt_ins, axis, out); } }; diff --git a/paddle/pten/kernels/cpu/concat_kernel.cc b/paddle/pten/kernels/cpu/concat_kernel.cc index fb59c9c6005ff..c4aed7679bd72 100644 --- a/paddle/pten/kernels/cpu/concat_kernel.cc +++ b/paddle/pten/kernels/cpu/concat_kernel.cc @@ -43,7 +43,7 @@ void ConcatKernel(const Context& dev_ctx, pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis); out->Resize(out_dims); - out->mutable_data(); + out->mutable_data(dev_ctx.GetPlace()); // If axis is 0, the lod of the output is not the same as inputs. if (axis == 0 && x[0].lod().size() > 0) { diff --git a/paddle/pten/kernels/gpu/concat_kernel.cu b/paddle/pten/kernels/gpu/concat_kernel.cu index 6ddfef460fc6c..e52e3a3d6446c 100644 --- a/paddle/pten/kernels/gpu/concat_kernel.cu +++ b/paddle/pten/kernels/gpu/concat_kernel.cu @@ -43,7 +43,7 @@ void ConcatKernel(const Context& dev_ctx, pten::DDim out_dims = pten::funcs::ComputeAndCheckShape(true, x_dims, axis); out->Resize(out_dims); - out->mutable_data(); + out->mutable_data(dev_ctx.GetPlace()); // If axis is 0, the lod of the output is not the same as inputs. if (axis == 0 && x[0].lod().size() > 0) { diff --git a/paddle/pten/tests/api/test_concat_api.cc b/paddle/pten/tests/api/test_concat_api.cc index e84aee0aaaf4f..c003e89f6c009 100644 --- a/paddle/pten/tests/api/test_concat_api.cc +++ b/paddle/pten/tests/api/test_concat_api.cc @@ -37,14 +37,16 @@ TEST(API, concat) { pten::DenseTensorMeta(pten::DataType::FLOAT32, framework::make_ddim({3, 10}), pten::DataLayout::NCHW)); - auto* dense_x_data = dense_x->mutable_data(); + auto* dense_x_data = + dense_x->mutable_data(paddle::platform::CPUPlace()); auto dense_y = std::make_shared( alloc.get(), pten::DenseTensorMeta(pten::DataType::FLOAT32, framework::make_ddim({3, 10}), pten::DataLayout::NCHW)); - auto* dense_y_data = dense_y->mutable_data(); + auto* dense_y_data = + dense_y->mutable_data(paddle::platform::CPUPlace()); for (size_t i = 0; i < 3; ++i) { for (size_t j = 0; j < 10; ++j) { diff --git a/paddle/pten/tests/kernels/test_concat_dev_api.cc b/paddle/pten/tests/kernels/test_concat_dev_api.cc index c5d979ad908ff..6f9ea1b0d990a 100644 --- a/paddle/pten/tests/kernels/test_concat_dev_api.cc +++ b/paddle/pten/tests/kernels/test_concat_dev_api.cc @@ -25,7 +25,7 @@ namespace pten { namespace tests { namespace framework = paddle::framework; -using DDim = paddle::framework::DDim; +using DDim = pten::framework::DDim; TEST(DEV_API, concat) { // 1. create tensor @@ -35,13 +35,15 @@ TEST(DEV_API, concat) { pten::DenseTensorMeta(pten::DataType::FLOAT32, framework::make_ddim({3, 10}), pten::DataLayout::NCHW)); - auto* dense_x_data = dense_x.mutable_data(); + auto* dense_x_data = + dense_x.mutable_data(paddle::platform::CPUPlace()); pten::DenseTensor dense_y(alloc.get(), pten::DenseTensorMeta(pten::DataType::FLOAT32, framework::make_ddim({3, 10}), pten::DataLayout::NCHW)); - auto* dense_y_data = dense_y.mutable_data(); + auto* dense_y_data = + dense_y.mutable_data(paddle::platform::CPUPlace()); for (size_t i = 0; i < 3; ++i) { for (size_t j = 0; j < 10; ++j) { @@ -50,15 +52,11 @@ TEST(DEV_API, concat) { } } - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); - std::vector inputs = {dense_x, dense_y}; // 2. test API - auto out = pten::Concat( - *(static_cast(dev_ctx)), inputs, 0); + pten::CPUContext dev_ctx; + auto out = pten::Concat(dev_ctx, inputs, 0); // 3. check result ASSERT_EQ(out.dims().size(), 2); From e92b304032eb7e9a46d18f38b3bc52ff00ee4701 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 22 Jan 2022 12:02:40 +0800 Subject: [PATCH 04/14] [PTen] Auto generate include headers (#39123) * auto gen include headers * move to pten.cmake --- .gitignore | 2 ++ cmake/{pten_kernel.cmake => pten.cmake} | 42 +++++++++++++++++++++++++ paddle/pten/CMakeLists.txt | 12 +++++++ paddle/pten/kernels/CMakeLists.txt | 2 -- 4 files changed, 56 insertions(+), 2 deletions(-) rename cmake/{pten_kernel.cmake => pten.cmake} (82%) diff --git a/.gitignore b/.gitignore index 6be36bf8c243e..708126b3bb070 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ paddle/fluid/API_PR.spec paddle/fluid/op_use_default_grad_maker_DEV.spec paddle/fluid/op_use_default_grad_maker_PR.spec paddle/pten/api/*/api* +paddle/pten/include/* +paddle/pten/extension.h *.DS_Store *.vs diff --git a/cmake/pten_kernel.cmake b/cmake/pten.cmake similarity index 82% rename from cmake/pten_kernel.cmake rename to cmake/pten.cmake index c2928376a02f8..70d61027da872 100644 --- a/cmake/pten_kernel.cmake +++ b/cmake/pten.cmake @@ -12,6 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. +function(generate_unify_header DIR_NAME) + set(options "") + set(oneValueArgs HEADER_NAME SKIP_SUFFIX) + set(multiValueArgs "") + cmake_parse_arguments(generate_unify_header "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + # get header name and suffix + set(header_name "${DIR_NAME}") + list(LENGTH generate_unify_header_HEADER_NAME generate_unify_header_HEADER_NAME_len) + if(${generate_unify_header_HEADER_NAME_len} GREATER 0) + set(header_name "${generate_unify_header_HEADER_NAME}") + endif() + set(skip_suffix "") + list(LENGTH generate_unify_header_SKIP_SUFFIX generate_unify_header_SKIP_SUFFIX_len) + if(${generate_unify_header_SKIP_SUFFIX_len} GREATER 0) + set(skip_suffix "${generate_unify_header_SKIP_SUFFIX}") + endif() + + # generate target header file + set(header_file ${CMAKE_CURRENT_SOURCE_DIR}/include/${header_name}.h) + file(WRITE ${header_file} "// Header file generated by paddle/pten/CMakeLists.txt for external users,\n// DO NOT edit or include it within paddle.\n\n#pragma once\n\n") + + # get all top-level headers and write into header file + file(GLOB HEADERS "${CMAKE_CURRENT_SOURCE_DIR}\/${DIR_NAME}\/*.h") + foreach(header ${HEADERS}) + if("${skip_suffix}" STREQUAL "") + string(REPLACE "${PADDLE_SOURCE_DIR}\/" "" header "${header}") + file(APPEND ${header_file} "#include \"${header}\"\n") + else() + string(FIND "${header}" "${skip_suffix}.h" skip_suffix_found) + if(${skip_suffix_found} EQUAL -1) + string(REPLACE "${PADDLE_SOURCE_DIR}\/" "" header "${header}") + file(APPEND ${header_file} "#include \"${header}\"\n") + endif() + endif() + endforeach() + # append header into extension.h + string(REPLACE "${PADDLE_SOURCE_DIR}\/" "" header_file "${header_file}") + file(APPEND ${pten_extension_header_file} "#include \"${header_file}\"\n") +endfunction() + # call kernel_declare need to make sure whether the target of input exists function(kernel_declare TARGET_LIST) foreach(kernel_path ${TARGET_LIST}) diff --git a/paddle/pten/CMakeLists.txt b/paddle/pten/CMakeLists.txt index cde5e719e316d..671ed28313af9 100644 --- a/paddle/pten/CMakeLists.txt +++ b/paddle/pten/CMakeLists.txt @@ -1,3 +1,6 @@ +# pten auto cmake utils +include(pten) + # paddle experimental common components add_subdirectory(common) @@ -25,3 +28,12 @@ message(STATUS "All standard pten kernels: ${pten_kernels}") set(PTEN_DEPS ${PTEN_DEPS} ${pten_kernels}) cc_library(pten DEPS ${PTEN_DEPS}) + +set(pten_extension_header_file ${CMAKE_CURRENT_SOURCE_DIR}/extension.h CACHE INTERNAL "pten/extension.h file") +file(WRITE ${pten_extension_header_file} "// Header file generated by paddle/pten/CMakeLists.txt for external users,\n// DO NOT edit or include it within paddle.\n\n#pragma once\n\n") + +# generate inner headers include dir for users +generate_unify_header(backends) +generate_unify_header(core) +generate_unify_header(infermeta) +generate_unify_header(kernels SKIP_SUFFIX grad_kernel) diff --git a/paddle/pten/kernels/CMakeLists.txt b/paddle/pten/kernels/CMakeLists.txt index 76e112808892d..f838c4d424c2d 100644 --- a/paddle/pten/kernels/CMakeLists.txt +++ b/paddle/pten/kernels/CMakeLists.txt @@ -1,5 +1,3 @@ -include(pten_kernel) - set(kernel_declare_file ${PADDLE_BINARY_DIR}/paddle/pten/kernels/declarations.h.tmp CACHE INTERNAL "declarations.h file") set(kernel_declare_file_final ${PADDLE_BINARY_DIR}/paddle/pten/kernels/declarations.h) file(WRITE ${kernel_declare_file} "// Generated by the paddle/pten/kernels/CMakeLists.txt. DO NOT EDIT!\n\n#pragma once\n\n") From 09f6f17c066b4dc2082a0402418116ff2ab5c96f Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 22 Jan 2022 12:55:06 +0800 Subject: [PATCH 05/14] add meta tensor for unify infershape (#39131) --- paddle/pten/core/CMakeLists.txt | 3 ++ paddle/pten/core/kernel_registry.h | 31 ++--------- paddle/pten/core/macros.h | 56 +++++++++++++++++++ paddle/pten/core/meta_tensor.cc | 86 ++++++++++++++++++++++++++++++ paddle/pten/core/meta_tensor.h | 54 +++++++++++++++++++ 5 files changed, 203 insertions(+), 27 deletions(-) create mode 100644 paddle/pten/core/macros.h create mode 100644 paddle/pten/core/meta_tensor.cc create mode 100644 paddle/pten/core/meta_tensor.h diff --git a/paddle/pten/core/CMakeLists.txt b/paddle/pten/core/CMakeLists.txt index d89b3c9fefb59..7c8ace2bc7ef4 100644 --- a/paddle/pten/core/CMakeLists.txt +++ b/paddle/pten/core/CMakeLists.txt @@ -16,6 +16,9 @@ cc_library(lod_utils SRCS lod_utils.cc DEPS enforce mixed_vector) cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base) cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base ) + +cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor) + cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index 194ab52d25688..e1160ea6b7d5d 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -24,6 +24,7 @@ #include "paddle/pten/core/kernel_def.h" #include "paddle/pten/core/kernel_factory.h" #include "paddle/pten/core/kernel_utils.h" +#include "paddle/pten/core/macros.h" #include "paddle/fluid/platform/enforce.h" @@ -158,33 +159,6 @@ struct KernelRegistrar { } }; -#define PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ - _PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) - -#define _PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ - struct __test_global_namespace_##uniq_name##__ {}; \ - static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ - __test_global_namespace_##uniq_name##__>::value, \ - msg) - -#ifdef __COUNTER__ -#define PT_ID __COUNTER__ -#else -#define PT_ID __LINE__ -#endif - -#if defined(_WIN32) -#define UNUSED -#define __builtin_expect(EXP, C) (EXP) -#else -#define UNUSED __attribute__((unused)) -#endif - -#define PT_CONCATENATE(arg1, arg2) PT_CONCATENATE1(arg1, arg2) -#define PT_CONCATENATE1(arg1, arg2) PT_CONCATENATE2(arg1, arg2) -#define PT_CONCATENATE2(arg1, arg2) arg1##arg2 -#define PT_EXPAND(x) x - /** * Reference: * @@ -834,6 +808,9 @@ struct KernelRegistrar { * to avoid being removed by linker */ #define PT_DECLARE_KERNEL(kernel_name, backend, layout) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_declare_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \ + "PT_DECLARE_KERNEL must be called in global namespace."); \ extern int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout(); \ UNUSED static int \ __declare_kernel_symbol_for_##kernel_name##_##backend##_##layout = \ diff --git a/paddle/pten/core/macros.h b/paddle/pten/core/macros.h new file mode 100644 index 0000000000000..fec67b1a3dc25 --- /dev/null +++ b/paddle/pten/core/macros.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace pten { + +// Disable the copy and assignment operator for a class. +#ifndef DISABLE_COPY_AND_ASSIGN +#define DISABLE_COPY_AND_ASSIGN(classname) \ + private: \ + classname(const classname&) = delete; \ + classname(classname&&) = delete; \ + classname& operator=(const classname&) = delete; \ + classname& operator=(classname&&) = delete +#endif + +#define PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ + _PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) + +#define _PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ + struct __test_global_namespace_##uniq_name##__ {}; \ + static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ + __test_global_namespace_##uniq_name##__>::value, \ + msg) + +#ifdef __COUNTER__ +#define PT_ID __COUNTER__ +#else +#define PT_ID __LINE__ +#endif + +#if defined(_WIN32) +#define UNUSED +#define __builtin_expect(EXP, C) (EXP) +#else +#define UNUSED __attribute__((unused)) +#endif + +#define PT_CONCATENATE(arg1, arg2) PT_CONCATENATE1(arg1, arg2) +#define PT_CONCATENATE1(arg1, arg2) PT_CONCATENATE2(arg1, arg2) +#define PT_CONCATENATE2(arg1, arg2) arg1##arg2 +#define PT_EXPAND(x) x + +} // namespace pten diff --git a/paddle/pten/core/meta_tensor.cc b/paddle/pten/core/meta_tensor.cc new file mode 100644 index 0000000000000..f52d771b73bb9 --- /dev/null +++ b/paddle/pten/core/meta_tensor.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/meta_tensor.h" + +#include "paddle/pten/core/compat_utils.h" +#include "paddle/pten/core/dense_tensor.h" + +#include "paddle/fluid/platform/enforce.h" + +namespace pten { + +int64_t MetaTensor::numel() const { return tensor_->numel(); } + +DDim MetaTensor::dims() const { return tensor_->dims(); } + +DataType MetaTensor::dtype() const { return tensor_->dtype(); } + +DataLayout MetaTensor::layout() const { return tensor_->layout(); } + +void MetaTensor::set_dims(const DDim& dims) { + if (pten::DenseTensor::classof(tensor_)) { + CompatibleDenseTensorUtils::GetMutableMeta( + static_cast(tensor_)) + ->dims = dims; + } else { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Unsupported setting dims for `%s`.", tensor_->type_info().name())); + } +} + +void MetaTensor::set_dtype(DataType dtype) { + if (pten::DenseTensor::classof(tensor_)) { + CompatibleDenseTensorUtils::GetMutableMeta( + static_cast(tensor_)) + ->dtype = dtype; + } else { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Unsupported settting dtype for `%s`.", tensor_->type_info().name())); + } +} + +void MetaTensor::set_layout(DataLayout layout) { + if (pten::DenseTensor::classof(tensor_)) { + CompatibleDenseTensorUtils::GetMutableMeta( + static_cast(tensor_)) + ->layout = layout; + } else { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Unsupported settting layout for `%s`.", tensor_->type_info().name())); + } +} + +void MetaTensor::share_lod(const MetaTensor& meta_tensor) { + if (pten::DenseTensor::classof(tensor_)) { + CompatibleDenseTensorUtils::GetMutableMeta( + static_cast(tensor_)) + ->lod = meta_tensor.lod(); + } else { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Unsupported share lod inplace for `%s`.", + tensor_->type_info().name())); + } +} + +const LoD& MetaTensor::lod() const { + if (pten::DenseTensor::classof(tensor_)) { + return static_cast(tensor_)->lod(); + } else { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Unsupported setting dims for `%s`.", tensor_->type_info().name())); + } +} + +} // namespace pten diff --git a/paddle/pten/core/meta_tensor.h b/paddle/pten/core/meta_tensor.h new file mode 100644 index 0000000000000..4273aa6f85b4e --- /dev/null +++ b/paddle/pten/core/meta_tensor.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/common/data_type.h" +#include "paddle/pten/common/layout.h" +#include "paddle/pten/core/macros.h" +#include "paddle/pten/core/tensor_base.h" +#include "paddle/pten/core/tensor_meta.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/framework/ddim.h" + +namespace pten { + +class MetaTensor { + public: + explicit MetaTensor(TensorBase* tensor) : tensor_(tensor) {} + + MetaTensor() = default; + MetaTensor(const MetaTensor&) = default; + MetaTensor(MetaTensor&&) = default; + MetaTensor& operator=(const MetaTensor&) = delete; + MetaTensor& operator=(MetaTensor&&) = delete; + + virtual ~MetaTensor() = default; + + virtual int64_t numel() const; + virtual DDim dims() const; + virtual DataType dtype() const; + virtual DataLayout layout() const; + virtual void set_dims(const DDim& dims); + virtual void set_dtype(DataType dtype); + virtual void set_layout(DataLayout layout); + virtual void share_lod(const MetaTensor& meta_tensor); + + private: + const LoD& lod() const; + TensorBase* tensor_; +}; + +} // namespace pten From ff7f9d064083b1b06e1c1781129763815d10c0b4 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Sat, 22 Jan 2022 15:45:44 +0800 Subject: [PATCH 06/14] [Move selected_rows PR #2] Added Selected_Rows and rw_lock to Pten (#39087) * Renamed selected_rows.* -> selected_rows_utils.* * Added selected_rows and rw_lock to pten * Removed useless header * Renamed the unit test target to fix CI * Use pten::framework::DDim * Set selceted_rows_test properties timeout * Polish code to pten style Co-authored-by: Chen Weihang --- paddle/pten/core/CMakeLists.txt | 3 +- paddle/pten/core/selected_rows.cc | 208 +++++++++++++++++++ paddle/pten/core/selected_rows.h | 164 +++++++++++++++ paddle/pten/core/utils/rw_lock.h | 105 ++++++++++ paddle/pten/tests/core/CMakeLists.txt | 7 + paddle/pten/tests/core/test_rw_lock.cc | 83 ++++++++ paddle/pten/tests/core/test_selected_rows.cc | 187 +++++++++++++++++ 7 files changed, 756 insertions(+), 1 deletion(-) create mode 100644 paddle/pten/core/selected_rows.cc create mode 100644 paddle/pten/core/selected_rows.h create mode 100644 paddle/pten/core/utils/rw_lock.h create mode 100644 paddle/pten/tests/core/test_rw_lock.cc create mode 100644 paddle/pten/tests/core/test_selected_rows.cc diff --git a/paddle/pten/core/CMakeLists.txt b/paddle/pten/core/CMakeLists.txt index 7c8ace2bc7ef4..f6f0e1f3e26ec 100644 --- a/paddle/pten/core/CMakeLists.txt +++ b/paddle/pten/core/CMakeLists.txt @@ -16,7 +16,6 @@ cc_library(lod_utils SRCS lod_utils.cc DEPS enforce mixed_vector) cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base) cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base ) - cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor) cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc) @@ -28,6 +27,8 @@ elseif(WITH_ROCM) hip_test(dim_test SRCS dim_test.cu DEPS ddim) endif() +cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim) + # Will remove once we implemented MKLDNN_Tensor if(WITH_MKLDNN) add_dependencies(dense_tensor mkldnn) diff --git a/paddle/pten/core/selected_rows.cc b/paddle/pten/core/selected_rows.cc new file mode 100644 index 0000000000000..6f64602bdcf4d --- /dev/null +++ b/paddle/pten/core/selected_rows.cc @@ -0,0 +1,208 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/selected_rows.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/framework/data_type.h" + +namespace pten { + +struct ReAllocateVisitor { + ReAllocateVisitor(const pten::framework::DDim& dims, + pten::DenseTensor* tensor) + : dims_(dims), tensor_(tensor) {} + + template + void operator()() const { + pten::DenseTensor cpu_tensor; + paddle::platform::CPUPlace cpu; + T* ptr = cpu_tensor.mutable_data(dims_, cpu); + const T* old_ptr = + tensor_->memory_size() == 0 ? nullptr : tensor_->data(); + if (old_ptr != nullptr) { + std::copy(old_ptr, old_ptr + tensor_->numel(), ptr); + } + tensor_->ShareDataWith(cpu_tensor); + } + + pten::framework::DDim dims_; + pten::DenseTensor* tensor_; +}; + +struct TensorCopyVisitor { + TensorCopyVisitor(pten::DenseTensor* dst, + int64_t dst_offset, + const pten::DenseTensor src, + int64_t src_offset, + int64_t size) + : dst_(dst), + dst_offset_(dst_offset), + src_(src), + src_offset_(src_offset), + size_(size) {} + + template + void apply() const { + // TODO(Yancey1989): support other place + paddle::platform::CPUPlace cpu; + paddle::memory::Copy(cpu, + dst_->mutable_data(cpu) + dst_offset_, + cpu, + src_.data() + src_offset_, + size_ * sizeof(T)); + } + + pten::DenseTensor* dst_; + int64_t dst_offset_; + pten::DenseTensor src_; + int64_t src_offset_; + int64_t size_; +}; + +struct TensorFillVisitor { + TensorFillVisitor(pten::DenseTensor* dst, + int64_t dst_offset, + int64_t size, + float value) + : dst_(dst), dst_offset_(dst_offset), size_(size) {} + + template + void apply() const { + // TODO(qiao): support other place + paddle::platform::CPUPlace cpu; + auto* tensor_data = dst_->mutable_data(cpu); + auto* start = tensor_data + dst_offset_; + auto* end = start + size_; + std::fill(start, end, static_cast(0.0)); + } + + pten::DenseTensor* dst_; + int64_t dst_offset_; + int64_t size_; +}; + +bool SelectedRows::HasKey(int64_t key) const { + return std::find(rows_.begin(), rows_.end(), key) == rows_.end() ? false + : true; +} + +int64_t SelectedRows::AutoGrownIndex(int64_t key, + bool auto_grown, + bool is_test) { + if (is_test) { + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + return -1; + } else { + return iter->second; + } + } + + rwlock_->RDLock(); + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + rwlock_->UNLock(); + PADDLE_ENFORCE_EQ(auto_grown, + true, + paddle::platform::errors::NotFound( + "Input key(%lld) is not found.", key)); + rwlock_->WRLock(); + auto map_size = id_to_index_.size(); + auto vector_size = rows_.size(); + if (map_size != vector_size) { + rwlock_->UNLock(); + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Row map size(%zu) should be equal to rows size(%zu).", + map_size, + vector_size)); + } + auto write_iter = id_to_index_.find(key); + if (write_iter == id_to_index_.end()) { + int row_num = rows_.size(); + if (row_num == value_->dims()[0]) { + rwlock_->UNLock(); + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Selected rows is full, then length exceed the length of first " + "dimension (%d).", + row_num)); + } + // key logic to put a key into id_to_index_ + rows_.push_back(key); + auto index = static_cast(rows_.size() - 1); + id_to_index_[key] = index; + rwlock_->UNLock(); + return index; + } else { + auto index = write_iter->second; + rwlock_->UNLock(); + return index; + } + } else { + auto index = iter->second; + rwlock_->UNLock(); + return index; + } +} + +void SelectedRows::SyncIndex() { + rwlock_->WRLock(); + id_to_index_.clear(); + for (size_t i = 0; i < rows_.size(); ++i) { + id_to_index_[rows_[i]] = i; + } + rwlock_->UNLock(); +} + +void SelectedRows::Get(const pten::DenseTensor& ids, + pten::DenseTensor* value, + bool auto_grown, + bool is_test) { + PADDLE_ENFORCE_EQ(value->IsInitialized(), + true, + paddle::platform::errors::InvalidArgument( + "The value tensor is not initialized.")); + if (ids.numel() == 0) { + VLOG(3) << "keys is empty, please check data!"; + } else { + int64_t value_width = value_->numel() / value_->dims()[0]; + PADDLE_ENFORCE_EQ( + value_width, + value->numel() / value->dims()[0], + paddle::platform::errors::InvalidArgument( + "Output tensor should have the same shape with table " + "except the first dimmension, excepted value width not counting " + "the first dimension is %d, actual value width is %d.", + value_width, + value->numel() / value->dims()[0])); + for (int i = 0; i < ids.numel(); ++i) { + auto id = ids.data()[i]; + int64_t index = AutoGrownIndex(id, auto_grown, is_test); + if (index < 0) { + VLOG(5) << "id " << id << " not in the table, return 0"; + paddle::framework::VisitDataType( + value_->type(), + TensorFillVisitor(value, i * value_width, value_width, 0.0)); + } else { + paddle::framework::VisitDataType(value_->type(), + TensorCopyVisitor(value, + i * value_width, + *value_.get(), + index * value_width, + value_width)); + } + } + } +} +} // namespace pten diff --git a/paddle/pten/core/selected_rows.h b/paddle/pten/core/selected_rows.h new file mode 100644 index 0000000000000..f5be0a906dbdb --- /dev/null +++ b/paddle/pten/core/selected_rows.h @@ -0,0 +1,164 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include // NOLINT +#include +#include +#include + +#include "paddle/pten/common/place.h" +#include "paddle/pten/core/ddim.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/utils/rw_lock.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/enforce.h" + +namespace pten { +class SelectedRows { + /* + * @brief We can use the SelectedRows structure to reproduce a sparse table. + * A sparse table is a key-value structure that the key is an `int64_t`, + * and the value is a Tensor which the first dimension is 0. + * You can use the following interface to operate the sparse table, and you + * can find + * some detail information from the comments of each interface: + * + * HasKey(key), whether the sparse table has the specified key. + * Set(key, value), set a key-value pair into the sparse table. + * Get(keys, value*), get value by given key list and apply it to the given + * value pointer + * with the specified offset. + * + */ + public: + SelectedRows(const std::vector& rows, const int64_t& height) + : rows_(rows), height_(height) { + value_.reset(new pten::DenseTensor()); + rwlock_.reset(new RWLock); + } + + SelectedRows() { + height_ = 0; + value_.reset(new pten::DenseTensor()); + rwlock_.reset(new RWLock); + } + + const pten::Place& place() const { return value_->place(); } + + const pten::DenseTensor& value() const { return *value_; } + + pten::DenseTensor* mutable_value() { return value_.get(); } + + int64_t height() const { return height_; } + + void set_height(int64_t height) { height_ = height; } + + const paddle::framework::Vector& rows() const { return rows_; } + + paddle::framework::Vector* mutable_rows() { return &rows_; } + + void set_rows(const paddle::framework::Vector& rows) { + rows_ = rows; + } + + /* + * @brief Get the index of key in rows + * + * @return -1 if the key does not exists. + */ + int64_t Index(int64_t key) const { + auto it = std::find(rows_.begin(), rows_.end(), key); + if (it == rows_.end()) { + PADDLE_THROW(paddle::platform::errors::NotFound( + "Input id (%lld) is not in current rows table.", key)); + } + return static_cast(std::distance(rows_.begin(), it)); + } + + /* + * @brief whether has the specified key in the table. + * + * @return true if the key is exists. + */ + bool HasKey(int64_t key) const; + + /* + * @brief Get value by the key list. + * Note!!! this interface is only used when selected_rows is used as + * parameters + * for distribute lookup table. + * + * @return a list of pair which contains the non-exists key and the index in + * the value + */ + void Get(const pten::DenseTensor& ids, + pten::DenseTensor* value, + bool auto_grown = false, + bool is_test = false); + + /* + * @brief Get the index of the key from id_to_index_ map. If the key not + * exist, + * add the key into id_to_index_. + * + * Note!!! this interface is only used when selected_rows is used as + * parameters + * for distribute lookup table. + * + * @return index of the key. + */ + int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false); + + /* + * @brief Get the index of the key from id_to_index_ map. + */ + inline int64_t GetIndexFromId(int64_t key) const { + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + return -1; + } else { + return iter->second; + } + } + + void SyncIndex(); + /* + * @brief Get complete Dims before + */ + pten::framework::DDim GetCompleteDims() const { + std::vector dims = vectorize(value_->dims()); + dims[0] = height_; + return pten::framework::make_ddim(dims); + } + + private: + // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here. + // SelectedRows are simply concated when adding together. Until a + // SelectedRows add a Tensor, will the duplicate rows be handled. + paddle::framework::Vector rows_; + std::unordered_map + id_to_index_; // should not be used when rows_ has duplicate member + std::unique_ptr value_{nullptr}; + int64_t height_; // height indicates the underline tensor's height + std::unique_ptr rwlock_{nullptr}; +}; + +} // namespace pten diff --git a/paddle/pten/core/utils/rw_lock.h b/paddle/pten/core/utils/rw_lock.h new file mode 100644 index 0000000000000..7bd190c901bc6 --- /dev/null +++ b/paddle/pten/core/utils/rw_lock.h @@ -0,0 +1,105 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#if !defined(_WIN32) +#include +#else +#include // NOLINT +#endif // !_WIN32 + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/platform/enforce.h" + +namespace pten { + +#if !defined(_WIN32) +struct RWLock { + RWLock() { pthread_rwlock_init(&lock_, nullptr); } + + ~RWLock() { pthread_rwlock_destroy(&lock_); } + + inline void RDLock() { + PADDLE_ENFORCE_EQ(pthread_rwlock_rdlock(&lock_), + 0, + paddle::platform::errors::External( + "The pthread failed to acquire read lock.")); + } + + inline void WRLock() { + PADDLE_ENFORCE_EQ(pthread_rwlock_wrlock(&lock_), + 0, + paddle::platform::errors::External( + "The pthread failed to acquire write lock.")); + } + + inline void UNLock() { + PADDLE_ENFORCE_EQ( + pthread_rwlock_unlock(&lock_), + 0, + paddle::platform::errors::External("The pthread failed to unlock.")); + } + + private: + pthread_rwlock_t lock_; +}; +// TODO(paddle-dev): Support RWLock for WIN32 for correctness. +#else +// https://stackoverflow.com/questions/7125250/making-pthread-rwlock-wrlock-recursive +// In windows, rw_lock seems like a hack. Use empty object and do nothing. +struct RWLock { + // FIXME(minqiyang): use mutex here to do fake lock + inline void RDLock() { mutex_.lock(); } + + inline void WRLock() { mutex_.lock(); } + + inline void UNLock() { mutex_.unlock(); } + + private: + std::mutex mutex_; +}; +#endif + +class AutoWRLock { + public: + explicit AutoWRLock(RWLock* rw_lock) : lock_(rw_lock) { Lock(); } + + ~AutoWRLock() { UnLock(); } + + private: + inline void Lock() { lock_->WRLock(); } + + inline void UnLock() { lock_->UNLock(); } + + private: + RWLock* lock_; +}; + +class AutoRDLock { + public: + explicit AutoRDLock(RWLock* rw_lock) : lock_(rw_lock) { Lock(); } + + ~AutoRDLock() { UnLock(); } + + private: + inline void Lock() { lock_->RDLock(); } + + inline void UnLock() { lock_->UNLock(); } + + private: + RWLock* lock_; +}; + +} // namespace pten diff --git a/paddle/pten/tests/core/CMakeLists.txt b/paddle/pten/tests/core/CMakeLists.txt index 117d6a29252c1..363a57f036b9b 100644 --- a/paddle/pten/tests/core/CMakeLists.txt +++ b/paddle/pten/tests/core/CMakeLists.txt @@ -4,3 +4,10 @@ cc_test(test_type_info SRCS test_type_info.cc) cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils) cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel) cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context) +cc_test(selected_rows_test SRCS test_selected_rows.cc DEPS selected_rows) +if(WITH_TESTING AND TEST selected_rows_test) + set_tests_properties(selected_rows_test PROPERTIES TIMEOUT 120) +endif() +if (NOT WIN32) +cc_test(test_rw_lock SRCS test_rw_lock.cc) +endif (NOT WIN32) diff --git a/paddle/pten/tests/core/test_rw_lock.cc b/paddle/pten/tests/core/test_rw_lock.cc new file mode 100644 index 0000000000000..5cd81fa76b40e --- /dev/null +++ b/paddle/pten/tests/core/test_rw_lock.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/utils/rw_lock.h" + +#include // NOLINT +#include // NOLINT + +namespace pten { +namespace tests { + +void f1(pten::RWLock *lock) { + lock->RDLock(); + lock->UNLock(); +} + +TEST(RWLOCK, read_read) { + pten::RWLock lock; + lock.RDLock(); + std::thread t1(f1, &lock); + std::thread t2(f1, &lock); + t1.join(); + t2.join(); + lock.UNLock(); +} + +void f2(pten::RWLock *lock, std::vector *result) { + lock->RDLock(); + ASSERT_EQ(result->size(), 0UL); + lock->UNLock(); +} + +void f3(pten::RWLock *lock, std::vector *result) { + lock->WRLock(); + result->push_back(1); + lock->UNLock(); +} + +TEST(RWLOCK, read_write) { + pten::RWLock lock; + std::vector result; + + lock.RDLock(); + std::thread t1(f2, &lock, &result); + t1.join(); + std::thread t2(f3, &lock, &result); + std::this_thread::sleep_for(std::chrono::seconds(1)); + ASSERT_EQ(result.size(), 0UL); + lock.UNLock(); + t2.join(); + ASSERT_EQ(result.size(), 1UL); +} + +void f4(pten::RWLock *lock, std::vector *result) { + lock->RDLock(); + ASSERT_EQ(result->size(), 1UL); + lock->UNLock(); +} + +TEST(RWLOCK, write_read) { + pten::RWLock lock; + std::vector result; + + lock.WRLock(); + std::thread t1(f4, &lock, &result); + std::this_thread::sleep_for(std::chrono::seconds(1)); + result.push_back(1); + lock.UNLock(); + t1.join(); +} +} // namespace tests +} // namespace pten diff --git a/paddle/pten/tests/core/test_selected_rows.cc b/paddle/pten/tests/core/test_selected_rows.cc new file mode 100644 index 0000000000000..81c7ff4a838a7 --- /dev/null +++ b/paddle/pten/tests/core/test_selected_rows.cc @@ -0,0 +1,187 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include // NOLINT + +#include "gtest/gtest.h" +#include "paddle/pten/core/selected_rows.h" + +namespace pten { +namespace tests { +class SelectedRowsTester : public ::testing::Test { + public: + void SetUp() override { + std::vector rows{0, 4, 7}; + int64_t height = 10; + int64_t row_numel = 100; + selected_rows_.reset(new SelectedRows(rows, height)); + + pten::DenseTensor* value = selected_rows_->mutable_value(); + auto* data = value->mutable_data( + pten::framework::make_ddim( + {static_cast(rows.size()), row_numel}), + place_); + for (int64_t i = 0; i < value->numel(); ++i) { + data[i] = static_cast(i); + } + } + + protected: + pten::CPUPlace place_; + std::unique_ptr selected_rows_{nullptr}; +}; + +TEST_F(SelectedRowsTester, height) { ASSERT_EQ(selected_rows_->height(), 10); } + +TEST_F(SelectedRowsTester, dims) { + ASSERT_EQ(selected_rows_->value().dims(), + pten::framework::make_ddim({3, 100})); +} + +TEST_F(SelectedRowsTester, complete_dims) { + ASSERT_EQ(selected_rows_->GetCompleteDims(), + pten::framework::make_ddim({10, 100})); +} + +TEST(SelectedRows, SparseTable) { + pten::CPUPlace cpu; + SelectedRows table; + + int64_t table_size = 100; + int64_t embedding_width = 8; + // initialize a sparse table + table.mutable_value()->Resize( + pten::framework::make_ddim({table_size, embedding_width})); + auto* data = table.mutable_value()->mutable_data(cpu); + for (int64_t i = 0; i < table_size; ++i) { + for (int64_t j = 0; j < embedding_width; ++j) { + data[i * embedding_width + j] = static_cast(i); + } + } + ASSERT_EQ(table.AutoGrownIndex(10, true, false), 0); + ASSERT_EQ(table.AutoGrownIndex(8, true, false), 1); + ASSERT_EQ(table.AutoGrownIndex(8, true, false), 1); + ASSERT_EQ(table.AutoGrownIndex(6, true, false), 2); + for (int64_t i = 11; i < 20; i++) { + ASSERT_EQ(table.AutoGrownIndex(i, true, true), -1); + ASSERT_TRUE(!table.HasKey(i)); + } + ASSERT_TRUE(table.HasKey(10)); + ASSERT_TRUE(table.HasKey(8)); + ASSERT_TRUE(table.HasKey(6)); + ASSERT_EQ(table.rows().size(), 3UL); + + pten::DenseTensor ids; + ids.Resize(pten::framework::make_ddim({4})); + auto* ids_data = ids.mutable_data(cpu); + ids_data[0] = static_cast(6); + ids_data[1] = static_cast(6); + ids_data[2] = static_cast(8); + ids_data[3] = static_cast(10); + + pten::DenseTensor get_value; + auto* value_data = get_value.mutable_data( + pten::framework::make_ddim({4, embedding_width}), cpu); + table.Get(ids, &get_value); + + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[0 * embedding_width + j], 2); + } + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[1 * embedding_width + j], 2); + } + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[2 * embedding_width + j], 1); + } + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[3 * embedding_width + j], 0); + } +} + +void f1(SelectedRows* table, int table_size) { + for (int i = 1000000; i > 0; --i) { + auto id = i % table_size; + int64_t index1 = table->AutoGrownIndex(id, true); + int64_t index2 = table->AutoGrownIndex(id, false); + int64_t index3 = table->AutoGrownIndex(id, true); + ASSERT_EQ(index1, index2); + ASSERT_EQ(index2, index3); + } +} + +void f2(SelectedRows* table, int table_size) { + for (int i = 0; i < 1000000; ++i) { + auto id = i % table_size; + int64_t index1 = table->AutoGrownIndex(id, true); + int64_t index2 = table->AutoGrownIndex(id, false); + int64_t index3 = table->AutoGrownIndex(id, true); + ASSERT_EQ(index1, index2); + ASSERT_EQ(index2, index3); + } +} + +void f3(SelectedRows* table, int table_size) { + clock_t t1 = clock(); + for (int i = 100000; i > 0; --i) { + auto id1 = table->AutoGrownIndex(i % table_size, true); + auto id2 = table->Index(i % table_size); + ASSERT_EQ(id1, id2); + } + clock_t t2 = clock(); + std::cout << "f3 run time:" << t2 - t1 << std::endl; +} + +void f4(SelectedRows* table, int table_size) { + clock_t t1 = clock(); + for (int i = 0; i < 100000; ++i) { + auto id1 = table->AutoGrownIndex(i % table_size, true); + auto id2 = table->Index(i % table_size); + ASSERT_EQ(id1, id2); + } + clock_t t2 = clock(); + std::cout << "f4 run time:" << t2 - t1 << std::endl; +} + +TEST(SelectedRows, MultiThreadAutoIndex) { + pten::CPUPlace cpu; + SelectedRows table; + + int64_t table_size = 100000; + int64_t embedding_width = 8; + // initialize a sparse table + table.mutable_value()->Resize( + pten::framework::make_ddim({table_size, embedding_width})); + auto* data = table.mutable_value()->mutable_data(cpu); + for (int64_t i = 0; i < table_size; ++i) { + for (int64_t j = 0; j < embedding_width; ++j) { + data[i * embedding_width + j] = static_cast(i); + } + } + + std::thread t1(f1, &table, table_size); + std::thread t11(f1, &table, table_size); + std::thread t2(f2, &table, table_size); + std::thread t22(f2, &table, table_size); + t1.join(); + t11.join(); + t2.join(); + t22.join(); + std::thread t3(f3, &table, table_size); + std::thread t4(f4, &table, table_size); + t3.join(); + t4.join(); +} +} // namespace tests +} // namespace pten From 7ac2f80f371a11c8b46fef1c2c6c51ab9e682931 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 22 Jan 2022 16:21:04 +0800 Subject: [PATCH 07/14] [PTen] Add attr method for ArgumentMappingContext (#39130) * add attr for arg map context * add argument fn declare * add attr test for get attr value method * polish details --- paddle/fluid/framework/CMakeLists.txt | 1 + paddle/fluid/framework/attribute.cc | 33 +++++ paddle/fluid/framework/attribute.h | 23 +++- paddle/fluid/framework/attribute_test.cc | 114 ++++++++++++++++++ paddle/fluid/framework/operator.h | 7 +- paddle/fluid/framework/pten_utils.h | 4 +- paddle/pten/core/CMakeLists.txt | 4 +- paddle/pten/core/compat/CMakeLists.txt | 1 + .../pten/core/{ => compat}/arg_map_context.cc | 31 +---- .../pten/core/{ => compat}/arg_map_context.h | 41 ++----- paddle/pten/core/kernel_def.h | 9 ++ paddle/pten/core/macros.h | 2 +- paddle/pten/core/meta_tensor.h | 3 + paddle/pten/ops/compat/scale_args_fn.h | 2 +- 14 files changed, 207 insertions(+), 68 deletions(-) create mode 100644 paddle/fluid/framework/attribute_test.cc create mode 100644 paddle/pten/core/compat/CMakeLists.txt rename paddle/pten/core/{ => compat}/arg_map_context.cc (53%) rename paddle/pten/core/{ => compat}/arg_map_context.h (79%) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 286a8684127a9..5c3b24463ef4b 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -167,6 +167,7 @@ cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor framework_proto selected_rows_utils data_device_transform data_type_transform data_layout_transform) cc_library(attribute SRCS attribute.cc DEPS framework_proto boost enforce) +cc_test(attribute_test SRCS attribute_test.cc DEPS attribute framework_proto proto_desc) cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc device_context) diff --git a/paddle/fluid/framework/attribute.cc b/paddle/fluid/framework/attribute.cc index 63934d17f9964..cf7a7c3c9f43d 100644 --- a/paddle/fluid/framework/attribute.cc +++ b/paddle/fluid/framework/attribute.cc @@ -17,6 +17,39 @@ limitations under the License. */ namespace paddle { namespace framework { +paddle::any GetAttrValue(const Attribute& attr) { + if (attr.type() == typeid(int)) { + return paddle::any(BOOST_GET_CONST(int, attr)); + } else if (attr.type() == typeid(float)) { + return paddle::any(BOOST_GET_CONST(float, attr)); + } else if (attr.type() == typeid(std::string)) { + return paddle::any(BOOST_GET_CONST(std::string, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else if (attr.type() == typeid(bool)) { + return paddle::any(BOOST_GET_CONST(bool, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else if (attr.type() == typeid(BlockDesc*)) { + return paddle::any(BOOST_GET_CONST(BlockDesc*, attr)); + } else if (attr.type() == typeid(int64_t)) { + return paddle::any(BOOST_GET_CONST(int64_t, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else if (attr.type() == typeid(std::vector)) { + return paddle::any(BOOST_GET_CONST(std::vector, attr)); + } else { + PADDLE_THROW( + platform::errors::Unimplemented("Unsupported Attribute value type.")); + } +} + Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { case proto::AttrType::BOOLEAN: { diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 37d399b7779a7..7026cc7cf1aa3 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -27,10 +27,15 @@ limitations under the License. */ #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" +#include "paddle/utils/any.h" namespace paddle { namespace framework { +paddle::any GetAttrValue(const Attribute& attr); + +Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); + template struct ExtractAttribute { explicit ExtractAttribute(const std::string& attr_name) @@ -204,8 +209,6 @@ inline proto::AttrType AttrTypeID() { return static_cast(tmp.which() - 1); } -Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); - class AttrReader { public: explicit AttrReader(const AttributeMap& attrs) @@ -234,6 +237,22 @@ class AttrReader { return *attr_value; } + inline const Attribute& GetAttr(const std::string& name) const { + auto it = attrs_.find(name); + bool found = it != attrs_.end(); + if (!found) { + if (default_attrs_ != nullptr) { + it = default_attrs_->find(name); + found = it != default_attrs_->end(); + } + } + PADDLE_ENFORCE_EQ(found, true, + platform::errors::NotFound( + "Attribute (%s) should be in AttributeMap.", name)); + + return it->second; + } + private: const AttributeMap& attrs_; const AttributeMap* default_attrs_; diff --git a/paddle/fluid/framework/attribute_test.cc b/paddle/fluid/framework/attribute_test.cc new file mode 100644 index 0000000000000..27a6afb49f5e8 --- /dev/null +++ b/paddle/fluid/framework/attribute_test.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/attribute.h" +#include "paddle/fluid/framework/program_desc.h" + +#include "gtest/gtest.h" +#include "paddle/utils/any.h" + +TEST(Attribute, GetAttrValueToAny) { + paddle::framework::Attribute x_int(100); + auto rlt_int = paddle::framework::GetAttrValue(x_int); + EXPECT_EQ(paddle::any_cast(rlt_int), 100); + + float float_value = 3.14; + paddle::framework::Attribute x_float(float_value); + auto rlt_float = paddle::framework::GetAttrValue(x_float); + EXPECT_NEAR(paddle::any_cast(rlt_float), 3.14, 1e-6); + + std::string str_value("test"); + paddle::framework::Attribute x_str(str_value); + auto rlt_str = paddle::framework::GetAttrValue(x_str); + EXPECT_EQ(paddle::any_cast(rlt_str), "test"); + + std::vector vec_int_var(2, 100); + paddle::framework::Attribute x_vec_int = vec_int_var; + auto rlt_vec_int = paddle::framework::GetAttrValue(x_vec_int); + auto vec_int = paddle::any_cast>(rlt_vec_int); + EXPECT_EQ(vec_int.size(), 2UL); + EXPECT_EQ(vec_int[0], 100); + EXPECT_EQ(vec_int[1], 100); + + std::vector vec_float_var(2, 3.14); + paddle::framework::Attribute x_vec_float = vec_float_var; + auto rlt_vec_float = paddle::framework::GetAttrValue(x_vec_float); + auto vec_float = paddle::any_cast>(rlt_vec_float); + EXPECT_EQ(vec_float.size(), 2UL); + EXPECT_NEAR(vec_float[0], 3.14, 1e-6); + EXPECT_NEAR(vec_float[1], 3.14, 1e-6); + + std::vector vec_str_var(2, "test"); + paddle::framework::Attribute x_vec_str = vec_str_var; + auto rlt_vec_str = paddle::framework::GetAttrValue(x_vec_str); + auto vec_str = paddle::any_cast>(rlt_vec_str); + EXPECT_EQ(vec_str.size(), 2UL); + EXPECT_EQ(vec_str[0], "test"); + EXPECT_EQ(vec_str[1], "test"); + + paddle::framework::Attribute x_bool(true); + auto rlt_bool = paddle::framework::GetAttrValue(x_bool); + EXPECT_EQ(paddle::any_cast(rlt_bool), true); + + std::vector vec_bool_var(2, true); + paddle::framework::Attribute x_vec_bool = vec_bool_var; + auto rlt_vec_bool = paddle::framework::GetAttrValue(x_vec_bool); + auto vec_bool = paddle::any_cast>(rlt_vec_bool); + EXPECT_EQ(vec_bool.size(), 2UL); + EXPECT_EQ(vec_bool[0], true); + EXPECT_EQ(vec_bool[1], true); + + paddle::framework::ProgramDesc prog; + paddle::framework::proto::BlockDesc proto_block; + paddle::framework::BlockDesc block_desc(&prog, &proto_block); + paddle::framework::Attribute x_block_desc(&block_desc); + auto rlt_block_desc = paddle::framework::GetAttrValue(x_block_desc); + auto block_desc_ptr = + paddle::any_cast(rlt_block_desc); + EXPECT_NE(block_desc_ptr, nullptr); + + std::vector vec_block_desc_var; + vec_block_desc_var.emplace_back(&block_desc); + paddle::framework::Attribute x_vec_block_desc(vec_block_desc_var); + auto rlt_vec_block_desc = paddle::framework::GetAttrValue(x_vec_block_desc); + auto vec_block_desc = + paddle::any_cast>( + rlt_vec_block_desc); + EXPECT_EQ(vec_block_desc.size(), 1UL); + EXPECT_NE(vec_block_desc[0], nullptr); + + int64_t int64_value = 100; + paddle::framework::Attribute x_int64(int64_value); + auto rlt_int64 = paddle::framework::GetAttrValue(x_int64); + EXPECT_EQ(paddle::any_cast(rlt_int64), 100); + + std::vector vec_int64_var(2, 100); + paddle::framework::Attribute x_vec_int64 = vec_int64_var; + auto rlt_vec_int64 = paddle::framework::GetAttrValue(x_vec_int64); + auto vec_int64 = paddle::any_cast>(rlt_vec_int64); + EXPECT_EQ(vec_int64.size(), 2UL); + EXPECT_EQ(vec_int64[0], 100); + EXPECT_EQ(vec_int64[1], 100); + + std::vector vec_double_var(2, 3.14); + paddle::framework::Attribute x_vec_double = vec_double_var; + auto rlt_vec_double = paddle::framework::GetAttrValue(x_vec_double); + auto vec_double = paddle::any_cast>(rlt_vec_double); + EXPECT_EQ(vec_double.size(), 2UL); + EXPECT_NEAR(vec_double[0], 3.14, 1e-6); + EXPECT_NEAR(vec_double[1], 3.14, 1e-6); +} diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 8e000ef9985bd..40c80ec5f2d65 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -40,7 +40,7 @@ limitations under the License. */ #include "paddle/fluid/platform/variant.h" #include "paddle/utils/flat_hash_map.h" -#include "paddle/pten/core/arg_map_context.h" +#include "paddle/pten/core/compat/arg_map_context.h" #include "paddle/pten/core/kernel_context.h" #include "paddle/pten/core/kernel_factory.h" @@ -454,8 +454,9 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext { return ctx_.HasOutput(name); } - bool HasAttr(const std::string& name) const override { - return ctx_.HasAttr(name); + paddle::any Attr(const std::string& name) const override { + auto& attr = ctx_.GetAttr(name); + return GetAttrValue(attr); } size_t InputSize(const std::string& name) const override { diff --git a/paddle/fluid/framework/pten_utils.h b/paddle/fluid/framework/pten_utils.h index a4493f3d3e5c0..ab129c6313dab 100644 --- a/paddle/fluid/framework/pten_utils.h +++ b/paddle/fluid/framework/pten_utils.h @@ -25,7 +25,7 @@ limitations under the License. */ #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/place.h" #include "paddle/pten/api/lib/utils/tensor_utils.h" -#include "paddle/pten/core/arg_map_context.h" +#include "paddle/pten/core/compat/arg_map_context.h" #include "paddle/pten/core/kernel_factory.h" #include "paddle/utils/flat_hash_map.h" #include "paddle/utils/small_vector.h" @@ -85,6 +85,6 @@ template <> struct ConvertToPtenContext { using TYPE = pten::CPUContext; }; - + } // namespace framework } // namespace paddle diff --git a/paddle/pten/core/CMakeLists.txt b/paddle/pten/core/CMakeLists.txt index f6f0e1f3e26ec..cd3a1755a9df4 100644 --- a/paddle/pten/core/CMakeLists.txt +++ b/paddle/pten/core/CMakeLists.txt @@ -1,3 +1,6 @@ +# utils used for compatible for fluid op system +add_subdirectory(compat) + if(WITH_GPU) cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info) elseif(WITH_ROCM) @@ -8,7 +11,6 @@ endif() cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce convert_utils) cc_library(kernel_context SRCS kernel_context.cc DEPS enforce pten_context) -cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce) cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce) cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector) diff --git a/paddle/pten/core/compat/CMakeLists.txt b/paddle/pten/core/compat/CMakeLists.txt new file mode 100644 index 0000000000000..253f60daf1f89 --- /dev/null +++ b/paddle/pten/core/compat/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce) diff --git a/paddle/pten/core/arg_map_context.cc b/paddle/pten/core/compat/arg_map_context.cc similarity index 53% rename from paddle/pten/core/arg_map_context.cc rename to paddle/pten/core/compat/arg_map_context.cc index d7aea11ddf043..3914a8a684eda 100644 --- a/paddle/pten/core/arg_map_context.cc +++ b/paddle/pten/core/compat/arg_map_context.cc @@ -12,41 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/pten/core/arg_map_context.h" +#include "paddle/pten/core/compat/arg_map_context.h" -#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/string/string_helper.h" namespace pten { -OpArgumentMappingFnMap& OpArgumentMappingFnMap::Instance() { - static OpArgumentMappingFnMap g_op_arg_mapping_fn_map; - return g_op_arg_mapping_fn_map; -} - -bool OpArgumentMappingFnMap::Has(const std::string& op_type) const { - return fn_map_.find(op_type) != fn_map_.end(); -} - -const ArgumentMappingFn& OpArgumentMappingFnMap::Get( - const std::string& op_type) const { - auto it = fn_map_.find(op_type); - PADDLE_ENFORCE_NE( - it, - fn_map_.end(), - paddle::platform::errors::NotFound( - "Operator `%s`'s argument mapping funciton is not registered.", - op_type)); - return it->second; -} - -void OpArgumentMappingFnMap::Emplace(const std::string& op_type, - const std::string api_name, - ArgumentMappingFn fn) { - name_map_.emplace(op_type, api_name); - fn_map_.emplace(op_type, fn); -} - std::ostream& operator<<(std::ostream& os, KernelSignature signature) { os << "Kernel Signature - name: " << signature.name << "; inputs: " << paddle::string::join_strings(std::get<0>(signature.args), ", ") diff --git a/paddle/pten/core/arg_map_context.h b/paddle/pten/core/compat/arg_map_context.h similarity index 79% rename from paddle/pten/core/arg_map_context.h rename to paddle/pten/core/compat/arg_map_context.h index be9eb3af76a36..e7dfc0706544c 100644 --- a/paddle/pten/core/arg_map_context.h +++ b/paddle/pten/core/compat/arg_map_context.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include +#include "paddle/utils/any.h" #include "paddle/utils/flat_hash_map.h" #include "paddle/utils/small_vector.h" @@ -28,22 +29,6 @@ using KernelArgsTuple = std::tuple, paddle::SmallVector, paddle::SmallVector>; -// TODO(chenweihang): Add more methods if needed in future -class ArgumentMappingContext { - public: - virtual ~ArgumentMappingContext() = default; - - virtual bool HasInput(const std::string& name) const = 0; - virtual bool HasOutput(const std::string& name) const = 0; - virtual bool HasAttr(const std::string& name) const = 0; - - virtual size_t InputSize(const std::string& name) const = 0; - virtual size_t OutputSize(const std::string& name) const = 0; - - virtual bool IsDenseTensorInput(const std::string& name) const = 0; - virtual bool IsSelectedRowsInput(const std::string& name) const = 0; -}; - struct KernelSignature { std::string name; KernelArgsTuple args; @@ -64,23 +49,23 @@ struct KernelSignature { std::ostream& operator<<(std::ostream& os, KernelSignature signature); -using ArgumentMappingFn = KernelSignature (*)(const ArgumentMappingContext&); - -class OpArgumentMappingFnMap { +// TODO(chenweihang): Add more methods if needed in future +class ArgumentMappingContext { public: - static OpArgumentMappingFnMap& Instance(); + virtual ~ArgumentMappingContext() = default; - bool Has(const std::string& op_type) const; + virtual bool HasInput(const std::string& name) const = 0; + virtual bool HasOutput(const std::string& name) const = 0; - const ArgumentMappingFn& Get(const std::string& op_type) const; + // now we can't use Attribute here, it will cause pten relay on + // boost::variant and BlockDesc + virtual paddle::any Attr(const std::string& name) const = 0; - void Emplace(const std::string& op_type, - const std::string api_name, - ArgumentMappingFn fn); + virtual size_t InputSize(const std::string& name) const = 0; + virtual size_t OutputSize(const std::string& name) const = 0; - private: - paddle::flat_hash_map name_map_; - paddle::flat_hash_map fn_map_; + virtual bool IsDenseTensorInput(const std::string& name) const = 0; + virtual bool IsSelectedRowsInput(const std::string& name) const = 0; }; } // namespace pten diff --git a/paddle/pten/core/kernel_def.h b/paddle/pten/core/kernel_def.h index 875083cfb59e3..3884bb55e47b8 100644 --- a/paddle/pten/core/kernel_def.h +++ b/paddle/pten/core/kernel_def.h @@ -14,16 +14,25 @@ #pragma once +#include + namespace pten { class Kernel; class KernelKey; class KernelArgsDef; class KernelContext; +class KernelSignature; +class ArgumentMappingContext; +class InferMetaContext; using KernelFn = void (*)(KernelContext* ctx); using KernelArgsDefFn = void (*)(Kernel* kernel); using KernelArgsParseFn = void (*)(const KernelKey& default_key, KernelArgsDef* args_def); +using ArgumentMappingFn = + std::function; +using InferMetaFn = void (*)(InferMetaContext* ctx); + } // namespace pten diff --git a/paddle/pten/core/macros.h b/paddle/pten/core/macros.h index fec67b1a3dc25..20a39fdda2ced 100644 --- a/paddle/pten/core/macros.h +++ b/paddle/pten/core/macros.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/pten/core/meta_tensor.h b/paddle/pten/core/meta_tensor.h index 4273aa6f85b4e..442ff4137de42 100644 --- a/paddle/pten/core/meta_tensor.h +++ b/paddle/pten/core/meta_tensor.h @@ -47,7 +47,10 @@ class MetaTensor { virtual void share_lod(const MetaTensor& meta_tensor); private: + // Because the lod in compiletime and runtime is different, + // so `LoD` cannot in public methods const LoD& lod() const; + TensorBase* tensor_; }; diff --git a/paddle/pten/ops/compat/scale_args_fn.h b/paddle/pten/ops/compat/scale_args_fn.h index b9a20400f971a..91f0db389d9d5 100644 --- a/paddle/pten/ops/compat/scale_args_fn.h +++ b/paddle/pten/ops/compat/scale_args_fn.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/pten/core/arg_map_context.h" +#include "paddle/pten/core/compat/arg_map_context.h" namespace pten { From ec24bc989c3f55695c281dc177930c7a0e74c09a Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 22 Jan 2022 17:10:28 +0800 Subject: [PATCH 08/14] add get inout var ptr for dygraph (#39134) --- .../fluid/eager/legacy/infer_shape_context.h | 23 +++++++++++++++---- paddle/fluid/imperative/infer_shape_context.h | 23 +++++++++++++++---- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/eager/legacy/infer_shape_context.h b/paddle/fluid/eager/legacy/infer_shape_context.h index a1032fd404f85..0979abc63d658 100644 --- a/paddle/fluid/eager/legacy/infer_shape_context.h +++ b/paddle/fluid/eager/legacy/infer_shape_context.h @@ -222,17 +222,30 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext { paddle::framework::DataLayout::kMKLDNN)); } - // TODO(paddle-dev): Can this be template? std::vector GetInputVarPtrs( const std::string& name) const override { - PADDLE_THROW(paddle::platform::errors::PermissionDenied( - "GetInputVarPtrs not support in dygraph runtime context")); + std::vector res; + auto it = tensor_in_->find(name); + PADDLE_ENFORCE_NE(it, tensor_in_->end(), + paddle::platform::errors::NotFound( + "Can not find [%s] in inputs.", name)); + for (auto& tensor : it->second) { + res.emplace_back(tensor->MutableVar()); + } + return res; } std::vector GetOutputVarPtrs( const std::string& name) const override { - PADDLE_THROW(paddle::platform::errors::PermissionDenied( - "GetOutputVarPtrs not support in dygraph runtime context")); + std::vector res; + auto it = tensor_out_->find(name); + PADDLE_ENFORCE_NE(it, tensor_out_->end(), + paddle::platform::errors::NotFound( + "Can not find [%s] in outputs.", name)); + for (auto& tensor : it->second) { + res.emplace_back(tensor->MutableVar()); + } + return res; } DDim GetInputDim(const std::string& name) const override { diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index a16ad1688fbac..71f7fb7387eff 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -220,17 +220,30 @@ class DygraphInferShapeContext : public framework::InferShapeContext { (op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN)); } - // TODO(paddle-dev): Can this be template? std::vector GetInputVarPtrs( const std::string& name) const override { - PADDLE_THROW(platform::errors::PermissionDenied( - "GetInputVarPtrs not support in dygraph runtime context")); + std::vector res; + auto it = var_base_map_in_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_in_->end(), + platform::errors::NotFound("Can not find [%s] in inputs.", name)); + for (auto& var : it->second) { + res.emplace_back(var->MutableVar()); + } + return res; } std::vector GetOutputVarPtrs( const std::string& name) const override { - PADDLE_THROW(platform::errors::PermissionDenied( - "GetOutputVarPtrs not support in dygraph runtime context")); + std::vector res; + auto it = var_base_map_out_->find(name); + PADDLE_ENFORCE_NE( + it, var_base_map_out_->end(), + platform::errors::NotFound("Can not find [%s] in outputs.", name)); + for (auto& var : it->second) { + res.emplace_back(var->MutableVar()); + } + return res; } DDim GetInputDim(const std::string& name) const override { From 60df925419086bffce3478be484398c2bdfbc592 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 22 Jan 2022 21:08:26 +0800 Subject: [PATCH 09/14] remove useless cmake list (#39141) --- paddle/pten/kernels/CMakeLists.txt | 14 -------------- paddle/pten/kernels/cpu/CMakeLists.txt | 0 paddle/pten/kernels/dnnl/CMakeLists.txt | 0 paddle/pten/kernels/gpu/CMakeLists.txt | 0 paddle/pten/kernels/primitive/CMakeLists.txt | 0 paddle/pten/kernels/xpu/CMakeLists.txt | 0 6 files changed, 14 deletions(-) delete mode 100644 paddle/pten/kernels/cpu/CMakeLists.txt delete mode 100644 paddle/pten/kernels/dnnl/CMakeLists.txt delete mode 100644 paddle/pten/kernels/gpu/CMakeLists.txt delete mode 100644 paddle/pten/kernels/primitive/CMakeLists.txt delete mode 100644 paddle/pten/kernels/xpu/CMakeLists.txt diff --git a/paddle/pten/kernels/CMakeLists.txt b/paddle/pten/kernels/CMakeLists.txt index f838c4d424c2d..999f72a7e6b65 100644 --- a/paddle/pten/kernels/CMakeLists.txt +++ b/paddle/pten/kernels/CMakeLists.txt @@ -2,23 +2,9 @@ set(kernel_declare_file ${PADDLE_BINARY_DIR}/paddle/pten/kernels/declarations.h. set(kernel_declare_file_final ${PADDLE_BINARY_DIR}/paddle/pten/kernels/declarations.h) file(WRITE ${kernel_declare_file} "// Generated by the paddle/pten/kernels/CMakeLists.txt. DO NOT EDIT!\n\n#pragma once\n\n") -# kernel primitive api -add_subdirectory(primitive) # pten functors and functions called by kernels add_subdirectory(funcs) -add_subdirectory(cpu) -if(WITH_GPU OR WITH_ROCM) - add_subdirectory(gpu) -endif() -if(WITH_MKLDNN) - # mkldnn will be deprecated and use the new name dnnl - add_subdirectory(dnnl) -endif() -if(WITH_XPU) - add_subdirectory(xpu) -endif() - # pten depends all pten kernel targets set_property(GLOBAL PROPERTY PTEN_KERNELS "") diff --git a/paddle/pten/kernels/cpu/CMakeLists.txt b/paddle/pten/kernels/cpu/CMakeLists.txt deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/paddle/pten/kernels/dnnl/CMakeLists.txt b/paddle/pten/kernels/dnnl/CMakeLists.txt deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/paddle/pten/kernels/gpu/CMakeLists.txt b/paddle/pten/kernels/gpu/CMakeLists.txt deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/paddle/pten/kernels/primitive/CMakeLists.txt b/paddle/pten/kernels/primitive/CMakeLists.txt deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/paddle/pten/kernels/xpu/CMakeLists.txt b/paddle/pten/kernels/xpu/CMakeLists.txt deleted file mode 100644 index e69de29bb2d1d..0000000000000 From 36d9a364d2822c8034cdd41c1f06cd9feefd2fdb Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 22 Jan 2022 21:08:38 +0800 Subject: [PATCH 10/14] add infershape utils (#39140) --- paddle/fluid/framework/CMakeLists.txt | 5 +- paddle/fluid/framework/infershape_utils.cc | 190 +++++++++++++++++++++ paddle/fluid/framework/infershape_utils.h | 44 +++++ 3 files changed, 237 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/framework/infershape_utils.cc create mode 100644 paddle/fluid/framework/infershape_utils.h diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 5c3b24463ef4b..0220e5fd59476 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -192,11 +192,11 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va IF(WITH_XPU) cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils - pten pten_utils kernel_factory) + pten pten_utils kernel_factory infershape_utils) ELSE() cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils - pten pten_utils kernel_factory) + pten pten_utils kernel_factory infershape_utils) ENDIF() cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) @@ -408,6 +408,7 @@ cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tens cc_library(generator SRCS generator.cc DEPS enforce place) cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info) +cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference) # Get the current working branch execute_process( diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc new file mode 100644 index 0000000000000..9a91a5208ebbc --- /dev/null +++ b/paddle/fluid/framework/infershape_utils.cc @@ -0,0 +1,190 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/infershape_utils.h" + +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/pten/core/compat/arg_map_context.h" +#include "paddle/pten/core/compat_utils.h" +#include "paddle/pten/core/convert_utils.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/meta_tensor.h" + +namespace paddle { +namespace framework { + +class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext { + public: + explicit InferShapeArgumentMappingContext(const InferShapeContext& ctx) + : ctx_(ctx) {} + + bool HasInput(const std::string& name) const override { + return ctx_.HasInput(name); + } + + bool HasOutput(const std::string& name) const override { + return ctx_.HasOutput(name); + } + + paddle::any Attr(const std::string& name) const override { + auto& attr = ctx_.Attrs().GetAttr(name); + return GetAttrValue(attr); + } + + size_t InputSize(const std::string& name) const override { + return ctx_.Inputs(name).size(); + } + + size_t OutputSize(const std::string& name) const override { + return ctx_.Outputs(name).size(); + } + + bool IsDenseTensorInput(const std::string& name) const override { + auto var_types = ctx_.GetInputsVarType(name); + return var_types[0] == proto::VarType::LOD_TENSOR; + } + + bool IsSelectedRowsInput(const std::string& name) const override { + auto var_types = ctx_.GetInputsVarType(name); + return var_types[0] == proto::VarType::SELECTED_ROWS; + } + + private: + const InferShapeContext& ctx_; +}; + +// TODO(chenweihang): Support SelectedRows later +// TODO(chenweihang): Support TensorArray later +class CompatMetaTensor : public pten::MetaTensor { + public: + CompatMetaTensor(InferShapeVarPtr var, bool is_runtime) + : var_(std::move(var)), is_runtime_(is_runtime) {} + + CompatMetaTensor() = default; + CompatMetaTensor(const CompatMetaTensor&) = default; + CompatMetaTensor(CompatMetaTensor&&) = default; + CompatMetaTensor& operator=(const CompatMetaTensor&) = delete; + CompatMetaTensor& operator=(CompatMetaTensor&&) = delete; + + int64_t numel() const override { + if (is_runtime_) { + auto* var = BOOST_GET_CONST(Variable*, var_); + return var->Get().numel(); + } else { + auto* var = BOOST_GET_CONST(VarDesc*, var_); + return var->ElementSize(); + } + } + + DDim dims() const override { + if (is_runtime_) { + auto* var = BOOST_GET_CONST(Variable*, var_); + return var->Get().dims(); + } else { + auto* var = BOOST_GET_CONST(VarDesc*, var_); + return make_ddim(var->GetShape()); + } + } + + pten::DataType dtype() const override { + if (is_runtime_) { + auto* var = BOOST_GET_CONST(Variable*, var_); + return var->Get().dtype(); + } else { + auto* var = BOOST_GET_CONST(VarDesc*, var_); + return pten::TransToPtenDataType(var->GetDataType()); + } + } + + DataLayout layout() const override { + if (is_runtime_) { + auto* var = BOOST_GET_CONST(Variable*, var_); + return var->Get().layout(); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported get layout for VarDesc now.")); + } + } + + void set_dims(const DDim& dims) override { + if (is_runtime_) { + auto* var = BOOST_GET(Variable*, var_); + LoDTensor* tensor = var->GetMutable(); + pten::CompatibleDenseTensorUtils::GetMutableMeta( + static_cast(tensor)) + ->dims = dims; + } else { + auto* var = BOOST_GET(VarDesc*, var_); + var->SetShape(vectorize(dims)); + } + } + + void set_dtype(pten::DataType dtype) override { + if (is_runtime_) { + auto* var = BOOST_GET(Variable*, var_); + LoDTensor* tensor = var->GetMutable(); + pten::CompatibleDenseTensorUtils::GetMutableMeta( + static_cast(tensor)) + ->dtype = dtype; + } else { + auto* var = BOOST_GET(VarDesc*, var_); + var->SetDataType(pten::TransToProtoVarType(dtype)); + } + } + + void set_layout(DataLayout layout) override { + if (is_runtime_) { + auto* var = BOOST_GET(Variable*, var_); + LoDTensor* tensor = var->GetMutable(); + pten::CompatibleDenseTensorUtils::GetMutableMeta( + static_cast(tensor)) + ->layout = layout; + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported set layout for VarDesc now.")); + } + } + + void share_lod(const MetaTensor& meta_tensor) override { + if (is_runtime_) { + auto* var = BOOST_GET(Variable*, var_); + LoDTensor* tensor = var->GetMutable(); + pten::CompatibleDenseTensorUtils::GetMutableMeta( + static_cast(tensor)) + ->lod = + static_cast(meta_tensor).GetRuntimeLoD(); + } else { + auto* var = BOOST_GET(VarDesc*, var_); + var->SetLoDLevel(static_cast(meta_tensor) + .GetCompileTimeLoD()); + } + } + + private: + const LoD& GetRuntimeLoD() const { + auto* var = BOOST_GET_CONST(Variable*, var_); + return var->Get().lod(); + } + int32_t GetCompileTimeLoD() const { + auto* var = BOOST_GET_CONST(VarDesc*, var_); + return var->GetLoDLevel(); + } + + InferShapeVarPtr var_; + bool is_runtime_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/infershape_utils.h b/paddle/fluid/framework/infershape_utils.h new file mode 100644 index 0000000000000..f943989523e50 --- /dev/null +++ b/paddle/fluid/framework/infershape_utils.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/op_info.h" +#include "paddle/fluid/framework/shape_inference.h" + +namespace pten { +class InferMetaContext; +} // namespace pten + +namespace paddle { +namespace framework { + +// TODO(chenweihang): impl this function in next PR +pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, + const std::string& op_type); + +#define DELCARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \ + struct functor_name : public paddle::framework::InferShapeBase { \ + void operator()( \ + paddle::framework::InferShapeContext* ctx) const override { \ + auto infer_meta_context = \ + paddle::framework::BuildInferMetaContext(ctx, #op_type); \ + fn(&infer_meta_context); \ + } \ + } + +} // namespace framework +} // namespace paddle From 85334b04b92870f44c0f0266f56b5cfcf0b7df02 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sun, 23 Jan 2022 10:45:58 +0800 Subject: [PATCH 11/14] [PTen] Add infermeta utils for register infermeta funtion (#39135) * add infermeta utils for register infermeta * polish license format --- paddle/pten/core/CMakeLists.txt | 1 + paddle/pten/core/infermeta_utils.cc | 73 ++++++++++++ paddle/pten/core/infermeta_utils.h | 170 ++++++++++++++++++++++++++++ 3 files changed, 244 insertions(+) create mode 100644 paddle/pten/core/infermeta_utils.cc create mode 100644 paddle/pten/core/infermeta_utils.h diff --git a/paddle/pten/core/CMakeLists.txt b/paddle/pten/core/CMakeLists.txt index cd3a1755a9df4..181012732fa35 100644 --- a/paddle/pten/core/CMakeLists.txt +++ b/paddle/pten/core/CMakeLists.txt @@ -19,6 +19,7 @@ cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tens cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base ) cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor) +cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor) cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) diff --git a/paddle/pten/core/infermeta_utils.cc b/paddle/pten/core/infermeta_utils.cc new file mode 100644 index 0000000000000..9f0037d18edf6 --- /dev/null +++ b/paddle/pten/core/infermeta_utils.cc @@ -0,0 +1,73 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/infermeta_utils.h" + +namespace pten { + +void InferMetaContext::SetMetaConfig(MetaConfig config) { + config_ = std::move(config); +} + +void InferMetaContext::EmplaceBackInput( + std::shared_ptr input) { + int index = inputs_.size(); + inputs_.emplace_back(std::move(input)); + input_range_.emplace_back(std::pair(index, index + 1)); +} +void InferMetaContext::EmplaceBackOutput( + std::shared_ptr output) { + int index = outputs_.size(); + outputs_.emplace_back(std::move(output)); + output_range_.emplace_back(std::pair(index, index + 1)); +} +void InferMetaContext::EmplaceBackAttr(paddle::any attr) { + attrs_.emplace_back(std::move(attr)); +} + +void InferMetaContext::EmplaceBackInputs( + paddle::SmallVector> inputs) { + int index = inputs_.size(); + input_range_.emplace_back(std::pair(index, index + inputs.size())); + inputs_.insert(inputs_.end(), + std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); +} +void InferMetaContext::EmplaceBackOutputs( + paddle::SmallVector> outputs) { + int index = outputs_.size(); + output_range_.emplace_back( + std::pair(index, index + outputs.size())); + outputs_.insert(outputs_.end(), + std::make_move_iterator(outputs.begin()), + std::make_move_iterator(outputs.end())); +} + +const std::pair& InferMetaContext::InputRangeAt(size_t idx) const { + return input_range_.at(idx); +} +const std::pair& InferMetaContext::OutputRangeAt(size_t idx) const { + return output_range_.at(idx); +} + +const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; } + +const MetaTensor& InferMetaContext::InputAt(size_t idx) const { + return *inputs_.at(idx); +} +MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { + return outputs_.at(idx).get(); +} + +} // namespace pten diff --git a/paddle/pten/core/infermeta_utils.h b/paddle/pten/core/infermeta_utils.h new file mode 100644 index 0000000000000..c6812dee92b6a --- /dev/null +++ b/paddle/pten/core/infermeta_utils.h @@ -0,0 +1,170 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "paddle/pten/core/meta_tensor.h" +#include "paddle/utils/small_vector.h" + +namespace pten { + +// TODO(chenweihang): add other flags if needed +struct MetaConfig { + bool is_runtime{true}; + + MetaConfig() = default; + + // supporting implicit construction is easier to use + MetaConfig(bool is_runtime) : is_runtime(is_runtime) {} // NOLINT +}; + +class InferMetaContext { + public: + InferMetaContext() = default; + explicit InferMetaContext(MetaConfig config) : config_(config) {} + + void SetMetaConfig(MetaConfig config); + void EmplaceBackInput(std::shared_ptr input); + void EmplaceBackOutput(std::shared_ptr output); + void EmplaceBackAttr(paddle::any attr); + + void EmplaceBackInputs( + paddle::SmallVector> inputs); + void EmplaceBackOutputs( + paddle::SmallVector> outputs); + + const std::pair& InputRangeAt(size_t idx) const; + const std::pair& OutputRangeAt(size_t idx) const; + + const MetaConfig& GetMetaConfig() const; + const MetaTensor& InputAt(size_t idx) const; + MetaTensor* MutableOutputAt(size_t idx); + + template + AttrType AttrAt(size_t idx) { + try { + return paddle::any_cast(attrs_.at(idx)); + } catch (paddle::bad_any_cast&) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Attribute cast error in InferMeta Context.")); + } + } + + private: + MetaConfig config_; + + // NOTE(chenweihang): Because the MetaTensor is a base class, and MetaTensor + // objects are all created in each round, so we have to use smart pointer + // here, maybe we can implemented a new InferMetaContext and a series utils + // specifically for fluid to avoid using shared_ptr + paddle::SmallVector> inputs_; + paddle::SmallVector> outputs_; + paddle::SmallVector attrs_; + + paddle::SmallVector> input_range_; + paddle::SmallVector> output_range_; +}; + +#define PT_INFER_META(...) \ + ::pten::InferMetaFnImpl::Call + +#define PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(attr_type) \ + template \ + struct InferMetaFnCallHelper { \ + template \ + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { \ + static_assert(out_idx == 0, \ + "InferMeta's Attributes should appear before Outputs."); \ + attr_type arg = ctx->AttrAt(attr_idx); \ + InferMetaFnCallHelper< \ + Tail...>::template Call(pargs..., \ + arg); \ + } \ + } + +template +struct InferMetaTypeTag {}; + +template +struct InferMetaFnImpl; + +template +struct InferMetaFnImpl { + static void Call(InferMetaContext* ctx) { + InferMetaFnCallHelper>::template Call<0, 0, 0>(ctx); + } + + private: + template + struct InferMetaFnCallHelper; + + template + struct InferMetaFnCallHelper { + template + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { + static_assert(attr_idx == 0, + "InferMeta's Input should appear before Attributes."); + static_assert(out_idx == 0, + "InferMeta's Input should appear before Outputs."); + const std::pair range = ctx->InputRangeAt(in_idx); + const MetaTensor& arg = ctx->InputAt(range.first); + InferMetaFnCallHelper< + Tail...>::template Call(ctx, + pargs..., + arg); + } + }; + + // TODO(chenweihang): support vector input later + + template + struct InferMetaFnCallHelper { + template + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { + const std::pair range = ctx->OutputRangeAt(out_idx); + MetaTensor* arg = ctx->MutableOutputAt(range.first); + InferMetaFnCallHelper< + Tail...>::template Call(ctx, + pargs..., + arg); + } + }; + + // TODO(chenweihang): support vector output later + + template + struct InferMetaFnCallHelper { + template + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { + const MetaConfig& arg = ctx->GetMetaConfig(); + InferMetaFnCallHelper::template Call( + ctx, pargs..., arg); + } + }; + + /* End case */ + template + struct InferMetaFnCallHelper> { + template + static void Call(InferMetaContext* ctx, Args&... args) { + return infer_meta_fn(args...); + } + }; +}; + +} // namespace pten From 8c5c1046dd02d41dfcb2a012ccd62f4d90b59a18 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Sun, 23 Jan 2022 12:11:32 +0800 Subject: [PATCH 12/14] Support test_imperative apply and Add a setter for EagerTensor (#39016) * Rearranged Eager AutoCodeGen directory structure * Removed USE_OP in Eager AutoCodeGen * Enabled generation for Operators without Grad/Inputs/Outputs * Resolved operators without input * Fixed merge conflicts * Enabled Eager AutoCodeGen for 10+ more operators * Refactored Eager AutoCodeGen with more organized helper objects * Enabled Eager AutoCodeGen for operators with multiple OpBases * Adjusted Eager AutoCodeGen to Enable Passing Output Tensor as Input Argument * Handled Dispensable Inputs/Outputs in Eager AutoCodeGen * Adjusted function generation/call between Python-C API & Dygraph API * Synchronized auto-generated Python-C API with Dygraph Forward Functions * support more eager tensor api * fix merge compile error * fix compile error and fit develop code * support pure CPU * fix some logic error in eager_mode * support _varbase_creator in eager mode * Added safe_initialized interface to EagerTensor for use in processing dispensable inputs * for eager mode * refine * support multiple constructor for eager tensor * add place related code * polish code * specific randint with dtype of int64 * Support pure cpu test * eager logic * refine test in pure cpu * eager logic * eager logic * eager logic, test=develop * skip core.eager when in inference, test=develop * refine, test=develop * refine, test=develop * call RetainGrad after run forward kernel, test=develop * refine, test=develop * support dygraph util, meta, guard test * eager test case * support inference test * refine test and fix initializer failed * modify eagertensor patch method * add eagertensor.clear_grandint, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * support create varbase and fix retain grad error * call monkey_patch_varbase in _test_eager_guard, test=develop * fix windows error * split clear_gradient to clear_gradient and zero_grads, test=develop * refine, test=develop * refine, test=develop * support test_imperative_basic test in eager mode * remove additional log in variable.h * remove additional log in variable.h * remove additional code create in merge * eager * fix some eager logic, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * patch_tensor_method_func, test=develop * refine, test=develop * eager test case, test=develop * refine, test=develop * eager, test=develop * eager, test=develop * eager optimizer, test=develop * eager optimizer, test=develop * eager test_imperative_optimizer_v2, test=develop * eager, test=develop * refine, test=develop * refine, test=develop * eager, test=develop * add resize in share buffer to, test=develop * eager, test=develop * fix _share_buffer_to, test=develop * refine, test=develop * refine, test=develop * support eager for dataloader,test=develop * Exposed EagerTensor's set func to implement set_value func * Rename set to _set_value, Supplement the corresponding test case * fix test concat dev api build failed * fix conflict * fix conflict * Use extern to Polish code Co-authored-by: jim19930609 Co-authored-by: JiabinYang <360788950@qq.com> Co-authored-by: Wang Huan Co-authored-by: wanghuancoder Co-authored-by: chentianyu03 --- paddle/fluid/pybind/eager_method.cc | 20 +++++++++++++ .../fluid/dygraph/varbase_patch_methods.py | 4 +++ .../tests/unittests/test_egr_python_api.py | 29 +++++++++++++++++++ .../unittests/test_imperative_layer_apply.py | 8 ++++- 4 files changed, 60 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index b254b5d41d3ab..bb3464665411c 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -34,6 +34,10 @@ limitations under the License. */ namespace paddle { namespace pybind { +extern void InitEagerTensorWithNumpyValue(EagerTensorObject* self, + const pybind11::object& array, + bool zero_copy); + extern PyTypeObject* p_eager_tensor_type; static PyObject* eager_tensor_method_numpy(EagerTensorObject* self, @@ -359,6 +363,20 @@ static PyObject* eager_tensor_method_get_underline_tensor( EAGER_CATCH_AND_THROW_RETURN_NULL } +// NOTE(wuweilong): Set value and not change self's original place +static PyObject* eager_tensor_method_set_value(EagerTensorObject* self, + PyObject* args, + PyObject* kwargs) { + EAGER_TRY + VLOG(4) << "Value " << self->eager_tensor.name(); + pybind11::object numpy_value = + pybind11::object(pybind11::handle(PyTuple_GET_ITEM(args, 0)), true); + InitEagerTensorWithNumpyValue(self, numpy_value, false); + Py_INCREF(Py_None); + return Py_None; + EAGER_CATCH_AND_THROW_RETURN_NULL +} + PyMethodDef variable_methods[] = { {"numpy", (PyCFunction)(void (*)(void))eager_tensor_method_numpy, METH_VARARGS | METH_KEYWORDS, NULL}, @@ -393,6 +411,8 @@ PyMethodDef variable_methods[] = { {"get_tensor", (PyCFunction)(void (*)(void))eager_tensor_method_get_underline_tensor, METH_VARARGS | METH_KEYWORDS, NULL}, + {"_set_value", (PyCFunction)(void (*)(void))eager_tensor_method_set_value, + METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL, 0, NULL}}; } // namespace pybind diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 8fc6bd818bc8f..f5d569828775e 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -180,6 +180,10 @@ def set_value(self, value): "Variable dtype not match, Variable [ {} ] need tensor with dtype {} but load tensor with dtype {}".format( self.name, self_tensor_np.dtype, value_np.dtype) + # NOTE(wuweilong): self could be VarBase or EagerTensor, the subsequent behavior are defined in different files + # if self is VarBase, method value() return Variable that bindded in imperative.cc, get_tensor() bindded in pybind.cc + # if self is EagerTensor, method value() return self that defined in this file, get_tensor() defined in eager_method.cc + # this Interface behavior will be unifed in the future. self.value().get_tensor().set(value_np, framework._current_expected_place()) diff --git a/python/paddle/fluid/tests/unittests/test_egr_python_api.py b/python/paddle/fluid/tests/unittests/test_egr_python_api.py index ba0421d6eb32d..d6bf768bee774 100644 --- a/python/paddle/fluid/tests/unittests/test_egr_python_api.py +++ b/python/paddle/fluid/tests/unittests/test_egr_python_api.py @@ -763,6 +763,24 @@ def test_value(self): paddle.fluid.framework._current_expected_place()) self.assertTrue(egr_tensor0.value().get_tensor()._is_initialized()) + def test_set_value(self): + with _test_eager_guard(): + ori_arr = np.random.rand(4, 16, 16, 32).astype('float32') + egr_tensor = core.eager.EagerTensor(value=ori_arr) + self.assertEqual(egr_tensor.stop_gradient, True) + self.assertEqual(egr_tensor.shape, [4, 16, 16, 32]) + self.assertTrue(np.array_equal(egr_tensor.numpy(), ori_arr)) + ori_place = egr_tensor.place + + new_arr = np.random.rand(4, 4, 16, 32).astype('float32') + self.assertFalse(np.array_equal(egr_tensor.numpy(), new_arr)) + + egr_tensor._set_value(new_arr) + self.assertEqual(egr_tensor.stop_gradient, True) + self.assertTrue(egr_tensor.place._equals(ori_place)) + self.assertEqual(egr_tensor.shape, [4, 4, 16, 32]) + self.assertTrue(np.array_equal(egr_tensor.numpy(), new_arr)) + class EagerParamBaseUsageTestCase(unittest.TestCase): def test_print(self): @@ -856,6 +874,17 @@ def test_backward_with_single_tensor(self): egr_tensor12.backward() self.assertTrue(np.array_equal(egr_tensor12.gradient(), arr)) + def test_set_value(self): + with _test_eager_guard(): + linear = paddle.nn.Linear(1, 3) + ori_place = linear.weight.place + new_weight = np.ones([1, 3]).astype('float32') + self.assertFalse(np.array_equal(linear.weight.numpy(), new_weight)) + + linear.weight._set_value(new_weight) + self.assertTrue(np.array_equal(linear.weight.numpy(), new_weight)) + self.assertTrue(linear.weight.place._equals(ori_place)) + class EagerGuardTestCase(unittest.TestCase): def test__test_eager_guard(self): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py b/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py index c18dab61fc5ab..0bc56294876d3 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_layer_apply.py @@ -21,6 +21,7 @@ import paddle.fluid as fluid import numpy as np +from paddle.fluid.framework import _test_eager_guard class LeNetDygraph(fluid.dygraph.Layer): @@ -70,7 +71,7 @@ def init_weights(layer): class TestLayerApply(unittest.TestCase): - def test_apply_init_weight(self): + def func_apply_init_weight(self): with fluid.dygraph.guard(): net = LeNetDygraph() @@ -84,6 +85,11 @@ def test_apply_init_weight(self): np.testing.assert_allclose(layer.weight.numpy(), 0.7) np.testing.assert_allclose(layer.bias.numpy(), -0.2) + def test_apply_init_weight(self): + with _test_eager_guard(): + self.func_apply_init_weight() + self.func_apply_init_weight() + if __name__ == '__main__': unittest.main() From c3796061c385491738ef7e27e0e89fccd75877f3 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Mon, 24 Jan 2022 09:57:43 +0800 Subject: [PATCH 13/14] Refactored python-level trace_op to call through _C_ops instead of Tracer::TraceOp, under eager_mode (#38338) * Replaced core.ops with _C_ops * Refactored python-level trace_op to call through _C_ops instead of Tracer::TraceOp, under eager_mode * Modified trace_op interface * Refactored trace_op logic for eager mode * Added Eager Dygraph support for OpTest * Fixed ci issues * Fixed CI failures * Fixed Coverage CI Issues * Fixed XPU CI Issues --- .../auto_code_generator/eager_generator.cc | 33 ++++- paddle/fluid/pybind/eager_method.cc | 27 ++++ .../pybind/eager_op_function_generator.cc | 19 +++ python/paddle/fluid/dygraph/tracer.py | 85 ++++++++++- .../fluid/tests/unittests/CMakeLists.txt | 4 + .../paddle/fluid/tests/unittests/op_test.py | 140 ++++++++++++++++-- .../fluid/tests/unittests/op_test_xpu.py | 6 +- .../fluid/tests/unittests/test_diag_v2.py | 2 +- .../fluid/tests/unittests/test_digamma_op.py | 4 +- .../tests/unittests/test_eager_trace_op.py | 50 +++++++ .../fluid/tests/unittests/test_trunc_op.py | 4 +- 11 files changed, 347 insertions(+), 27 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_eager_trace_op.py diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 3ffe02e2bc0ed..b79b69356b3ac 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -37,6 +37,8 @@ std::unordered_map> core_ops_returns_info = {}; std::unordered_map> core_ops_args_info = {}; +std::unordered_map> + core_ops_args_type_info = {}; /* --- Static maps to handle corner cases --- */ static std::unordered_map @@ -1225,10 +1227,16 @@ static std::pair GenerateForwardFunctionContents( */ VLOG(6) << "Generating Dygraph Forward Function"; - std::string generated_function_body = ""; + const char* FORWARD_FUNCTION_TEMPLATE = + " VLOG(3) << \"Running Eager Forward Op: %s\";\n"; + std::string generated_function_body = + paddle::string::Sprintf(FORWARD_FUNCTION_TEMPLATE, op_type); + std::string dygraph_function_args_str = ""; core_ops_args_info[op_type] = {}; + core_ops_args_type_info[op_type] = {}; core_ops_args_info[op_type].resize(in_vars.size()); + core_ops_args_type_info[op_type].resize(in_vars.size()); /* ------ Dygraph forward function generation ------ */ generated_function_body += " // Dygraph Forward Pass\n"; @@ -1246,10 +1254,14 @@ static std::pair GenerateForwardFunctionContents( "const std::vector& %s"; input_args_str_list[input_position] = paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name); + + core_ops_args_type_info[op_type][input_position] = "list"; } else { const char* FWD_INS_ARG_TEMPLATE = "const egr::EagerTensor& %s"; input_args_str_list[input_position] = paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name); + + core_ops_args_type_info[op_type][input_position] = "tensor"; } core_ops_args_info[op_type][input_position] = input_name; @@ -1318,11 +1330,14 @@ static std::pair GenerateForwardFunctionContents( paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); dygraph_function_args_str += arg_str; + core_ops_args_type_info[op_type].push_back("list"); } else { const char* FWD_NUM_ARG_TEMPLATE = ", egr::EagerTensor* %s"; std::string arg_str = paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name); dygraph_function_args_str += arg_str; + + core_ops_args_type_info[op_type].push_back("tensor"); } const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", egr::EagerUtils::TrySyncToVars(%s) },"; @@ -1344,6 +1359,7 @@ static std::pair GenerateForwardFunctionContents( outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name, outnum); core_ops_args_info[op_type].push_back(outnum); + core_ops_args_type_info[op_type].push_back("int"); } else { const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", " @@ -1811,6 +1827,11 @@ static std::string GenerateGradNodeCCContents( } */ + const char* EAGER_LOG_TEMPLATE = + " VLOG(3) << \"Running Eager Backward Node: GradNode%s\";\n"; + std::string generated_grad_function_body = + paddle::string::Sprintf(EAGER_LOG_TEMPLATE, fwd_op_type); + // This is a Copy auto op_base_infos = bwd_info.GetOpBaseInfos(); @@ -1829,7 +1850,6 @@ static std::string GenerateGradNodeCCContents( op_base_infos.emplace_back(std::move(op_base_info)); } - std::string generated_grad_function_body = ""; size_t outs_size = 0; for (size_t i = 0; i < op_base_infos.size(); i++) { const auto& op_base_info = op_base_infos[i]; @@ -2030,6 +2050,9 @@ static std::string GenerateDygraphHFileIncludes() { dygraph_forward_api_includes_str += "extern std::unordered_map> " "core_ops_args_info;\n"; + dygraph_forward_api_includes_str += + "extern std::unordered_map> " + "core_ops_args_type_info;\n"; dygraph_forward_api_includes_str += "extern std::unordered_map> " "core_ops_returns_info;\n\n"; @@ -2126,16 +2149,20 @@ static std::string GenerateCoreOpsReturnsInfo() { "std::unordered_map> " "core_ops_args_info = { %s };\n" "std::unordered_map> " + "core_ops_args_type_info = { %s };\n" + "std::unordered_map> " "core_ops_returns_info = { %s };\n"; std::string core_ops_args_info_init_str = ConvertCoreOpsInfosToString(core_ops_args_info); + std::string core_ops_args_type_info_init_str = + ConvertCoreOpsInfosToString(core_ops_args_type_info); std::string core_ops_returns_info_init_str = ConvertCoreOpsInfosToString(core_ops_returns_info); std::string core_ops_info_str = paddle::string::Sprintf( Core_Ops_Returns_MAP_TEMPLATE, core_ops_args_info_init_str, - core_ops_returns_info_init_str); + core_ops_args_type_info_init_str, core_ops_returns_info_init_str); return core_ops_info_str; } diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index bb3464665411c..4835d8873af19 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -121,6 +121,30 @@ static PyObject* eager_tensor_method__copy_to(EagerTensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } +static PyObject* eager_tensor_method_reconstruct_from_(EagerTensorObject* self, + PyObject* args, + PyObject* kwargs) { + EAGER_SYNC_TRY + egr::EagerTensor src_tensor = + CastPyArg2EagerTensor(PyTuple_GET_ITEM(args, 0), 0); + bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1); + std::string orig_name = self->eager_tensor.name(); + VLOG(6) << "Start Reconstructing Tensor from" << src_tensor.name() << " to " + << orig_name; + self->eager_tensor.copy_(src_tensor, blocking); + // Steal Tensor from src tensor + self->eager_tensor.set_tensor(src_tensor.Tensor()); + + // Recover source name + self->eager_tensor.set_name(orig_name); + + VLOG(6) << "Finished Reconstructing Tensor from" << src_tensor.name() + << " to " << self->eager_tensor.name(); + Py_INCREF(Py_None); + return Py_None; + EAGER_CATCH_AND_THROW_RETURN_NULL +} + static PyObject* eager_tensor_method_copy_(EagerTensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_SYNC_TRY @@ -387,6 +411,9 @@ PyMethodDef variable_methods[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"copy_", (PyCFunction)(void (*)(void))eager_tensor_method_copy_, METH_VARARGS | METH_KEYWORDS, NULL}, + {"reconstruct_from_", + (PyCFunction)(void (*)(void))eager_tensor_method_reconstruct_from_, + METH_VARARGS | METH_KEYWORDS, NULL}, {"retain_grads", (PyCFunction)(void (*)(void))eager_tensor_retain_grads, METH_VARARGS | METH_KEYWORDS, NULL}, {"_clear_gradient", diff --git a/paddle/fluid/pybind/eager_op_function_generator.cc b/paddle/fluid/pybind/eager_op_function_generator.cc index 3d0a4d0de75bd..090604ab4ee1a 100644 --- a/paddle/fluid/pybind/eager_op_function_generator.cc +++ b/paddle/fluid/pybind/eager_op_function_generator.cc @@ -313,6 +313,21 @@ static std::string GenerateCoreOpsInfoMap() { " }\n" "}\n" "\n" + "static PyObject * eager_get_core_ops_args_type_info(PyObject *self) {\n" + " PyThreadState *tstate = nullptr;\n" + " try\n" + " {\n" + " return ToPyObject(core_ops_args_type_info);\n" + " }\n" + " catch(...) {\n" + " if (tstate) {\n" + " PyEval_RestoreThread(tstate);\n" + " }\n" + " ThrowExceptionToPython(std::current_exception());\n" + " return nullptr;\n" + " }\n" + "}\n" + "\n" "static PyObject * eager_get_core_ops_returns_info(PyObject *self) {\n" " PyThreadState *tstate = nullptr;\n" " try\n" @@ -399,6 +414,10 @@ int main(int argc, char* argv[]) { "{\"get_core_ops_args_info\", " "(PyCFunction)(void(*)(void))eager_get_core_ops_args_info, METH_NOARGS, " "\"C++ interface function for eager_get_core_ops_args_info.\"},\n" + "{\"get_core_ops_args_type_info\", " + "(PyCFunction)(void(*)(void))eager_get_core_ops_args_type_info, " + "METH_NOARGS, " + "\"C++ interface function for eager_get_core_ops_args_type_info.\"},\n" " {\"get_core_ops_returns_info\", " "(PyCFunction)(void(*)(void))eager_get_core_ops_returns_info, " "METH_NOARGS, \"C++ interface function for " diff --git a/python/paddle/fluid/dygraph/tracer.py b/python/paddle/fluid/dygraph/tracer.py index 2ecb0998dd355..f31edf142b22d 100644 --- a/python/paddle/fluid/dygraph/tracer.py +++ b/python/paddle/fluid/dygraph/tracer.py @@ -19,6 +19,7 @@ from collections import defaultdict from paddle.fluid import core from paddle.fluid import framework +from paddle import _C_ops class Tracer(core.Tracer): @@ -46,9 +47,87 @@ def trace_op(self, attrs, stop_gradient=False, inplace_map=None): - self.trace(type, inputs, outputs, attrs, - framework._current_expected_place(), self._has_grad and - not stop_gradient, inplace_map if inplace_map else {}) + if framework._in_eager_mode(): + # inputs : {"sum": [tensor], ...} + # outputs : {"sum": [tensor], ...} + + function_ptr = _C_ops.__dict__[type] + + core_ops_args_info = _C_ops.get_core_ops_args_info() + core_ops_args_type_info = _C_ops.get_core_ops_args_type_info() + core_ops_returns_info = _C_ops.get_core_ops_returns_info() + + op_args = core_ops_args_info[type] + op_args_type = core_ops_args_type_info[type] + op_returns = core_ops_returns_info[type] + + arg_list = [] + for i in range(len(op_args)): + arg_name = op_args[i] + arg_type = op_args_type[i] + if arg_name in inputs.keys(): + arg_to_append = inputs[arg_name] + elif arg_name in outputs.keys(): + arg_to_append = outputs[arg_name] + else: + if "Num" in arg_name: + # Remove "Num" suffix to get out_name + out_name = arg_name[:-3] + assert out_name in outputs.keys() + num_outs = len(outputs[out_name]) + arg_to_append = num_outs + else: + arg_to_append = None + + if arg_to_append is None: + arg_list.append(arg_to_append) + elif arg_type == "tensor": + if isinstance(arg_to_append, list): + arg_list.append(arg_to_append[0]) + else: + arg_list.append(arg_to_append) + elif arg_type == "list": + assert isinstance(arg_to_append, list) + arg_list.append(arg_to_append) + else: + assert arg_type == "int" + assert isinstance(arg_to_append, int) + arg_list.append(arg_to_append) + + attrs_list = [] + for k, v in attrs.items(): + attrs_list.append(k) + attrs_list.append(v) + returns = function_ptr(*arg_list, *attrs_list) + + if isinstance(returns, tuple): + for i in range(len(op_returns)): + retname = op_returns[i] + if retname in outputs.keys(): + # Replaced outputs by function returns + if isinstance(returns[i], list): + for j in range(len(returns[i])): + outputs[retname][j].reconstruct_from_( + returns[i][j], False) + else: + outputs[retname][0].reconstruct_from_(returns[i], + False) + elif isinstance(returns, list): + assert len(outputs.keys()) == 1 + key = list(outputs.keys())[0] + for j in range(len(returns)): + outputs[key][j].reconstruct_from_(returns[j], False) + else: + assert len(outputs.keys()) == 1 + key = list(outputs.keys())[0] + if isinstance(outputs[key], list): + outputs[key][0].reconstruct_from_(returns, False) + else: + outputs[key].reconstruct_from_(returns, False) + else: + self.trace(type, inputs, outputs, attrs, + framework._current_expected_place(), self._has_grad and + not stop_gradient, inplace_map if inplace_map else {}) def train_mode(self): self._train_mode = True diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 5c57d1a21bce6..2ac5e9404c1ba 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -104,6 +104,10 @@ foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() +if(ON_INFER) + LIST(REMOVE_ITEM TEST_OPS test_eager_trace_op) +endif() + if(NOT WITH_GPU) LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 01d851469a8d1..e05acdd6b42cc 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -25,10 +25,12 @@ import itertools import collections from collections import defaultdict +from copy import copy import paddle import paddle.fluid as fluid import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard from paddle.fluid.backward import append_backward from paddle.fluid.op import Operator from paddle.fluid.executor import Executor @@ -495,7 +497,7 @@ def _append_ops(self, block): type=self.op_type, inputs=inputs, outputs=outputs, - attrs=self.attrs if hasattr(self, "attrs") else dict()) + attrs=copy(self.attrs) if hasattr(self, "attrs") else dict()) # infer variable type and infer shape in compile-time op.desc.infer_var_type(block.desc) op.desc.infer_shape(block.desc) @@ -1111,7 +1113,8 @@ def check_output_with_place(self, no_check_set=None, equal_nan=False, check_dygraph=True, - inplace_atol=None): + inplace_atol=None, + check_eager=False): self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) if self.dtype == np.float64 and \ self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST: @@ -1120,6 +1123,7 @@ def check_output_with_place(self, if self.is_bfloat16_op(): if self.is_mkldnn_op(): check_dygraph = False + check_eager = False if hasattr(self, 'force_fp32_output') and getattr( self, 'force_fp32_output'): atol = 1e-2 @@ -1136,6 +1140,10 @@ def check_output_with_place(self, if check_dygraph: dygraph_outs = self._calc_dygraph_output( place, no_check_set=no_check_set) + if check_eager: + with _test_eager_guard(): + eager_dygraph_outs = self._calc_dygraph_output( + place, no_check_set=no_check_set) outs, fetch_list = self._calc_output(place, no_check_set=no_check_set) for out_name, out_dup in Operator.get_op_outputs(self.op_type): @@ -1178,6 +1186,13 @@ def find_actual(target_name, fetch_list): sub_out_name, dygraph_outs, place) imperative_actual_t = np.array(imperative_actual.value() .get_tensor()) + if check_eager: + with _test_eager_guard(): + eager_imperative_actual = find_imperative_actual( + sub_out_name, eager_dygraph_outs, place) + eager_imperative_actual_t = eager_imperative_actual.numpy( + ) + idx = find_actual(sub_out_name, fetch_list) actual = outs[idx] actual_t = np.array(actual) @@ -1197,6 +1212,16 @@ def find_actual(target_name, fetch_list): equal_nan=equal_nan), "Output (" + sub_out_name + ") has diff at " + str(place) + " in dygraph mode") + if check_eager: + with _test_eager_guard(): + self.assertTrue( + np.allclose( + eager_imperative_actual_t, + expect_t, + atol=atol, + equal_nan=equal_nan), + "Output (" + sub_out_name + ") has diff at " + + str(place) + " in eager dygraph mode") if isinstance(expect, tuple): self.assertListEqual( actual.recursive_sequence_lengths(), expect[1], @@ -1209,12 +1234,27 @@ def find_actual(target_name, fetch_list): "Output (" + out_name + ") has different lod at " + str(place) + " in dygraph mode") + if check_eager: + with _test_eager_guard(): + self.assertListEqual( + eager_imperative_actual.value().get_tensor() + .recursive_sequence_lengths(), expect[1], + "Output (" + out_name + + ") has different lod at " + str(place) + + " in eager dygraph mode") else: if check_dygraph: imperative_actual = find_imperative_actual( out_name, dygraph_outs, place) imperative_actual_t = np.array(imperative_actual.value() .get_tensor()) + if check_eager: + with _test_eager_guard(): + eager_imperative_actual = find_imperative_actual( + out_name, eager_dygraph_outs, place) + eager_imperative_actual_t = eager_imperative_actual.numpy( + ) + idx = find_actual(out_name, fetch_list) actual = outs[idx] actual_t = np.array(actual) @@ -1275,6 +1315,32 @@ def find_actual(target_name, fetch_list): str(place) + "\nExpect " + str(expect_t) + "\n" + "But Got" + str(imperative_actual_t) + " in class " + self.__class__.__name__) + if check_eager: + with _test_eager_guard(): + if self.is_bfloat16_op(): + if eager_imperative_actual_t.dtype == np.uint16: + eager_imperative_actual_t = convert_uint16_to_float( + eager_imperative_actual_t) + if expect_t.dtype == np.uint16: + expect_t = convert_uint16_to_float(expect_t) + if six.moves.reduce(lambda x, y: x * y, + eager_imperative_actual_t.shape, + 1) == 0 and six.moves.reduce( + lambda x, y: x * y, + expect_t.shape, 1) == 0: + pass + else: + self.assertTrue( + np.allclose( + eager_imperative_actual_t, + expect_t, + atol=atol, + rtol=rtol, + equal_nan=equal_nan), + "Output (" + out_name + ") has diff at " + + str(place) + "\nExpect " + str(expect_t) + "\n" + + "But Got" + str(eager_imperative_actual_t) + + " in class " + self.__class__.__name__) if isinstance(expect, tuple): self.assertListEqual(actual.recursive_sequence_lengths(), expect[1], "Output (" + out_name + @@ -1284,7 +1350,15 @@ def find_actual(target_name, fetch_list): imperative_actual.value().get_tensor() .recursive_sequence_lengths(), expect[1], "Output (" + out_name + ") has different lod at " + - str(place) + " in dygraph mode") + str(place) + " in eager dygraph mode") + if check_eager: + with _test_eager_guard(): + self.assertListEqual( + eager_imperative_actual.value().get_tensor() + .recursive_sequence_lengths(), expect[1], + "Output (" + out_name + + ") has different lod at " + str(place) + + " in eager dygraph mode") # Note(zhiqiu): inplace_atol should be only set when op doesn't ensure # computational consistency. @@ -1306,7 +1380,9 @@ def find_actual(target_name, fetch_list): self.check_inplace_output_with_place( place, no_check_set=no_check_set, inplace_atol=inplace_atol) - if check_dygraph: + if check_eager: + return outs, dygraph_outs, eager_dygraph_outs, fetch_list + elif check_dygraph: return outs, dygraph_outs, fetch_list else: return outs, fetch_list @@ -1377,7 +1453,8 @@ def check_output(self, no_check_set=None, equal_nan=False, check_dygraph=True, - inplace_atol=None): + inplace_atol=None, + check_eager=False): self.__class__.op_type = self.op_type if self.is_mkldnn_op(): self.__class__.use_mkldnn = True @@ -1387,10 +1464,18 @@ def check_output(self, places = self._get_places() for place in places: - res = self.check_output_with_place(place, atol, no_check_set, - equal_nan, check_dygraph, - inplace_atol) - if check_dygraph: + res = self.check_output_with_place( + place, + atol, + no_check_set, + equal_nan, + check_dygraph, + inplace_atol, + check_eager=check_eager) + if check_eager: + assert check_dygraph == True + outs, dygraph_outs, eager_dygraph_outs, fetch_list = res + elif check_dygraph: outs, dygraph_outs, fetch_list = res else: outs, fetch_list = res @@ -1461,14 +1546,23 @@ def check_grad(self, max_relative_error=0.005, user_defined_grads=None, user_defined_grad_outputs=None, - check_dygraph=True): + check_dygraph=True, + check_eager=False): self._check_grad_helper() places = self._get_places() for place in places: self.check_grad_with_place( - place, inputs_to_check, output_names, no_grad_set, - numeric_grad_delta, in_place, max_relative_error, - user_defined_grads, user_defined_grad_outputs, check_dygraph) + place, + inputs_to_check, + output_names, + no_grad_set, + numeric_grad_delta, + in_place, + max_relative_error, + user_defined_grads, + user_defined_grad_outputs, + check_dygraph, + check_eager=check_eager) def check_grad_with_place(self, place, @@ -1481,7 +1575,8 @@ def check_grad_with_place(self, user_defined_grads=None, user_defined_grad_outputs=None, check_dygraph=True, - numeric_place=None): + numeric_place=None, + check_eager=False): self.scope = core.Scope() op_inputs = self.inputs if hasattr(self, "inputs") else dict() op_outputs = self.outputs if hasattr(self, "outputs") else dict() @@ -1490,6 +1585,7 @@ def check_grad_with_place(self, self._check_grad_helper() if self.is_bfloat16_op() and self.is_mkldnn_op(): check_dygraph = False + check_eager = False if self.dtype == np.float64 and \ self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST: @@ -1592,6 +1688,22 @@ def check_grad_with_place(self, max_relative_error, "Gradient Check On %s" % str(place)) + if check_eager: + with _test_eager_guard(): + eager_dygraph_grad = self._get_dygraph_grad( + inputs_to_check, place, output_names, + user_defined_grad_outputs, no_grad_set) + fp32_grads = [] + for grad in eager_dygraph_grad: + if grad.dtype == np.uint16: + grad = convert_uint16_to_float(grad) + max_relative_error = 0.03 if max_relative_error < 0.03 else max_relative_error + fp32_grads.append(grad) + eager_dygraph_grad = fp32_grads + self._assert_is_close(numeric_grads, eager_dygraph_grad, + inputs_to_check, max_relative_error, + "Gradient Check On %s" % str(place)) + def _find_var_in_dygraph(self, output_vars, name): if name in output_vars: return output_vars[name] diff --git a/python/paddle/fluid/tests/unittests/op_test_xpu.py b/python/paddle/fluid/tests/unittests/op_test_xpu.py index 187d78ba04aee..e77e3956c34b5 100644 --- a/python/paddle/fluid/tests/unittests/op_test_xpu.py +++ b/python/paddle/fluid/tests/unittests/op_test_xpu.py @@ -78,7 +78,8 @@ def check_output_with_place(self, no_check_set=None, equal_nan=False, check_dygraph=True, - inplace_atol=None): + inplace_atol=None, + check_eager=False): self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) #xpu not support float64 if self.dtype == np.float64: @@ -105,7 +106,8 @@ def check_grad_with_place(self, user_defined_grads=None, user_defined_grad_outputs=None, check_dygraph=True, - numeric_place=None): + numeric_place=None, + check_eager=False): if place == None: place = paddle.XPUPlace(0) diff --git a/python/paddle/fluid/tests/unittests/test_diag_v2.py b/python/paddle/fluid/tests/unittests/test_diag_v2.py index c364fb0a19335..1478cd888c47c 100644 --- a/python/paddle/fluid/tests/unittests/test_diag_v2.py +++ b/python/paddle/fluid/tests/unittests/test_diag_v2.py @@ -41,7 +41,7 @@ def setUp(self): def test_check_output(self): paddle.enable_static() - self.check_output() + self.check_output(check_eager=True) def init_config(self): pass diff --git a/python/paddle/fluid/tests/unittests/test_digamma_op.py b/python/paddle/fluid/tests/unittests/test_digamma_op.py index 86f59af19346c..503094779a3ae 100644 --- a/python/paddle/fluid/tests/unittests/test_digamma_op.py +++ b/python/paddle/fluid/tests/unittests/test_digamma_op.py @@ -40,10 +40,10 @@ def init_dtype_type(self): self.dtype = np.float64 def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad_normal(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestDigammaOpFp32(TestDigammaOp): diff --git a/python/paddle/fluid/tests/unittests/test_eager_trace_op.py b/python/paddle/fluid/tests/unittests/test_eager_trace_op.py new file mode 100644 index 0000000000000..79de17fdb6641 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_trace_op.py @@ -0,0 +1,50 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid.core as core +from paddle import _C_ops +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +from paddle.fluid.framework import _test_eager_guard + + +class TestEagerTraceOp(unittest.TestCase): + def test_branches(self): + with _test_eager_guard(): + data = np.random.random([1, 1]).astype(np.float32) + x = paddle.to_tensor(data) + + paddle.fluid.framework._dygraph_tracer().trace_op( + 'broadcast_tensors', {'X': [x, x], + 'Out': [x, x]}, {'Out': [x, x]}, {}) + paddle.fluid.framework._dygraph_tracer().trace_op( + 'scale', {'X': x}, {'Out': x}, {'scale': 0.5}) + + scale = paddle.to_tensor(np.random.random([1]).astype(np.float32)) + paddle.fluid.framework._dygraph_tracer().trace_op( + 'instance_norm', {'Scale': [scale], + 'X': [x]}, {'Y': [x]}, {}) + paddle.fluid.framework._dygraph_tracer().trace_op( + 'coalesce_tensor', {'Input': [x]}, {'Output': [x]}, + {'dtype': int(core.VarDesc.VarType.FP32)}) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_trunc_op.py b/python/paddle/fluid/tests/unittests/test_trunc_op.py index 51844071138c7..b4482b402ea96 100644 --- a/python/paddle/fluid/tests/unittests/test_trunc_op.py +++ b/python/paddle/fluid/tests/unittests/test_trunc_op.py @@ -37,10 +37,10 @@ def init_dtype_type(self): self.dtype = np.float64 def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5) + self.check_grad(['X'], 'Out', numeric_grad_delta=1e-5, check_eager=True) class TestFloatTruncOp(TestTruncOp): From e106901ecab756876a066e0af1304e0b82cab0c3 Mon Sep 17 00:00:00 2001 From: helen88 Date: Mon, 24 Jan 2022 10:24:16 +0800 Subject: [PATCH 14/14] support sparse of adam, *test=kunlun (#38483) * support sparse of adam, *test=kunlun * add pre-commit-config.yaml * support sparse of adam in KL2,*test=kunlun * support sparse of adam in KL2, *test=kunlun * modify xpu.cmake, *test=kunlun * support sparse of adam, rm some wait, *test=kunlun * support sparse of adam, rm some wait, *test=kunlun * support sparse of adam, *test=kunlun * support sparse of adam, *test=kunlun * support sparse of adam, *test=kunlun * support sparse of adam, *test=kunlun * support sparse of adam, *test=kunlun --- cmake/external/xpu.cmake | 2 +- .../operators/math/selected_rows_functor.cc | 153 ++++++++++++++++++ .../fluid/operators/optimizers/adam_op_xpu.cc | 127 ++++++++++++++- .../tests/unittests/xpu/test_adam_op_xpu.py | 138 ++++++++++++++++ 4 files changed, 411 insertions(+), 9 deletions(-) diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index c7a6f04b5f40a..578fb1621603f 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -36,7 +36,7 @@ ENDIF() if(NOT DEFINED XPU_BASE_URL) SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") - SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220104") + SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220116") else() SET(XPU_BASE_URL "${XPU_BASE_URL}") endif() diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 67176f26b079f..f6178eb0a1eb6 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -477,6 +477,155 @@ struct MergeAdd { } }; +#ifdef PADDLE_WITH_XPU +template +struct MergeAdd { + framework::SelectedRows operator()(const platform::XPUDeviceContext& context, + const framework::SelectedRows& input, + const bool sorted_result = false) { + framework::SelectedRows out; + (*this)(context, input, &out, sorted_result); + return out; + } + + void operator()(const platform::XPUDeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* output, + const bool sorted_result = false) { + framework::Vector input_rows(input.rows()); + if (input_rows.size() == 0) { + return; + } + + framework::SelectedRows& out = *output; + std::set row_set(input_rows.begin(), input_rows.end()); + std::vector merge_rows(row_set.begin(), row_set.end()); + auto input_width = input.value().dims()[1]; + + out.set_rows(merge_rows); + out.set_height(input.height()); + out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + int r = + xpu::constant(context.x_context(), out.mutable_value()->data(), + merge_rows.size() * input_width, static_cast(0.f)); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External("XPU constant op return" + " wrong value[%d %s].", + r, XPUAPIErrorMsg[r])); + + std::unordered_map rows_to_id; + for (size_t i = 0; i < merge_rows.size(); ++i) { + rows_to_id[merge_rows[i]] = i; + } + + auto* out_data = out.mutable_value()->data(); + auto* input_data = input.value().data(); + int n = input_width; + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = rows_to_id[input_rows[i]]; + auto r = xpu::add(context.x_context(), &input_data[i * input_width], + &out_data[out_i * input_width], + &out_data[out_i * input_width], n); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU API return wrong value[%d %s], ", r, + XPUAPIErrorMsg[r])); + } + } + + void operator()(const platform::XPUDeviceContext& context, + const std::vector& inputs, + framework::SelectedRows* output, + const bool sorted_result = false) { + if (inputs.size() == 0) { + VLOG(3) << "no input! return"; + return; + } + const framework::SelectedRows* has_value_input = nullptr; + for (auto* in : inputs) { + if (in->rows().size() > 0) { + has_value_input = in; + break; + } + } + if (has_value_input == nullptr) { + VLOG(3) << "no input has value! just return" << std::endl; + return; + } + auto input_width = has_value_input->value().dims()[1]; + auto input_height = has_value_input->height(); + framework::SelectedRows& out = *output; + std::set merged_row_set; + size_t row_num = 0; + for (auto* input : inputs) { + if (input->rows().size() == 0) { + continue; + } + PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1], + platform::errors::InvalidArgument( + "All inputs should have same " + "dimension except for the first one.")); + PADDLE_ENFORCE_EQ(input_height, input->height(), + platform::errors::InvalidArgument( + "All inputs should have same height.")); + row_num += input->rows().size(); + merged_row_set.insert(input->rows().begin(), input->rows().end()); + } + + std::vector merge_rows(merged_row_set.begin(), + merged_row_set.end()); + + if (sorted_result) { + std::sort(merge_rows.begin(), merge_rows.end()); + } + + out.set_rows(merge_rows); + out.set_height(input_height); + out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merged_row_set.size()), input_width}), + context.GetPlace()); + + int r = + xpu::constant(context.x_context(), out.mutable_value()->data(), + merge_rows.size() * input_width, static_cast(0.f)); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External("XPU constant op return" + " wrong value[%d %s].", + r, XPUAPIErrorMsg[r])); + + float* out_data = reinterpret_cast(out.mutable_value()->data()); + + std::unordered_map rows_to_id; + for (size_t i = 0; i < merge_rows.size(); ++i) { + rows_to_id[merge_rows[i]] = i; + } + + for (auto* input : inputs) { + if (input->rows().size() == 0) { + continue; + } + auto& input_rows = input->rows(); + + int n = input_width; + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = rows_to_id[input_rows[i]]; + auto r = xpu::add( + context.x_context(), input->value().data() + i * input_width, + &out_data[out_i * input_width], &out_data[out_i * input_width], n); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU API return wrong value[%d %s], ", r, + XPUAPIErrorMsg[r])); + } + } + } +}; + +#endif template struct MergeAverage { framework::SelectedRows operator()(const platform::CPUDeviceContext& context, @@ -589,6 +738,10 @@ template struct MergeAdd; +#ifdef PADDLE_WITH_XPU +template struct MergeAdd; +#endif + template struct MergeAverage; template struct MergeAverage; template struct MergeAverage; diff --git a/paddle/fluid/operators/optimizers/adam_op_xpu.cc b/paddle/fluid/operators/optimizers/adam_op_xpu.cc index 0a653c4011719..e462c20c7f51d 100644 --- a/paddle/fluid/operators/optimizers/adam_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_xpu.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/optimizers/adam_op.h" #include "gflags/gflags.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" namespace paddle { namespace operators { @@ -155,6 +156,11 @@ class AdamOpXPUKernel : public framework::OpKernel { mom2_out.template mutable_data(ctx.GetPlace()), param_out.template mutable_data(ctx.GetPlace()), beta1, beta2, epsilon, param.numel()); + + xpu_wait(dev_ctx.x_context()->xpu_stream); + PADDLE_ENFORCE_EQ( + r == xpu::Error_t::SUCCESS, true, + platform::errors::External("XPU API return wrong value[%d],", r)); if (!use_global_beta_pow) { // update in cpu and then copy to xpu if (beta1_pow.place() == platform::CPUPlace() && @@ -165,7 +171,6 @@ class AdamOpXPUKernel : public framework::OpKernel { const float* beta2_pow_p = beta2_pow.template data(); beta2_pow_out->mutable_data(platform::CPUPlace())[0] = beta2 * beta2_pow_p[0]; - xpu_wait(dev_ctx.x_context()->xpu_stream); } else { float* beta1_pow_out_p = beta1_pow_out->mutable_data(ctx.GetPlace()); @@ -177,23 +182,129 @@ class AdamOpXPUKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ( r, xpu::SUCCESS, platform::errors::External( - "XPU kernel scale occur error in adamw error code ", r, + "XPU kernel scale occur error in adam error code ", r, XPUAPIErrorMsg[r])); r = xpu::scale(dev_ctx.x_context(), beta2_pow_ptr, beta2_pow_out_p, beta2_pow.numel(), false, beta2, 0.0f); PADDLE_ENFORCE_EQ( r, xpu::SUCCESS, platform::errors::External( - "XPU kernel scale occur error in adamw error code ", r, + "XPU kernel scale occur error in adam error code ", r, XPUAPIErrorMsg[r])); + + xpu_wait(dev_ctx.x_context()->xpu_stream); + } + } + } else if (grad_var->IsType()) { + auto* grad = ctx.Input("Grad"); + auto& dev_ctx = ctx.template device_context(); + + if (grad->rows().size() == 0) { + VLOG(3) << "grad row size is 0!!"; + return; + } + + std::vector cpu_rows(grad->rows().begin(), grad->rows().end()); + bool is_strict_sorted = true; + for (size_t i = 1; i < cpu_rows.size(); ++i) { + if (cpu_rows[i - 1] >= cpu_rows[i]) { + is_strict_sorted = false; + break; } + } + + framework::SelectedRows tmp_grad_merge; + const framework::SelectedRows* grad_merge_ptr; + if (is_strict_sorted) { + grad_merge_ptr = grad; + } else { + scatter::MergeAdd merge_func; + merge_func(ctx.template device_context(), + *grad, &tmp_grad_merge, true); + + xpu_wait(dev_ctx.x_context()->xpu_stream); + grad_merge_ptr = &tmp_grad_merge; + } + const T* beta1_pow_ptr = beta1_pow.template data(); + const T* beta2_pow_ptr = beta2_pow.template data(); + Tensor xpu_beta1_pow; + Tensor xpu_beta2_pow; + if (beta1_pow.place() == platform::CPUPlace() && + beta2_pow.place() == platform::CPUPlace()) { + paddle::framework::TensorCopy(beta1_pow, ctx.GetPlace(), dev_ctx, + &xpu_beta1_pow); + paddle::framework::TensorCopy(beta2_pow, ctx.GetPlace(), dev_ctx, + &xpu_beta2_pow); + dev_ctx.Wait(); + beta1_pow_ptr = xpu_beta1_pow.template data(); + beta2_pow_ptr = xpu_beta2_pow.template data(); + } + auto& grad_merge = *grad_merge_ptr; + auto& grad_tensor = grad_merge.value(); + const T* grad_data = grad_tensor.template data(); + int row_count = grad_merge.rows().size(); + std::vector rows(row_count); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int* xpu_rows = RAII_GUARD.alloc_l3_or_gm(row_count); + std::vector merge_rows(grad_merge.rows().begin(), + grad_merge.rows().end()); + for (size_t i = 0; i < grad_merge.rows().size(); ++i) { + rows[i] = static_cast(merge_rows[i]); + } + xpu_wait(dev_ctx.x_context()->xpu_stream); + memory::Copy(ctx.GetPlace(), xpu_rows, platform::CPUPlace(), rows.data(), + row_count * sizeof(int)); + auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); + auto ori_rows = param.numel() / row_numel; - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External( - "XPU API return wrong value[%d], please check " - "where Baidu Kunlun Card is properly installed.", - r)); + int lazy_mode = static_cast(ctx.Attr("lazy_mode")); + int r = xpu::sparse_adam( + dev_ctx.x_context(), grad_data, mom1.template data(), + mom2.template data(), param.template data(), beta1_pow_ptr, + beta2_pow_ptr, lr.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2_out.template mutable_data(ctx.GetPlace()), + param_out.template mutable_data(ctx.GetPlace()), beta1, beta2, + epsilon, ori_rows, xpu_rows, row_numel, grad_merge.rows().size(), + lazy_mode); + + PADDLE_ENFORCE_EQ( + r == xpu::Error_t::SUCCESS, true, + platform::errors::External("XPU API return wrong value[%d],", r)); + + if (!use_global_beta_pow) { + // update in cpu and then copy to xpu + if (beta1_pow.place() == platform::CPUPlace() && + beta2_pow.place() == platform::CPUPlace()) { + const float* beta1_pow_p = beta1_pow.template data(); + beta1_pow_out->mutable_data(platform::CPUPlace())[0] = + beta1 * beta1_pow_p[0]; + const float* beta2_pow_p = beta2_pow.template data(); + beta2_pow_out->mutable_data(platform::CPUPlace())[0] = + beta2 * beta2_pow_p[0]; + } else { + float* beta1_pow_out_p = + beta1_pow_out->mutable_data(ctx.GetPlace()); + float* beta2_pow_out_p = + beta2_pow_out->mutable_data(ctx.GetPlace()); + int r = + xpu::scale(dev_ctx.x_context(), beta1_pow_ptr, beta1_pow_out_p, + beta1_pow.numel(), false, beta1, 0.0f); + PADDLE_ENFORCE_EQ( + r, xpu::SUCCESS, + platform::errors::External( + "XPU kernel scale occur error in adam error code ", r, + XPUAPIErrorMsg[r])); + r = xpu::scale(dev_ctx.x_context(), beta2_pow_ptr, beta2_pow_out_p, + beta2_pow.numel(), false, beta2, 0.0f); + PADDLE_ENFORCE_EQ( + r, xpu::SUCCESS, + platform::errors::External( + "XPU kernel scale occur error in adam error code ", r, + XPUAPIErrorMsg[r])); + } } + xpu_wait(dev_ctx.x_context()->xpu_stream); } else { PADDLE_ENFORCE_EQ(1, 2, platform::errors::InvalidArgument( "Variable type not supported by adam_op")); diff --git a/python/paddle/fluid/tests/unittests/xpu/test_adam_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_adam_op_xpu.py index 147824f341be4..a36c0bf071332 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_adam_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_adam_op_xpu.py @@ -216,6 +216,144 @@ def adam_step(inputs, attributes): return param_out, moment1_out, moment2_out +def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad, + lazy_mode): + ''' + Simulate one step of the adam optimizer + :param inputs: dict of inputs + :param attributes: dict of attributes + :return tuple: tuple of output param, moment1, moment2, + beta1 power accumulator and beta2 power accumulator + ''' + param = inputs['Param'] + # grad = inputs['Grad'] + moment1 = inputs['Moment1'] + moment2 = inputs['Moment2'] + lr = inputs['LearningRate'] + beta1_pow = inputs['Beta1Pow'] + beta2_pow = inputs['Beta2Pow'] + + beta1 = attributes['beta1'] + beta2 = attributes['beta2'] + epsilon = attributes['epsilon'] + + moment1_out = np.zeros(shape=[height, row_numel]) + moment2_out = np.zeros(shape=[height, row_numel]) + param_out = np.zeros(shape=[height, row_numel]) + + def update_row(row_id, update_value): + moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1 + ) * update_value + moment2_out[row_id] = beta2 * moment2[row_id] + ( + 1 - beta2) * np.square(update_value) + lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) + param_out[row_id] = param[row_id] - lr_t * (moment1_out[row_id] / ( + np.sqrt(moment2_out[row_id]) + epsilon)) + + if lazy_mode: + for idx, row_id in enumerate(rows): + update_row(row_id, np_grad[idx]) + else: + for row_id in range(param_out.shape[0]): + update_value = np.zeros(np_grad[0].shape).astype("float32") + if row_id in rows: + update_value = np_grad[rows.index(row_id)] + update_row(row_id, update_value) + + return param_out, moment1_out, moment2_out + + +class TestSparseAdamOp(unittest.TestCase): + def setup(self, scope, place, lazy_mode): + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + beta1_pow = np.array([beta1**10]).astype("float32") + beta2_pow = np.array([beta2**10]).astype("float32") + + height = 10 + rows = [0, 4, 7] + self.rows = rows + row_numel = 12 + self.row_numel = row_numel + self.dense_inputs = { + "Param": np.full((height, row_numel), 5.0).astype("float32"), + "Moment1": np.full((height, row_numel), 5.0).astype("float32"), + "Moment2": np.full((height, row_numel), 5.0).astype("float32"), + 'Beta1Pow': beta1_pow, + 'Beta2Pow': beta2_pow, + "LearningRate": np.full((1), 2.0).astype("float32") + } + self.init_output = np.full((height, row_numel), 0.0).astype("float32") + self.attrs = { + 'epsilon': epsilon, + 'beta1': beta1, + 'beta2': beta2, + 'min_row_size_to_use_multithread': 2 + } + + grad_selected_rows = scope.var('Grad').get_selected_rows() + grad_selected_rows.set_height(height) + grad_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + np_array[0, 0] = 2.0 + np_array[2, 8] = 4.0 + + grad_tensor = grad_selected_rows.get_tensor() + grad_tensor.set(np_array, place) + + self.sparse_inputs = ["Grad"] + + param_out, mom1, mom2 = adam_step_sparse(self.dense_inputs, self.attrs, + height, rows, row_numel, + np_array, lazy_mode) + self.outputs = { + "ParamOut": param_out, + "Moment1Out": mom1, + "Moment2Out": mom2, + 'Beta1PowOut': beta1_pow * beta1, + 'Beta2PowOut': beta2_pow * beta2 + } + + def check_with_place(self, place, lazy_mode): + scope = core.Scope() + self.setup(scope, place, lazy_mode) + + op_args = dict() + op_args['lazy_mode'] = lazy_mode + for key, np_array in self.dense_inputs.items(): + var = scope.var(key).get_tensor() + var.set(np_array, place) + op_args[key] = key + for s in self.sparse_inputs: + op_args[s] = s + for s in self.outputs: + var = scope.var(s).get_tensor() + var.set(self.init_output, place) + op_args[s] = s + for k in self.attrs: + op_args[k] = self.attrs[k] + + # create and run adam operator + adam_op = Operator("adam", **op_args) + adam_op.run(scope, place) + + for key, np_array in self.outputs.items(): + out_var = scope.var(key).get_tensor() + actual = np.array(out_var) + actual = actual.reshape([actual.size]) + np_array = np_array.reshape([np_array.size]) + + for i in range(np_array.size): + self.assertLess((actual[i] - np_array[i]), 0.00001) + + def test_sparse_adam(self): + xpu_version = core.get_xpu_device_version(0) + version_str = "xpu2" if xpu_version == core.XPUVersion.XPU2 else "xpu1" + if "xpu2" == version_str: + self.check_with_place(paddle.XPUPlace(0), False) + + class TestAdamOpBetaVariable(OpTest): def setUp(self): '''Test Adam Op with beta as Variable