From 7c6acc6b8798adfc15bd34ee3390553ae7409efb Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 21 Dec 2021 16:22:46 +0000 Subject: [PATCH 1/2] add empty and empty_like kernel in pten --- paddle/fluid/operators/empty_op.cc | 14 ++ paddle/pten/CMakeLists.txt | 4 +- paddle/pten/api/include/kernel_signature.h | 5 + paddle/pten/api/lib/kernel_declare.h | 6 +- paddle/pten/include/creation.h | 2 +- paddle/pten/infermeta/nary.cc | 12 +- paddle/pten/infermeta/nary.h | 12 +- paddle/pten/infermeta/unary.cc | 6 +- paddle/pten/infermeta/unary.h | 6 +- paddle/pten/kernels/cpu/CMakeLists.txt | 1 + paddle/pten/kernels/cpu/empty_kernel.cc | 41 ++++++ paddle/pten/kernels/cuda/CMakeLists.txt | 2 + paddle/pten/kernels/cuda/empty_kernel.cu | 41 ++++++ paddle/pten/kernels/empty_kernel.h | 28 ++++ paddle/pten/kernels/full_kernel.h | 1 - paddle/pten/kernels/impl/empty_kernel_impl.h | 34 +++++ paddle/pten/tests/api/CMakeLists.txt | 1 + paddle/pten/tests/api/test_empty_api.cc | 127 +++++++++++++++++++ python/paddle/utils/code_gen/api.yaml | 30 ++++- python/paddle/utils/code_gen/api_gen.py | 5 +- 20 files changed, 349 insertions(+), 29 deletions(-) create mode 100644 paddle/pten/kernels/cpu/empty_kernel.cc create mode 100644 paddle/pten/kernels/cuda/empty_kernel.cu create mode 100644 paddle/pten/kernels/empty_kernel.h create mode 100644 paddle/pten/kernels/impl/empty_kernel_impl.h create mode 100644 paddle/pten/tests/api/test_empty_api.cc diff --git a/paddle/fluid/operators/empty_op.cc b/paddle/fluid/operators/empty_op.cc index 3d28ca90a5a15f..71780971560172 100644 --- a/paddle/fluid/operators/empty_op.cc +++ b/paddle/fluid/operators/empty_op.cc @@ -109,6 +109,20 @@ class EmptyOp : public framework::OperatorWithKernel { framework::proto::VarType::Type(context.Attr("dtype")), context.GetPlace()); } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext& ctx) const override { + std::string shape; + if (ctx.HasInput("ShapeTensor")) { + shape = "ShapeTensor"; + } else if (ctx.MultiInput("ShapeTensorList").size()) { + shape = "ShapeTensorList"; + } else { + shape = "shape"; + } + + return framework::KernelSignature("empty", {}, {shape}, {"Out"}); + } }; class EmptyOpVarTypeInference : public framework::VarTypeInference { diff --git a/paddle/pten/CMakeLists.txt b/paddle/pten/CMakeLists.txt index eb9a149dd6da4f..ee4708b024d546 100644 --- a/paddle/pten/CMakeLists.txt +++ b/paddle/pten/CMakeLists.txt @@ -24,10 +24,10 @@ add_subdirectory(tests) # make an unity target for compile deps set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context) -set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu manipulation_cpu conj_kernel_cpu scale_kernel_cpu full_kernel_cpu) +set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu manipulation_cpu conj_kernel_cpu scale_kernel_cpu full_kernel_cpu empty_kernel_cpu) set(PTEN_DEPS ${PTEN_DEPS} nary unary binary) if(WITH_GPU OR WITH_ROCM) - set(PTEN_DEPS ${PTEN_DEPS} math_cuda linalg_cuda manipulation_cuda conj_kernel_cuda scale_kernel_cuda full_kernel_cuda) + set(PTEN_DEPS ${PTEN_DEPS} math_cuda linalg_cuda manipulation_cuda conj_kernel_cuda scale_kernel_cuda full_kernel_cuda empty_kernel_cuda) endif() if(WITH_XPU) set(PTEN_DEPS ${PTEN_DEPS} manipulation_xpu) diff --git a/paddle/pten/api/include/kernel_signature.h b/paddle/pten/api/include/kernel_signature.h index ebae064c336897..1c120edb9ab7bc 100644 --- a/paddle/pten/api/include/kernel_signature.h +++ b/paddle/pten/api/include/kernel_signature.h @@ -50,6 +50,11 @@ using dot_kernel = void (*)(const DeviceContext&, using flatten_kernel = void (*)(const DeviceContext&, const DenseTensor&, int, int, DenseTensor*); +using empty_kernel = void (*)(const DeviceContext&, + const ScalarArray&, + DenseTensor*); + +using empty_like_kernel = void (*)(const DeviceContext&, DenseTensor*); using full_kernel = void (*)(const DeviceContext&, const ScalarArray&, const Scalar&, diff --git a/paddle/pten/api/lib/kernel_declare.h b/paddle/pten/api/lib/kernel_declare.h index e748a51082c52f..3beb77b6f0f140 100644 --- a/paddle/pten/api/lib/kernel_declare.h +++ b/paddle/pten/api/lib/kernel_declare.h @@ -20,7 +20,8 @@ limitations under the License. */ // the kernel declare statement is automatically generated according to the // file name of the kernel, and this header file will be removed -PT_DECLARE_KERNEL(full_like, CPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(empty, CPU, ALL_LAYOUT); +PT_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(dot, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(flatten, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(sign, CPU, ALL_LAYOUT); @@ -28,7 +29,8 @@ PT_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(conj, CPU, ALL_LAYOUT); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PT_DECLARE_KERNEL(full_like, CUDA, ALL_LAYOUT); +PT_DECLARE_KERNEL(empty, CUDA, ALL_LAYOUT); +PT_DECLARE_KERNEL(full, CUDA, ALL_LAYOUT); PT_DECLARE_KERNEL(dot, CUDA, ALL_LAYOUT); PT_DECLARE_KERNEL(flatten, CUDA, ALL_LAYOUT); PT_DECLARE_KERNEL(sign, CUDA, ALL_LAYOUT); diff --git a/paddle/pten/include/creation.h b/paddle/pten/include/creation.h index d685d262ebc1c9..c5decb5fc5bd2b 100644 --- a/paddle/pten/include/creation.h +++ b/paddle/pten/include/creation.h @@ -30,7 +30,7 @@ DenseTensor FullLike( DataType dtype = DataType::UNDEFINED, Backend backend = Backend::UNDEFINED, // Is backend needed here? DataLayout layout = DataLayout::UNDEFINED) { - auto out_meta = FullLikeInferMeta(x.meta(), dtype, layout); + auto out_meta = CreateLikeInferMeta(x.meta(), dtype, layout); pten::DenseTensor dense_out( pten::make_intrusive( dev_ctx.GetPlace()), diff --git a/paddle/pten/infermeta/nary.cc b/paddle/pten/infermeta/nary.cc index 8b12a88f10fc01..5287c5cca1439c 100644 --- a/paddle/pten/infermeta/nary.cc +++ b/paddle/pten/infermeta/nary.cc @@ -17,16 +17,16 @@ limitations under the License. */ namespace pten { -DenseTensorMeta FullInferMeta(const std::vector& shape, - DataType dtype, - DataLayout layout) { +DenseTensorMeta CreateInferMeta(const std::vector& shape, + DataType dtype, + DataLayout layout) { const auto& out_dims = paddle::framework::make_ddim(shape); return {dtype, out_dims, layout}; } -DenseTensorMeta FullInferMeta(const ScalarArray& shape, - DataType dtype, - DataLayout layout) { +DenseTensorMeta CreateInferMeta(const ScalarArray& shape, + DataType dtype, + DataLayout layout) { const auto& out_dims = paddle::framework::make_ddim(shape.GetData()); return {dtype, out_dims, layout}; } diff --git a/paddle/pten/infermeta/nary.h b/paddle/pten/infermeta/nary.h index 010accd2e79e54..721a39bb3ac31a 100644 --- a/paddle/pten/infermeta/nary.h +++ b/paddle/pten/infermeta/nary.h @@ -27,12 +27,12 @@ namespace pten { // Because functions in this file // not only can infer shape, but alse need infer lod or other useful data. -DenseTensorMeta FullInferMeta(const std::vector& shape, - DataType dtype, - DataLayout layout); +DenseTensorMeta CreateInferMeta(const std::vector& shape, + DataType dtype, + DataLayout layout); -DenseTensorMeta FullInferMeta(const ScalarArray& shape, - DataType dtype, - DataLayout layout); +DenseTensorMeta CreateInferMeta(const ScalarArray& shape, + DataType dtype, + DataLayout layout); } // namespace pten diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index 49d4a24e3a2c46..843a78f3413cf0 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -81,9 +81,9 @@ DenseTensorMeta CastInferMeta(const DenseTensorMeta& x_meta, return out_meta; } -DenseTensorMeta FullLikeInferMeta(const DenseTensorMeta& x_meta, - DataType dtype, - DataLayout layout) { +DenseTensorMeta CreateLikeInferMeta(const DenseTensorMeta& x_meta, + DataType dtype, + DataLayout layout) { return {dtype == DataType::UNDEFINED ? x_meta.dtype : dtype, x_meta.dims, layout == DataLayout::UNDEFINED ? x_meta.layout : layout}; diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index 3f28b2b48530ff..ae42cbd5dd2c6d 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -44,9 +44,9 @@ DenseTensorMeta FlattenInferMeta(const DenseTensorMeta& x_meta, DenseTensorMeta CastInferMeta(const DenseTensorMeta& x_meta, const DataType out_dtype); -DenseTensorMeta FullLikeInferMeta(const DenseTensorMeta& x_meta, - DataType dtype, - DataLayout layout); +DenseTensorMeta CreateLikeInferMeta(const DenseTensorMeta& x_meta, + DataType dtype, + DataLayout layout); DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta, const std::vector& shape); diff --git a/paddle/pten/kernels/cpu/CMakeLists.txt b/paddle/pten/kernels/cpu/CMakeLists.txt index 036ce68ee43c1e..814fd2446c7436 100644 --- a/paddle/pten/kernels/cpu/CMakeLists.txt +++ b/paddle/pten/kernels/cpu/CMakeLists.txt @@ -5,3 +5,4 @@ cc_library(manipulation_cpu SRCS manipulation.cc DEPS dense_tensor kernel_contex cc_library(scale_kernel_cpu SRCS scale_kernel.cc DEPS dense_tensor kernel_context kernel_factory eigen_function) cc_library(full_kernel_cpu SRCS full_kernel.cc DEPS dense_tensor kernel_context kernel_factory eigen_function) cc_library(conj_kernel_cpu SRCS conj_kernel.cc DEPS dense_tensor kernel_context kernel_factory) +cc_library(empty_kernel_cpu SRCS empty_kernel.cc DEPS dense_tensor kernel_context kernel_factory) diff --git a/paddle/pten/kernels/cpu/empty_kernel.cc b/paddle/pten/kernels/cpu/empty_kernel.cc new file mode 100644 index 00000000000000..654b0fd214a9de --- /dev/null +++ b/paddle/pten/kernels/cpu/empty_kernel.cc @@ -0,0 +1,41 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/empty_kernel.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/impl/empty_kernel_impl.h" + +PT_REGISTER_CTX_KERNEL(empty, + CPU, + ALL_LAYOUT, + pten::Empty, + bool, + int, + int64_t, + float, + double, + paddle::platform::float16) {} + +PT_REGISTER_CTX_KERNEL(empty_like, + CPU, + ALL_LAYOUT, + pten::EmptyLike, + bool, + int, + int64_t, + float, + double, + paddle::platform::float16) {} diff --git a/paddle/pten/kernels/cuda/CMakeLists.txt b/paddle/pten/kernels/cuda/CMakeLists.txt index 428b2762ca790e..50a3f1a9b30480 100644 --- a/paddle/pten/kernels/cuda/CMakeLists.txt +++ b/paddle/pten/kernels/cuda/CMakeLists.txt @@ -6,6 +6,7 @@ if(WITH_GPU) nv_library(scale_kernel_cuda SRCS scale_kernel.cu DEPS dense_tensor kernel_context kernel_factory eigen_function) nv_library(full_kernel_cuda SRCS full_kernel.cu DEPS dense_tensor kernel_context kernel_factory eigen_function) nv_library(conj_kernel_cuda SRCS conj_kernel.cu DEPS dense_tensor kernel_context kernel_factory) + nv_library(empty_kernel_cuda SRCS empty_kernel.cu DEPS dense_tensor kernel_context kernel_factory) elseif(WITH_ROCM) hip_library(math_cuda SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_cuda) hip_library(linalg_cuda SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory) @@ -14,4 +15,5 @@ elseif(WITH_ROCM) hip_library(scale_kernel_cuda SRCS scale_kernel.cu DEPS dense_tensor kernel_context kernel_factory eigen_function) hip_library(full_kernel_cuda SRCS full_kernel.cu DEPS dense_tensor kernel_context kernel_factory eigen_function) hip_library(conj_kernel_cuda SRCS conj_kernel.cu DEPS dense_tensor kernel_context kernel_factory) + hip_library(empty_kernel_cuda SRCS empty_kernel.cu DEPS dense_tensor kernel_context kernel_factory) endif() diff --git a/paddle/pten/kernels/cuda/empty_kernel.cu b/paddle/pten/kernels/cuda/empty_kernel.cu new file mode 100644 index 00000000000000..592da1d0e3df83 --- /dev/null +++ b/paddle/pten/kernels/cuda/empty_kernel.cu @@ -0,0 +1,41 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/kernels/empty_kernel.h" + +#include "paddle/pten/backends/cuda/cuda_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/impl/empty_kernel_impl.h" + +PT_REGISTER_CTX_KERNEL(empty, + CUDA, + ALL_LAYOUT, + pten::Empty, + bool, + int, + int64_t, + float, + double, + paddle::platform::float16) {} + +PT_REGISTER_CTX_KERNEL(empty_like, + CUDA, + ALL_LAYOUT, + pten::EmptyLike, + bool, + int, + int64_t, + float, + double, + paddle::platform::float16) {} diff --git a/paddle/pten/kernels/empty_kernel.h b/paddle/pten/kernels/empty_kernel.h new file mode 100644 index 00000000000000..7aa5a27765a198 --- /dev/null +++ b/paddle/pten/kernels/empty_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/common/scalar_array.h" +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void Empty(const ContextT& dev_ctx, const ScalarArray& shape, DenseTensor* out); + +template +void EmptyLike(const ContextT& dev_ctx, DenseTensor* out); + +} // namespace pten diff --git a/paddle/pten/kernels/full_kernel.h b/paddle/pten/kernels/full_kernel.h index f8abb9436679b1..d1139cf9ecefeb 100644 --- a/paddle/pten/kernels/full_kernel.h +++ b/paddle/pten/kernels/full_kernel.h @@ -14,7 +14,6 @@ #pragma once -#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" diff --git a/paddle/pten/kernels/impl/empty_kernel_impl.h b/paddle/pten/kernels/impl/empty_kernel_impl.h new file mode 100644 index 00000000000000..1010f8dd9601fd --- /dev/null +++ b/paddle/pten/kernels/impl/empty_kernel_impl.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/common/scalar_array.h" +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void Empty(const ContextT& dev_ctx, + const ScalarArray& shape, + DenseTensor* out) { + out->Resize(paddle::framework::make_ddim(shape.GetData())); +} + +template +void EmptyLike(const ContextT& dev_ctx, DenseTensor* out) { + out->mutable_data(); +} + +} // namespace pten diff --git a/paddle/pten/tests/api/CMakeLists.txt b/paddle/pten/tests/api/CMakeLists.txt index e85eb4c3294f19..bb1eab2c095518 100644 --- a/paddle/pten/tests/api/CMakeLists.txt +++ b/paddle/pten/tests/api/CMakeLists.txt @@ -12,6 +12,7 @@ cc_test(test_framework_place_utils storage SRCS test_place_utils.cc DEPS pten_ap cc_test(test_mean_api SRCS test_mean_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_dot_api SRCS test_dot_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS pten_tensor pten_api pten_api_utils) +cc_test(test_empty_api SRCS test_empty_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_fill_api SRCS test_fill_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS pten_tensor pten_api pten_api_utils) diff --git a/paddle/pten/tests/api/test_empty_api.cc b/paddle/pten/tests/api/test_empty_api.cc new file mode 100644 index 00000000000000..fcc01ad8a71720 --- /dev/null +++ b/paddle/pten/tests/api/test_empty_api.cc @@ -0,0 +1,127 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/api/include/api.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace paddle { +namespace tests { + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +// TODO(chenweihang): Remove this test after the API is used in the dygraph +TEST(API, empty_like) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + auto dense_x = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 2}), + pten::DataLayout::NCHW)); + + paddle::experimental::Tensor x(dense_x); + + // 2. test API + auto out = paddle::experimental::empty_like(x, pten::DataType::FLOAT32); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 3); + ASSERT_EQ(out.numel(), 6); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); +} + +TEST(API, empty1) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + auto dense_shape = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::INT64, + framework::make_ddim({2}), + pten::DataLayout::NCHW)); + auto* shape_data = dense_shape->mutable_data(); + shape_data[0] = 2; + shape_data[1] = 3; + + paddle::experimental::Tensor tensor_shape(dense_shape); + + // 2. test API + auto out = paddle::experimental::empty(tensor_shape, pten::DataType::FLOAT32); + + // 3. check result + ASSERT_EQ(out.shape().size(), 2UL); + ASSERT_EQ(out.shape()[0], 2); + ASSERT_EQ(out.numel(), 6); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); +} + +TEST(API, empty2) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + auto dense_scalar = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::INT32, + framework::make_ddim({1}), + pten::DataLayout::NCHW)); + dense_scalar->mutable_data()[0] = 2; + + paddle::experimental::Tensor shape_scalar1(dense_scalar); + paddle::experimental::Tensor shape_scalar2(dense_scalar); + std::vector list_shape{shape_scalar1, + shape_scalar2}; + + auto out = paddle::experimental::empty(list_shape, pten::DataType::FLOAT32); + + ASSERT_EQ(out.shape().size(), 2UL); + ASSERT_EQ(out.shape()[0], 2); + ASSERT_EQ(out.numel(), 4); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); +} + +TEST(API, empty3) { + std::vector vector_shape{2, 3}; + + auto out = paddle::experimental::empty(vector_shape, pten::DataType::INT32); + + ASSERT_EQ(out.shape().size(), 2UL); + ASSERT_EQ(out.shape()[0], 2); + ASSERT_EQ(out.numel(), 6); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::INT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); +} + +} // namespace tests +} // namespace paddle diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 0c410d9b66fe99..ca7a8f79664950 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -36,6 +36,32 @@ kernel : func : dot +- api : empty + args : (const ScalarArray& shape, DataType dtype=DataType::FLOAT32, Backend place=Backend::CPU, DataLayout layout=DataLayout::NCHW) + output: Tensor + infer_meta : + func : CreateInferMeta + param : [shape, dtype, layout] + kernel : + func : empty + param : [shape] + data_type : dtype + backend : place + layout : layout + +- api : empty_like + args : (const Tensor& x, DataType dtype = DataType::UNDEFINED, Backend place = Backend::UNDEFINED, DataLayout layout = DataLayout::UNDEFINED) + output: Tensor + infer_meta : + func : CreateLikeInferMeta + param : [x, dtype, layout] + kernel : + func : empty_like + param : [] + data_type : dtype > x + backend : place > x + layout : layout > x + - api : flatten args : (const Tensor& x, int start_axis, int stop_axis) output : Tensor @@ -48,7 +74,7 @@ args : (const ScalarArray& shape, const Scalar& value, DataType dtype=DataType::FLOAT32, Backend place=Backend::CPU, DataLayout layout=DataLayout::NCHW) output: Tensor infer_meta : - func : FullInferMeta + func : CreateInferMeta param : [shape, dtype, layout] kernel : func : full @@ -61,7 +87,7 @@ args : (const Tensor& x, const Scalar& value, DataType dtype = DataType::UNDEFINED, Backend place = Backend::UNDEFINED, DataLayout layout = DataLayout::UNDEFINED) output: Tensor infer_meta : - func : FullLikeInferMeta + func : CreateLikeInferMeta param : [x, dtype, layout] kernel : func : full_like diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index 029985475011ea..97f3c774f3f4d3 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -43,12 +43,11 @@ def __init__(self, api_item_yaml): if 'data_type' not in self.kernel or len(self.kernel[ 'data_type']) == 0: self.kernel['data_type'] = None - if 'param' not in self.kernel or len(self.kernel['param']) == 0: + if 'param' not in self.kernel: self.kernel['param'] = None self.infer_meta = api_item_yaml['infer_meta'] - if 'param' not in self.infer_meta or len(self.infer_meta[ - 'param']) == 0: + if 'param' not in self.infer_meta: self.infer_meta['param'] = None def parse_args(self, args_str): From 3852fff5663df6328840e82ca4ad98aedcd86650 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 22 Dec 2021 12:10:53 +0000 Subject: [PATCH 2/2] add empty dev_api --- paddle/pten/include/creation.h | 48 ++++++ paddle/pten/kernels/cpu/CMakeLists.txt | 1 - paddle/pten/kernels/cuda/CMakeLists.txt | 19 --- paddle/pten/kernels/{cpu => }/empty_kernel.cc | 43 +++++- paddle/pten/kernels/gpu/CMakeLists.txt | 2 - paddle/pten/kernels/gpu/empty_kernel.cu | 41 ----- paddle/pten/kernels/impl/empty_kernel_impl.h | 34 ----- paddle/pten/tests/kernels/CMakeLists.txt | 2 +- .../tests/kernels/test_creation_dev_api.cc | 141 ++++++++++++++++++ .../pten/tests/kernels/test_fill_dev_api.cc | 66 -------- 10 files changed, 231 insertions(+), 166 deletions(-) delete mode 100644 paddle/pten/kernels/cuda/CMakeLists.txt rename paddle/pten/kernels/{cpu => }/empty_kernel.cc (53%) delete mode 100644 paddle/pten/kernels/gpu/empty_kernel.cu delete mode 100644 paddle/pten/kernels/impl/empty_kernel_impl.h create mode 100644 paddle/pten/tests/kernels/test_creation_dev_api.cc delete mode 100644 paddle/pten/tests/kernels/test_fill_dev_api.cc diff --git a/paddle/pten/include/creation.h b/paddle/pten/include/creation.h index c5decb5fc5bd2b..73c5999ca9247a 100644 --- a/paddle/pten/include/creation.h +++ b/paddle/pten/include/creation.h @@ -16,12 +16,60 @@ #include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/include/infermeta.h" +#include "paddle/pten/kernels/empty_kernel.h" #include "paddle/pten/kernels/full_kernel.h" namespace pten { // TODO(YuanRisheng) This function name should be same as User API name. // TODO(zyfncg) Automatic code generation +template +DenseTensor Empty(const ContextT& dev_ctx, + const ScalarArray& shape, + DataType dtype = DataType::FLOAT32, + Backend backend = Backend::CPU, // Is backend needed here? + DataLayout layout = DataLayout::NCHW) { + auto out_meta = CreateInferMeta(shape, dtype, layout); + pten::DenseTensor dense_out( + pten::make_intrusive( + dev_ctx.GetPlace()), + std::move(out_meta)); + Empty(dev_ctx, shape, &dense_out); + return dense_out; +} + +template +DenseTensor EmptyLike( + const ContextT& dev_ctx, + const DenseTensor& x, + DataType dtype = DataType::UNDEFINED, + Backend backend = Backend::UNDEFINED, // Is backend needed here? + DataLayout layout = DataLayout::UNDEFINED) { + auto out_meta = CreateLikeInferMeta(x.meta(), dtype, layout); + pten::DenseTensor dense_out( + pten::make_intrusive( + dev_ctx.GetPlace()), + std::move(out_meta)); + EmptyLike(dev_ctx, &dense_out); + return dense_out; +} + +template +DenseTensor Full(const ContextT& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DataType dtype = DataType::FLOAT32, + Backend backend = Backend::CPU, // Is backend needed here? + DataLayout layout = DataLayout::NCHW) { + auto out_meta = CreateInferMeta(shape, dtype, layout); + pten::DenseTensor dense_out( + pten::make_intrusive( + dev_ctx.GetPlace()), + std::move(out_meta)); + Full(dev_ctx, shape, val, &dense_out); + return dense_out; +} + template DenseTensor FullLike( const ContextT& dev_ctx, diff --git a/paddle/pten/kernels/cpu/CMakeLists.txt b/paddle/pten/kernels/cpu/CMakeLists.txt index 7afe1caed8543f..7a32fab2674c34 100644 --- a/paddle/pten/kernels/cpu/CMakeLists.txt +++ b/paddle/pten/kernels/cpu/CMakeLists.txt @@ -3,4 +3,3 @@ cc_library(linalg_cpu SRCS linalg.cc DEPS dense_tensor kernel_context kernel_fac cc_library(utils_cpu SRCS utils.cc DEPS dense_tensor kernel_context kernel_factory memory convert_utils) cc_library(manipulation_cpu SRCS manipulation.cc DEPS dense_tensor kernel_context kernel_factory utils_cpu unary) cc_library(conj_kernel_cpu SRCS conj_kernel.cc DEPS dense_tensor kernel_context kernel_factory) -cc_library(empty_kernel_cpu SRCS empty_kernel.cc DEPS dense_tensor kernel_context kernel_factory) diff --git a/paddle/pten/kernels/cuda/CMakeLists.txt b/paddle/pten/kernels/cuda/CMakeLists.txt deleted file mode 100644 index 50a3f1a9b30480..00000000000000 --- a/paddle/pten/kernels/cuda/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -if(WITH_GPU) - nv_library(math_cuda SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_cuda) - nv_library(linalg_cuda SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory) - nv_library(utils_cuda SRCS utils.cu DEPS dense_tensor kernel_context kernel_factory memory convert_utils) - nv_library(manipulation_cuda SRCS manipulation.cu DEPS dense_tensor kernel_context kernel_factory utils_cuda unary) - nv_library(scale_kernel_cuda SRCS scale_kernel.cu DEPS dense_tensor kernel_context kernel_factory eigen_function) - nv_library(full_kernel_cuda SRCS full_kernel.cu DEPS dense_tensor kernel_context kernel_factory eigen_function) - nv_library(conj_kernel_cuda SRCS conj_kernel.cu DEPS dense_tensor kernel_context kernel_factory) - nv_library(empty_kernel_cuda SRCS empty_kernel.cu DEPS dense_tensor kernel_context kernel_factory) -elseif(WITH_ROCM) - hip_library(math_cuda SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_cuda) - hip_library(linalg_cuda SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory) - hip_library(utils_cuda SRCS utils.cu DEPS dense_tensor kernel_context kernel_factory memory convert_utils) - hip_library(manipulation_cuda SRCS manipulation.cu DEPS dense_tensor kernel_context kernel_factory utils_cuda unary) - hip_library(scale_kernel_cuda SRCS scale_kernel.cu DEPS dense_tensor kernel_context kernel_factory eigen_function) - hip_library(full_kernel_cuda SRCS full_kernel.cu DEPS dense_tensor kernel_context kernel_factory eigen_function) - hip_library(conj_kernel_cuda SRCS conj_kernel.cu DEPS dense_tensor kernel_context kernel_factory) - hip_library(empty_kernel_cuda SRCS empty_kernel.cu DEPS dense_tensor kernel_context kernel_factory) -endif() diff --git a/paddle/pten/kernels/cpu/empty_kernel.cc b/paddle/pten/kernels/empty_kernel.cc similarity index 53% rename from paddle/pten/kernels/cpu/empty_kernel.cc rename to paddle/pten/kernels/empty_kernel.cc index 654b0fd214a9de..4c6d8706e0ff3e 100644 --- a/paddle/pten/kernels/cpu/empty_kernel.cc +++ b/paddle/pten/kernels/empty_kernel.cc @@ -14,9 +14,24 @@ limitations under the License. */ #include "paddle/pten/kernels/empty_kernel.h" -#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/backends/all_context.h" #include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/impl/empty_kernel_impl.h" + +namespace pten { + +template +void Empty(const ContextT& dev_ctx, + const ScalarArray& shape, + DenseTensor* out) { + out->Resize(paddle::framework::make_ddim(shape.GetData())); +} + +template +void EmptyLike(const ContextT& dev_ctx, DenseTensor* out) { + out->mutable_data(); +} + +} // namespace pten PT_REGISTER_CTX_KERNEL(empty, CPU, @@ -39,3 +54,27 @@ PT_REGISTER_CTX_KERNEL(empty_like, float, double, paddle::platform::float16) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_REGISTER_CTX_KERNEL(empty, + GPU, + ALL_LAYOUT, + pten::Empty, + bool, + int, + int64_t, + float, + double, + paddle::platform::float16) {} + +PT_REGISTER_CTX_KERNEL(empty_like, + GPU, + ALL_LAYOUT, + pten::EmptyLike, + bool, + int, + int64_t, + float, + double, + paddle::platform::float16) {} +#endif diff --git a/paddle/pten/kernels/gpu/CMakeLists.txt b/paddle/pten/kernels/gpu/CMakeLists.txt index d7d8a3d717674d..a0646e1cb78792 100644 --- a/paddle/pten/kernels/gpu/CMakeLists.txt +++ b/paddle/pten/kernels/gpu/CMakeLists.txt @@ -4,12 +4,10 @@ if(WITH_GPU) nv_library(utils_gpu SRCS utils.cu DEPS dense_tensor kernel_context kernel_factory memory convert_utils) nv_library(manipulation_gpu SRCS manipulation.cu DEPS dense_tensor kernel_context kernel_factory utils_gpu unary) nv_library(conj_kernel_gpu SRCS conj_kernel.cu DEPS dense_tensor kernel_context kernel_factory) - nv_library(empty_kernel_gpu SRCS empty_kernel.cu DEPS dense_tensor kernel_context kernel_factory) elseif(WITH_ROCM) hip_library(math_gpu SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_gpu) hip_library(linalg_gpu SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory) hip_library(utils_gpu SRCS utils.cu DEPS dense_tensor kernel_context kernel_factory memory convert_utils) hip_library(manipulation_gpu SRCS manipulation.cu DEPS dense_tensor kernel_context kernel_factory utils_gpu unary) hip_library(conj_kernel_gpu SRCS conj_kernel.cu DEPS dense_tensor kernel_context kernel_factory) - hip_library(empty_kernel_gpu SRCS empty_kernel.cu DEPS dense_tensor kernel_context kernel_factory) endif() diff --git a/paddle/pten/kernels/gpu/empty_kernel.cu b/paddle/pten/kernels/gpu/empty_kernel.cu deleted file mode 100644 index 3c22207db992ad..00000000000000 --- a/paddle/pten/kernels/gpu/empty_kernel.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/pten/kernels/empty_kernel.h" - -#include "paddle/pten/backends/gpu/gpu_context.h" -#include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/impl/empty_kernel_impl.h" - -PT_REGISTER_CTX_KERNEL(empty, - GPU, - ALL_LAYOUT, - pten::Empty, - bool, - int, - int64_t, - float, - double, - paddle::platform::float16) {} - -PT_REGISTER_CTX_KERNEL(empty_like, - GPU, - ALL_LAYOUT, - pten::EmptyLike, - bool, - int, - int64_t, - float, - double, - paddle::platform::float16) {} diff --git a/paddle/pten/kernels/impl/empty_kernel_impl.h b/paddle/pten/kernels/impl/empty_kernel_impl.h deleted file mode 100644 index 1010f8dd9601fd..00000000000000 --- a/paddle/pten/kernels/impl/empty_kernel_impl.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/pten/common/scalar_array.h" -#include "paddle/pten/core/dense_tensor.h" - -namespace pten { - -template -void Empty(const ContextT& dev_ctx, - const ScalarArray& shape, - DenseTensor* out) { - out->Resize(paddle::framework::make_ddim(shape.GetData())); -} - -template -void EmptyLike(const ContextT& dev_ctx, DenseTensor* out) { - out->mutable_data(); -} - -} // namespace pten diff --git a/paddle/pten/tests/kernels/CMakeLists.txt b/paddle/pten/tests/kernels/CMakeLists.txt index 3a626aad2deb5d..5f14554a0cceeb 100644 --- a/paddle/pten/tests/kernels/CMakeLists.txt +++ b/paddle/pten/tests/kernels/CMakeLists.txt @@ -1,6 +1,6 @@ cc_test(test_copy_dev_api SRCS test_copy_dev_api.cc DEPS pten pten_api_utils) cc_test(test_dot_dev_api SRCS test_dot_dev_api.cc DEPS pten pten_api_utils) -cc_test(test_fill_dev_api SRCS test_fill_dev_api.cc DEPS pten pten_api_utils) +cc_test(test_creation_dev_api SRCS test_creation_dev_api.cc DEPS pten pten_api_utils) cc_test(test_flatten_dev_api SRCS test_flatten_dev_api.cc DEPS pten pten_api_utils) cc_test(test_mean_dev_api SRCS test_mean_dev_api.cc DEPS pten pten_api_utils) cc_test(test_scale_dev_api SRCS test_scale_dev_api.cc DEPS pten pten_api_utils) diff --git a/paddle/pten/tests/kernels/test_creation_dev_api.cc b/paddle/pten/tests/kernels/test_creation_dev_api.cc new file mode 100644 index 00000000000000..8469b94b797c87 --- /dev/null +++ b/paddle/pten/tests/kernels/test_creation_dev_api.cc @@ -0,0 +1,141 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/include/creation.h" + +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace pten { +namespace tests { + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +TEST(DEV_API, empty) { + // 1. create input + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + + // 2. test API + auto out = pten::Empty( + *(static_cast(dev_ctx)), + {3, 2}, + pten::DataType::INT32); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 3); + ASSERT_EQ(out.numel(), 6); + ASSERT_EQ(out.meta().dtype, pten::DataType::INT32); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); +} + +TEST(DEV_API, empty_like) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x(alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 2}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(); + dense_x_data[0] = 0; + + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + + // 2. test API + auto out = pten::EmptyLike( + *(static_cast(dev_ctx)), dense_x); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 3); + ASSERT_EQ(out.numel(), 6); + ASSERT_EQ(out.meta().dtype, pten::DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); +} + +TEST(DEV_API, full) { + // 1. create input + float val = 1.0; + + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + + // 2. test API + auto out = pten::Full( + *(static_cast(dev_ctx)), + {3, 2}, + val, + pten::DataType::FLOAT32); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 3); + ASSERT_EQ(out.numel(), 6); + ASSERT_EQ(out.meta().dtype, pten::DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); + + auto* actual_result = out.data(); + for (auto i = 0; i < 6; i++) { + ASSERT_NEAR(actual_result[i], val, 1e-6f); + } +} + +TEST(DEV_API, full_like) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x(alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 2}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(); + dense_x_data[0] = 0; + float val = 1.0; + + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + + // 2. test API + auto out = pten::FullLike( + *(static_cast(dev_ctx)), + dense_x, + val); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 3); + ASSERT_EQ(out.numel(), 6); + ASSERT_EQ(out.meta().dtype, pten::DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); + + auto* actual_result = out.data(); + for (auto i = 0; i < 6; i++) { + ASSERT_NEAR(actual_result[i], val, 1e-6f); + } +} + +} // namespace tests +} // namespace pten diff --git a/paddle/pten/tests/kernels/test_fill_dev_api.cc b/paddle/pten/tests/kernels/test_fill_dev_api.cc deleted file mode 100644 index 9a8b1f94e731b5..00000000000000 --- a/paddle/pten/tests/kernels/test_fill_dev_api.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -#include "paddle/pten/include/creation.h" - -#include "paddle/pten/api/lib/utils/allocator.h" -#include "paddle/pten/core/dense_tensor.h" -#include "paddle/pten/core/kernel_registry.h" - -namespace pten { -namespace tests { - -namespace framework = paddle::framework; -using DDim = paddle::framework::DDim; - -TEST(DEV_API, fill_any_like) { - // 1. create tensor - const auto alloc = std::make_shared( - paddle::platform::CPUPlace()); - pten::DenseTensor dense_x(alloc, - pten::DenseTensorMeta(pten::DataType::FLOAT32, - framework::make_ddim({3, 2}), - pten::DataLayout::NCHW)); - auto* dense_x_data = dense_x.mutable_data(); - dense_x_data[0] = 0; - float val = 1.0; - - paddle::platform::DeviceContextPool& pool = - paddle::platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); - - // 2. test API - auto out = pten::FullLike( - *(static_cast(dev_ctx)), - dense_x, - val); - - // 3. check result - ASSERT_EQ(out.dims().size(), 2); - ASSERT_EQ(out.dims()[0], 3); - ASSERT_EQ(out.numel(), 6); - ASSERT_EQ(out.meta().dtype, pten::DataType::FLOAT32); - ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); - - auto* actual_result = out.data(); - for (auto i = 0; i < 6; i++) { - ASSERT_NEAR(actual_result[i], val, 1e-6f); - } -} - -} // namespace tests -} // namespace pten