From e11ecfce86376857f2f1624a1a0866d538700c72 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 2 Nov 2021 11:31:10 +0800 Subject: [PATCH 1/6] Add matmul_v2 kernel in pten (#36844) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial tensor design & sign kernel demo * add move constructor for meta & add lodtensor * add dirs & sign xpu kernel * add mean cpu&cuda kernel impl * move sign & mean xpu & npu kernel * add selected_rows basic impl * refactor design, BaseTensor to DenseTensor, etc. * add scale mkldnn kernel * polish xpu & npu impl details * fix mkldnn reuse compile failed * change tensor operation lib name * rename util filename * add more comments * change TensorImplInterface to TensorInterface * add kernel key and factory * remove MKLDNNTensorMeta, add MKLDNNDenseTensor * change XXDeviceContext to XXContext * add base kernel registrar utils & test on sign * replace boost::any by paddle::any * fix several ci failed * fix npu compile error * add ordered map util * fix multiple ordered_map compile errors * move dev into include dir * support sign op in static op run * fix static op run error * fix new executor compile failed * add dygraph branch & remove sign_op.h * fix test_infer_no_need_buffer_slots * fix rocm compile link error * fix unitybuild error & clear glog * fix npu compile failed * skip quant trans test * fix part windows compile problem * fix xpu enforce error * fix inference test failed * remove ordered_map to solve quant failed * fix part of rcom compile faild * add more register kernels * revert scale kernel temporarily * fix code format error * add new kernel registrar marco * rename top to tcmpt * revert xpu, npu, mkldnn impl & remove op def * add kernel args parse functor to auto parse args * revert some change & add scale kernels * add op proto in dygraph kernelcontext building * polish kernel dispatch logic & nameing rule * fix scale kernel match error * fix scale test failed * add mean API and unittest * test mean api success * add branch to solve compiled error * skip clang format error * add mean skip rule in op_library * add dot kernel, api and unittest (#6) * remove old kernel and add symbol link * fix dot compiled failed * add merco for module declare * fix npu and xpu compile error * revert sign, mean, scale, dot kernel removing * add comment for keeping old kernel impl * fix mutable_data error * fix bfloat16 conflit * fix inference undef error * adapt to msvc compile rules * polish comment for template inst * add cmake template instantiation for win * fix backend to place device id bug * fix ifdef error * Op2functor (#7) * add kernel args maker class * make args maker non-const * remove debug log * modify codes by review options * split constructPrKernelContext function * fix output name bug * fix test_mean_op test_sign_op failed * fill_any_like kernel refactor (#10) * fill_any_like kernel refactor * remove useless code of full_like c++ api * skip dtype for fill_any_like * add attrs for kernel key constrcut * add use_pt_kernel Flags to control whether to use pt kernel (#13) * add use_pt_kernel Flags to control whether to use pt kernel * change the default value to true for cheking pt kernels * fix mutable_data cuda place error * move high level apis into hapi * remove selectedrows adapting temporarily * Support Scalar in Tensor Compute Library (#14) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * remove mkldnn tensor & polish details * use flat_hash_map and small_vector in kernel factory * Refactor flatten kernel (#12) * refactor flatten kernel * update infershape function * fix compile bugs * fix bugs when merge * fix compiler bugs * fix bugs when run test_flatten_api * fix bugs when run test * Revert "use flat_hash_map and small_vector in kernel factory" This reverts commit 23091495cfdd3df8cc1be592d30f09ea66a7c72b. * Move cpu, cuda and other device code into kernels (#15) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * start refactor matmul * move cpu, cuda and other device modules into kernels * merge code * polish code in operator.cc * Perfect unitests (#16) * perfect unittest * update license * replace with flat_hash_map, small_vector (#19) * fix small_vector build error on windows platform * replace with flat_hash_map, small_vector * remove todo * Perfect unitests (#20) * perfect unittest * update license * fix bug when run tcmpt_utils_test * refactor execution adapting impl * fix insert conflit * Fix CI bug of test_yolov3 (#21) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * start refactor matmul * move cpu, cuda and other device modules into kernels * merge code * polish code in operator.cc * Fix CI bug of test_yolov3 * add the tensor base class, test=develop (#17) * update the tensor base class, test=develop * remove two funcs, test=develop * update the error msg, test=develop Co-authored-by: Chen Weihang * [no-verify] commit backend and tensor signature changes * Rename tcmpt to pten (#23) * rename tcmpt to pten * update omitted files for rename to pten * update omitted file for rename to pten * remove k of all enum var * remove kernel_instantiate (#26) * remove symbols and spatial_tensor * change common to functions * readd share tensor impl methods * add a candidate dense tensor class, test=develop (#28) * change all Pt to Pten * resolve conflit with xiaowei * Op2functor opt1 (#27) * replace to small vector and change to const & * add std::move Co-authored-by: Chen Weihang * polish kernel factory and kernel registry * fix operator test error msg mismatch * remove tensor signature and backend set member * move scalar and polish enforce * revert dtype layout change to fix error * fix enum operator override error * add several base unittests * add pten utils tests * polish some details * Dev/op2func refactor 3 (#30) * add a candidate dense tensor class, test=develop * remove TensorBase::backend(), test=develop * remove some ops, test=develop * cherry-pick the pr of tensor meta, test=develop * moves the dense tensor and some ops, test=develop * update the linalg operator, test=develop * update other operators, test=develop * fix errors, test=develop * fix bugs, test=develop * try to resolve the problem of windows ci, test=develop * updates codes, test=develop * fix the tensor_utils.cc, test=develop * modify the dense tensor, test=develop * fix the data type, test=develop Co-authored-by: shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> * polish some details * polish kernel signature details * fix a bug about offsets of the tensor, test=develop (#31) Co-authored-by: shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> * add matmul kernel in pten * add unittest for new matmul_v2 kernel * fix bug of CI compile * fix bug of CI compile * merge conflict * remove useless file Co-authored-by: Chen Weihang Co-authored-by: chentianyu03 Co-authored-by: YuanRisheng Co-authored-by: 石晓伟 <39303645+Shixiaowei02@users.noreply.github.com> --- paddle/fluid/framework/operator.cc | 1 + paddle/fluid/imperative/prepared_operator.cc | 1 + paddle/fluid/operators/matmul_v2_op.h | 25 +- paddle/pten/hapi/include/linalg.h | 5 + paddle/pten/hapi/lib/linalg.cc | 43 +- paddle/pten/infershape/binary.cc | 70 +++ paddle/pten/infershape/binary.h | 5 + paddle/pten/kernels/cpu/linalg.cc | 27 + paddle/pten/kernels/cpu/linalg.h | 2 +- paddle/pten/kernels/cuda/linalg.cu | 33 ++ paddle/pten/kernels/cuda/linalg.h | 8 + paddle/pten/kernels/cuda/utils.h | 5 + .../pten/kernels/functions/math/matmul_func.h | 491 ++++++++++++++++++ paddle/pten/tests/CMakeLists.txt | 1 + paddle/pten/tests/test_fill_api.cc | 1 - paddle/pten/tests/test_matmul_api.cc | 160 ++++++ 16 files changed, 866 insertions(+), 12 deletions(-) create mode 100644 paddle/pten/kernels/functions/math/matmul_func.h create mode 100644 paddle/pten/tests/test_matmul_api.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d317aac8594b4c..16e63e433e6403 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/op_call_stack.h" +#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/unused_var_check.h" diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index b2d55babc7e1c1..db26c66958140b 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" +#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/imperative/infer_shape_context.h" #include "paddle/pten/common/scalar.h" #include "paddle/utils/small_vector.h" diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 2ba82243fe6b5b..1d361ae45dad9b 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -25,6 +25,11 @@ limitations under the License. */ #include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" +// only can include the headers in paddle/pten/api dirs +#include "paddle/pten/api/include/core.h" +#include "paddle/pten/api/include/linalg.h" +#include "paddle/pten/hapi/lib/utils/tensor_utils.h" + #if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #endif @@ -380,15 +385,17 @@ class MatMulV2Kernel : public framework::OpKernel { auto* Out = ctx.Output("Out"); bool trans_x = ctx.Attr("trans_x"); bool trans_y = ctx.Attr("trans_y"); - PADDLE_ENFORCE_NE(framework::product(X->dims()), 0, - platform::errors::InvalidArgument( - "The Input(X) dims size must not be equal 0," - " but reviced dims size is 0. ")); - PADDLE_ENFORCE_NE(framework::product(Y->dims()), 0, - platform::errors::InvalidArgument( - "The Input(Y) dims size must not be equal 0," - " but reviced dims size is 0. ")); - MatMulFunction(X, Y, Out, trans_x, trans_y, ctx); + + auto& dev_ctx = ctx.device_context(); + Out->mutable_data(X->place()); + + auto pt_x = paddle::experimental::MakePtenDenseTensor(*X); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*Y); + auto pt_out = paddle::experimental::MakePtenDenseTensor(*Out); + + // call new kernel + pten::Matmul(dev_ctx, *pt_x.get(), *pt_y.get(), trans_x, trans_y, + pt_out.get()); } }; diff --git a/paddle/pten/hapi/include/linalg.h b/paddle/pten/hapi/include/linalg.h index fd628ea19334e8..6e78b50af11c36 100644 --- a/paddle/pten/hapi/include/linalg.h +++ b/paddle/pten/hapi/include/linalg.h @@ -21,5 +21,10 @@ namespace experimental { Tensor dot(const Tensor& x, const Tensor& y); +Tensor matmul(const Tensor& x, + const Tensor& y, + bool transpose_x, + bool transpose_y); + } // namespace experimental } // namespace paddle diff --git a/paddle/pten/hapi/lib/linalg.cc b/paddle/pten/hapi/lib/linalg.cc index 54829feb43a246..3f13f546ee25e0 100644 --- a/paddle/pten/hapi/lib/linalg.cc +++ b/paddle/pten/hapi/lib/linalg.cc @@ -25,7 +25,6 @@ limitations under the License. */ #include "paddle/pten/core/kernel_context.h" #include "paddle/pten/hapi/lib/kernel_dispatch.h" #include "paddle/pten/hapi/lib/utils/allocator.h" -#include "paddle/pten/infershape/binary.h" namespace paddle { namespace experimental { @@ -65,5 +64,47 @@ Tensor dot(const Tensor& x, const Tensor& y) { return out; } +Tensor matmul(const Tensor& x, + const Tensor& y, + bool transpose_x, + bool transpose_y) { + // 1. Get kernel signature and kernel + auto kernel_key_set = ParseKernelKeyByInputArgs(x, y); + auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( + "matmul_v2", kernel_key); + + // 2. Get Device Context + auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto kernel_context = pten::KernelContext(*dev_ctx); + + // 3. Auto data transform + auto dense_x = std::dynamic_pointer_cast(x.impl()); + auto dense_y = std::dynamic_pointer_cast(y.impl()); + kernel_context.EmplaceBackInput(dense_x); + kernel_context.EmplaceBackInput(dense_y); + kernel_context.EmplaceBackAttr(transpose_x); + kernel_context.EmplaceBackAttr(transpose_y); + // TODO(chenweihang): add transform impl + + // 4. InferShape + auto out_meta = MatmulInferShape( + dense_x->meta(), dense_y->meta(), transpose_x, transpose_y); + + // 5. Prepare outputs + const auto allocator = std::make_shared( + pten::TransToFluidPlace(kernel_key.backend())); + auto dense_out = std::make_shared(allocator, out_meta); + kernel_context.EmplaceBackOutput(dense_out); + + Tensor out; + out.set_impl(dense_out); + + // 6. Call kernel + kernel(&kernel_context); + + return out; +} + } // namespace experimental } // namespace paddle diff --git a/paddle/pten/infershape/binary.cc b/paddle/pten/infershape/binary.cc index c2b88c74d847e3..c17e0871581834 100644 --- a/paddle/pten/infershape/binary.cc +++ b/paddle/pten/infershape/binary.cc @@ -59,4 +59,74 @@ DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta, return return_meta; } +DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta, + const DenseTensorMeta& y_meta, + bool trans_x, + bool trans_y) { + std::vector dims_x = paddle::framework::vectorize(x_meta.dims); + std::vector dims_y = paddle::framework::vectorize(y_meta.dims); + auto ndims_x = dims_x.size(); + auto ndims_y = dims_y.size(); + PADDLE_ENFORCE_GT(ndims_x, + 0, + paddle::platform::errors::InvalidArgument( + "The Input(x) dims size must be greater than 0," + " but reviced dims size is 0. ")); + PADDLE_ENFORCE_GT(ndims_y, + 0, + paddle::platform::errors::InvalidArgument( + "The Input(y) dims size must be greater than 0," + " but reviced dims size is 0. ")); + + bool x_broadcasted = false, y_broadcasted = false; + if (ndims_x == 1) { + dims_x.insert(dims_x.begin(), 1); + ndims_x = 2; + x_broadcasted = true; + } + + if (ndims_y == 1) { + dims_y.push_back(1); + ndims_y = 2; + y_broadcasted = true; + } + + size_t M, N; + if (trans_x) { + M = dims_x[ndims_x - 1]; + } else { + M = dims_x[ndims_x - 2]; + } + if (trans_y) { + N = dims_y[ndims_y - 2]; + } else { + N = dims_y[ndims_y - 1]; + } + + std::vector new_dims; + if (ndims_x > ndims_y) { + new_dims.assign(dims_x.begin(), dims_x.end() - 2); + } else if (ndims_x < ndims_y) { + new_dims.assign(dims_y.begin(), dims_y.end() - 2); + } else { + new_dims.reserve(ndims_x); + for (size_t i = 0; i < ndims_x - 2; ++i) { + new_dims.push_back(std::max(dims_x[i], dims_y[i])); + } + } + if (!x_broadcasted) { + new_dims.push_back(M); + } + if (!y_broadcasted) { + new_dims.push_back(N); + } + if (x_broadcasted && y_broadcasted) { + new_dims.push_back(1); + } + + auto ddim_out = paddle::framework::make_ddim(new_dims); + + return {x_meta.type, ddim_out, x_meta.layout}; +} + } // namespace pten diff --git a/paddle/pten/infershape/binary.h b/paddle/pten/infershape/binary.h index 613d2f66a6edd4..f58e5503f22a17 100644 --- a/paddle/pten/infershape/binary.h +++ b/paddle/pten/infershape/binary.h @@ -36,4 +36,9 @@ namespace pten { DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta, const DenseTensorMeta& y_meta); +DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta, + const DenseTensorMeta& y_meta, + bool trans_x, + bool trans_y); + } // namespace pten diff --git a/paddle/pten/kernels/cpu/linalg.cc b/paddle/pten/kernels/cpu/linalg.cc index df401370c881ff..ced13dc41d1ae1 100644 --- a/paddle/pten/kernels/cpu/linalg.cc +++ b/paddle/pten/kernels/cpu/linalg.cc @@ -21,6 +21,8 @@ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/complex.h" +#include "paddle/pten/kernels/functions/math/matmul_func.h" + namespace pten { template @@ -45,6 +47,27 @@ void Dot(const CPUContext& dev_ctx, } } +template +void Matmul(const CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y, + DenseTensor* out) { + PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()), + 0, + paddle::platform::errors::InvalidArgument( + "The Input(X) dims size must not be equal 0," + " but reviced dims size is 0. ")); + PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()), + 0, + paddle::platform::errors::InvalidArgument( + "The Input(Y) dims size must not be equal 0," + " but reviced dims size is 0. ")); + math::MatMulFunction( + dev_ctx, x, y, out, transpose_x, transpose_y); +} + } // namespace pten PT_REGISTER_MODULE(LinalgCPU); @@ -62,3 +85,7 @@ PT_REGISTER_KERNEL("dot", int64_t, complex64, complex128) {} + +PT_REGISTER_KERNEL( + "matmul_v2", CPU, ANY, pten::Matmul, float, double, complex64, complex128) { +} diff --git a/paddle/pten/kernels/cpu/linalg.h b/paddle/pten/kernels/cpu/linalg.h index a9447be74934c7..a954033866f177 100644 --- a/paddle/pten/kernels/cpu/linalg.h +++ b/paddle/pten/kernels/cpu/linalg.h @@ -30,7 +30,7 @@ void Dot(const CPUContext& dev_ctx, DenseTensor* out); template -void matmul(const CPUContext& dev_ctx, +void Matmul(const CPUContext& dev_ctx, const DenseTensor& x, const DenseTensor& y, bool transpose_x, diff --git a/paddle/pten/kernels/cuda/linalg.cu b/paddle/pten/kernels/cuda/linalg.cu index 928a09a4edbfff..6811afa8a49ff5 100644 --- a/paddle/pten/kernels/cuda/linalg.cu +++ b/paddle/pten/kernels/cuda/linalg.cu @@ -16,6 +16,7 @@ #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/kernels/functions/eigen/dot.h" +#include "paddle/pten/kernels/functions/math/matmul_func.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/complex.h" @@ -30,10 +31,32 @@ void Dot(const CUDAContext& dev_ctx, eigen::Dot(dev_ctx, x, y, out); } +template +void Matmul(const CUDAContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y, + DenseTensor* out) { + PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()), + 0, + paddle::platform::errors::InvalidArgument( + "The Input(X) dims size must not be equal 0," + " but reviced dims size is 0. ")); + PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()), + 0, + paddle::platform::errors::InvalidArgument( + "The Input(Y) dims size must not be equal 0," + " but reviced dims size is 0. ")); + math::MatMulFunction( + dev_ctx, x, y, out, transpose_x, transpose_y); +} + } // namespace pten PT_REGISTER_MODULE(LinalgCUDA); +using float16 = paddle::platform::float16; using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; @@ -47,3 +70,13 @@ PT_REGISTER_KERNEL("dot", int64_t, complex64, complex128) {} + +PT_REGISTER_KERNEL("matmul_v2", + CUDA, + ANY, + pten::Matmul, + float, + double, + float16, + complex64, + complex128) {} diff --git a/paddle/pten/kernels/cuda/linalg.h b/paddle/pten/kernels/cuda/linalg.h index ad38f71ec080a8..a6489efa72eee2 100644 --- a/paddle/pten/kernels/cuda/linalg.h +++ b/paddle/pten/kernels/cuda/linalg.h @@ -32,6 +32,14 @@ void Dot(const CUDAContext& dev_ctx, const DenseTensor& y, DenseTensor* out); +template +void Matmul(const CUDAContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y, + DenseTensor* out); + } // namespace pten #endif diff --git a/paddle/pten/kernels/cuda/utils.h b/paddle/pten/kernels/cuda/utils.h index a8a6838f4602a6..0d79f04f2ee5ed 100644 --- a/paddle/pten/kernels/cuda/utils.h +++ b/paddle/pten/kernels/cuda/utils.h @@ -14,6 +14,9 @@ limitations under the License. */ #pragma once +// CUDA and HIP use same api +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_registry.h" @@ -26,3 +29,5 @@ using CUDAContext = paddle::platform::CUDADeviceContext; void Copy(const CUDAContext& dev_ctx, const DenseTensor& src, DenseTensor* dst); } // namespace pten + +#endif diff --git a/paddle/pten/kernels/functions/math/matmul_func.h b/paddle/pten/kernels/functions/math/matmul_func.h new file mode 100644 index 00000000000000..b5ddd26a95576f --- /dev/null +++ b/paddle/pten/kernels/functions/math/matmul_func.h @@ -0,0 +1,491 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/complex_functors.h" + +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/functions/eigen/common.h" + +namespace pten { +namespace math { + +static void GetBroadcastFromDims(const int x_ndim, + const std::int64_t* x_dims, + const int y_ndim, + const std::int64_t* y_dims, + std::int64_t* x_bd_dims, + std::int64_t* y_bd_dims, + std::int64_t* out_bd_dims) { + const int ndim = (std::max)(x_ndim, y_ndim); + std::fill(x_bd_dims, x_bd_dims + ndim - x_ndim, 1); + std::fill(y_bd_dims, y_bd_dims + ndim - y_ndim, 1); + std::copy(x_dims, x_dims + x_ndim, x_bd_dims + ndim - x_ndim); + std::copy(y_dims, y_dims + y_ndim, y_bd_dims + ndim - y_ndim); + + for (int i = 0; i < ndim; ++i) { + PADDLE_ENFORCE_EQ( + x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1, + true, + paddle::platform::errors::InvalidArgument( + "Input(X) and Input(Y) has error dim." + "X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s]," + "or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1," + "But received X_broadcast's shape[%s] = [%s]" + "received Y_broadcast's shape[%s] = [%s]", + i, + i, + i, + i, + i, + x_bd_dims[i], + i, + y_bd_dims[i])); + if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) { + out_bd_dims[i] = 0; + } else { + out_bd_dims[i] = (std::max)(x_bd_dims[i], y_bd_dims[i]); + } + } +} + +static int64_t GetIndexMessage(const int n, + const int64_t* dims, + const int64_t* index) { + int64_t sum = 0; + for (int i = 0; i < n; ++i) { + if (dims[i] > 1) { + sum = sum * dims[i] + index[i]; + } + } + return sum; +} + +static void IndexIncreaseFromDims(const int ndim, + const int64_t* dims, + int64_t* index) { + for (int i = ndim - 1; i >= 0; --i) { + ++index[i]; + if (index[i] >= dims[i]) { + index[i] -= dims[i]; + } else { + break; + } + } +} + +template +void MatMulFunction(const DeviceContext& dev_ctx, + const DenseTensor& X, + const DenseTensor& Y, + const std::vector& x_dims, + const std::vector& y_dims, + DenseTensor* Out, + bool trans_x, + bool trans_y, + bool flag = false) { + const int x_ndim = x_dims.size(); + const int y_ndim = y_dims.size(); + + // Get data ptr + const T* x_data = X.data(); + const T* y_data = Y.data(); + + if (x_ndim == 1 && y_ndim == 1) { + PADDLE_ENFORCE_EQ( + X.numel(), + Y.numel(), + paddle::platform::errors::InvalidArgument( + "X's numbers must be equal to Y's numbers," + "when X/Y's dims =1. But received X has [%d] elements," + "received Y has [%d] elements", + X.numel(), + Y.numel())); + VLOG(3) << "MatMul's case 1"; + Out->Resize({1}); + Out->mutable_data(); + auto out_eigen = EigenScalar::From(*Out); + auto x_eigen = EigenVector::Flatten(X); + auto y_eigen = EigenVector::Flatten(Y); + + auto& dev = *dev_ctx.eigen_device(); + if (flag) { + out_eigen.device(dev) = (x_eigen * y_eigen).sum() + out_eigen; + } else { + out_eigen.device(dev) = (x_eigen * y_eigen).sum(); + } + return; + } + + auto blas = paddle::operators::math::GetBlas(dev_ctx); + + if (x_ndim == 1) { + const int N = X.numel(); + if (trans_y) { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], + N, + paddle::platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, + N, + y_ndim - 1, + y_dims[y_ndim - 1])); + } else { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], + N, + paddle::platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, + N, + y_ndim - 2, + y_dims[y_ndim - 2])); + } + std::vector out_dims(y_ndim - 1); + if (trans_y) { + std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); + } else { + std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); + out_dims.back() = y_dims.back(); + } + Out->Resize(paddle::framework::make_ddim(out_dims)); + Out->mutable_data(); + if (trans_y) { + const int M = Y.numel() / N; + VLOG(3) << "MatMul's case 2"; + blas.GEMV(false, + M, + N, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data()); + } else { + const int M = y_dims[y_ndim - 1]; + const int batch_size = Y.numel() / (M * N); + if (batch_size == 1) { + VLOG(3) << "MatMul's case 3"; + blas.GEMV(true, + N, + M, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data()); + } else { + VLOG(3) << "MatMul's case 4"; + blas.BatchedGEMM(CblasTrans, + CblasNoTrans, + M, + 1, + N, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data(), + batch_size, + M * N, + 0); + } + } + return; + } + + if (y_ndim == 1) { + const int N = Y.numel(); + if (trans_x) { + PADDLE_ENFORCE_EQ(x_dims[x_ndim - 2], + N, + paddle::platform::errors::InvalidArgument( + "Input(X) has error dim." + "X'dims[%d] must be equal to %d" + "But received X'dims[%d] is %d", + x_ndim - 2, + N, + x_ndim - 2, + x_dims[x_ndim - 2])); + } else { + PADDLE_ENFORCE_EQ(x_dims[x_ndim - 1], + N, + paddle::platform::errors::InvalidArgument( + "Input(X) has error dim." + "X'dims[%d] must be equal to %d" + "But received X'dims[%d] is %d", + x_ndim - 1, + N, + x_ndim - 1, + x_dims[x_ndim - 1])); + } + std::vector out_dims(x_ndim - 1); + if (trans_x) { + std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); + out_dims.back() = x_dims.back(); + } else { + std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); + } + Out->Resize(paddle::framework::make_ddim(out_dims)); + Out->mutable_data(); + + if (trans_x) { + const int M = x_dims[x_ndim - 1]; + const int batch_size = X.numel() / (M * N); + if (batch_size == 1) { + VLOG(3) << "MatMul's case 5"; + blas.GEMV(true, + N, + M, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data()); + } else { + VLOG(3) << "MatMul's case 6"; + blas.BatchedGEMM(CblasTrans, + CblasNoTrans, + M, + 1, + N, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data(), + batch_size, + M * N, + 0); + } + } else { + const int M = X.numel() / N; + VLOG(3) << "MatMul's case 7"; + blas.GEMV(false, + M, + N, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data()); + } + return; + } + + const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; + const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; + if (trans_y) { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], + K, + paddle::platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, + K, + y_ndim - 1, + y_dims[y_ndim - 1])); + } else { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], + K, + paddle::platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, + K, + y_ndim - 2, + y_dims[y_ndim - 2])); + } + const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; + const int ndim = (std::max)(x_ndim, y_ndim); + std::vector x_broadcast_dims(ndim); + std::vector y_broadcast_dims(ndim); + std::vector out_broadcast_dims(ndim); + + GetBroadcastFromDims(x_ndim - 2, + x_dims.data(), + y_ndim - 2, + y_dims.data(), + x_broadcast_dims.data(), + y_broadcast_dims.data(), + out_broadcast_dims.data()); + out_broadcast_dims[ndim - 2] = M; + out_broadcast_dims[ndim - 1] = N; + + Out->Resize(paddle::framework::make_ddim(out_broadcast_dims)); + Out->mutable_data(); + + const int batch_dim = ndim - 2; + // broadcast message + const bool is_broadcast_dims = + !std::equal(x_broadcast_dims.cbegin(), + x_broadcast_dims.cbegin() + batch_dim, + y_broadcast_dims.cbegin()); + + const std::int64_t x_batch_size = + std::accumulate(x_broadcast_dims.cbegin(), + x_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + const std::int64_t y_batch_size = + std::accumulate(y_broadcast_dims.cbegin(), + y_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + const std::int64_t out_batch_size = + std::accumulate(out_broadcast_dims.cbegin(), + out_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + if (out_batch_size == 0) return; + if (x_batch_size == 1 && y_batch_size == 1) { + VLOG(3) << "MatMul's case 8"; + blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data()); + } else if (x_batch_size == 1) { + if (M == 1 && trans_y) { + VLOG(3) << "MatMul's case 9"; + blas.GEMV(false, + y_batch_size * N, + K, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data()); + } else { + VLOG(3) << "MatMul's case 10"; + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data(), + out_batch_size, + 0, + K * N); + } + } else if (y_batch_size == 1) { + if (!trans_x) { + VLOG(3) << "MatMul's case 11"; + blas.GEMM(CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + x_batch_size * M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data()); + } else { + VLOG(3) << "MatMul's case 12"; + blas.BatchedGEMM(CblasTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data(), + out_batch_size, + M * K, + 0); + } + } else if (!is_broadcast_dims) { + VLOG(3) << "MatMul's case 13"; + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data(), + out_batch_size, + M * K, + K * N); + } else { + // in the case, can't use stridedgemm + std::vector x_ptr(out_batch_size); + std::vector y_ptr(out_batch_size); + std::vector out_ptr(out_batch_size); + std::vector index(batch_dim, 0); + for (std::int64_t i = 0; i < out_batch_size; ++i) { + // using the index to get offset + const std::int64_t x_index = + GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); + const std::int64_t y_index = + GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); + + x_ptr[i] = x_data + x_index * M * K; + y_ptr[i] = y_data + y_index * K * N; + out_ptr[i] = Out->mutable_data() + i * M * N; + IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); + } + VLOG(3) << "MatMul's case 14"; + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_ptr.data(), + y_ptr.data(), + static_cast(flag), + out_ptr.data(), + out_batch_size); + } +} + +template +void MatMulFunction(const DeviceContext& dev_ctx, + const DenseTensor& X, + const DenseTensor& Y, + DenseTensor* Out, + bool trans_x, + bool trans_y, + bool flag = false) { + const std::vector x_dims = vectorize(X.dims()); + const std::vector y_dims = vectorize(Y.dims()); + MatMulFunction( + dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag); +} + +} // namespace math +} // namespace pten diff --git a/paddle/pten/tests/CMakeLists.txt b/paddle/pten/tests/CMakeLists.txt index 27e76c87c6c0b7..3dc779380527f1 100644 --- a/paddle/pten/tests/CMakeLists.txt +++ b/paddle/pten/tests/CMakeLists.txt @@ -8,6 +8,7 @@ cc_test(dense_tensor_test SRCS dense_tensor_test.cc DEPS dense_tensor) cc_test(kernel_factory_test SRCS kernel_factory_test.cc DEPS kernel_factory) cc_test(test_mean_api SRCS test_mean_api.cc DEPS math_api pten_hapi_utils) cc_test(test_dot_api SRCS test_dot_api.cc DEPS linalg_api pten_hapi_utils) +cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS linalg_api pten_hapi_utils) cc_test(test_fill_api SRCS test_fill_api.cc DEPS creation_api pten_hapi_utils) cc_test(test_copy_api SRCS test_copy_api.cc DEPS utils_cpu pten_hapi_utils) cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS utils_cpu manipulation_api pten_hapi_utils) diff --git a/paddle/pten/tests/test_fill_api.cc b/paddle/pten/tests/test_fill_api.cc index c19d14efaa976b..4f93e03aca2f3d 100644 --- a/paddle/pten/tests/test_fill_api.cc +++ b/paddle/pten/tests/test_fill_api.cc @@ -32,7 +32,6 @@ using DDim = paddle::framework::DDim; // TODO(chenweihang): Remove this test after the API is used in the dygraph TEST(API, full_like) { - // 1. create tensor // 1. create tensor const auto alloc = std::make_shared( paddle::platform::CPUPlace()); diff --git a/paddle/pten/tests/test_matmul_api.cc b/paddle/pten/tests/test_matmul_api.cc new file mode 100644 index 00000000000000..b0579834519aa3 --- /dev/null +++ b/paddle/pten/tests/test_matmul_api.cc @@ -0,0 +1,160 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/hapi/include/linalg.h" + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/hapi/lib/utils/allocator.h" +#include "paddle/pten/kernels/cuda/utils.h" + +PT_DECLARE_MODULE(LinalgCPU); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_DECLARE_MODULE(LinalgCUDA); +#endif + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +TEST(API, matmul_cpu) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + auto dense_x = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + + auto* dense_x_data = dense_x->mutable_data(); + + auto dense_y = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + auto* dense_y_data = dense_y->mutable_data(); + + for (size_t i = 0; i < 9; ++i) { + dense_x_data[i] = 1.0; + dense_y_data[i] = 2.0; + } + std::vector sum(9, 6.0); + + paddle::experimental::Tensor x(dense_x); + paddle::experimental::Tensor y(dense_y); + + // 2. test API + auto out = paddle::experimental::matmul(x, y, false, false); + + // 3. check result + ASSERT_EQ(out.shape().size(), 2); + ASSERT_EQ(out.shape()[0], 3); + ASSERT_EQ(out.shape()[1], 3); + ASSERT_EQ(out.numel(), 9); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + + auto dense_out = std::dynamic_pointer_cast(out.impl()); + + for (size_t i = 0; i < 9; i++) { + ASSERT_NEAR(sum[i], dense_out->data()[i], 1e-6f); + } +} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +TEST(API, matmul_cuda) { + // Prepare CPU Dense Tensor + const auto alloc_cpu = + std::make_shared( + paddle::platform::CPUPlace()); + auto ref_x = std::make_shared( + alloc_cpu, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + + auto* ref_x_data = ref_x->mutable_data(); + + auto ref_y = std::make_shared( + alloc_cpu, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + auto* ref_y_data = ref_y->mutable_data(); + + for (size_t i = 0; i < 9; ++i) { + ref_x_data[i] = 1.0; + ref_y_data[i] = 2.0; + } + std::vector sum(9, 6.0); + + // 1. create tensor + const auto alloc_cuda = + std::make_shared( + paddle::platform::CUDAPlace()); + auto dense_x = std::make_shared( + alloc_cuda, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + + auto dense_y = std::make_shared( + alloc_cuda, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + + auto& pool = paddle::platform::DeviceContextPool::Instance(); + auto place = paddle::platform::CUDAPlace(); + auto* dev_ctx = pool.GetByPlace(place); + + pten::Copy(*dev_ctx, *ref_x.get(), dense_x.get()); + pten::Copy(*dev_ctx, *ref_y.get(), dense_y.get()); + + paddle::experimental::Tensor x(dense_x); + paddle::experimental::Tensor y(dense_y); + + // 2. test API + auto out = paddle::experimental::matmul(x, y, false, false); + + // 3. check result + ASSERT_EQ(out.shape().size(), 2); + ASSERT_EQ(out.shape()[0], 3); + ASSERT_EQ(out.shape()[1], 3); + ASSERT_EQ(out.numel(), 9); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + + auto dense_out = std::dynamic_pointer_cast(out.impl()); + + auto ref_out = std::make_shared( + alloc_cpu, + pten::DenseTensorMeta( + pten::DataType::FLOAT32, out.shape(), pten::DataLayout::NCHW)); + + pten::Copy(*dev_ctx, *dense_out.get(), ref_out.get()); + + for (size_t i = 0; i < 9; i++) { + ASSERT_NEAR(sum[i], ref_out->data()[i], 1e-6f); + } +} + +#endif From a4c3e038a4f65160f98f22be5f8824c85ff94751 Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Tue, 2 Nov 2021 05:09:09 +0100 Subject: [PATCH 2/6] Correct conv2d int8 mkldnn UT (#36711) * Refactor conv2d int8 unit test * Correct according to review and add int8 check --- paddle/fluid/pybind/pybind.cc | 19 +++ .../mkldnn/test_conv2d_int8_mkldnn_op.py | 147 ++++++++---------- 2 files changed, 86 insertions(+), 80 deletions(-) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d79bba7fd2f81e..fdff8310e710ba 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -226,6 +226,23 @@ bool SupportsBfloat16FastPerformance() { #endif } +bool SupportsInt8() { +#ifndef PADDLE_WITH_MKLDNN + return false; +#else + return (platform::MayIUse(platform::cpu_isa_t::avx2) || + platform::MayIUse(platform::cpu_isa_t::avx512f)); +#endif +} + +bool SupportsVNNI() { +#ifndef PADDLE_WITH_MKLDNN + return false; +#else + return platform::MayIUse(platform::cpu_isa_t::avx512_core_vnni); +#endif +} + // According to the input `place` and `dtype`, this function returns a tuple // consists of three sets: // 1) All operators registered in the Paddle framework. @@ -2121,6 +2138,8 @@ All parameter, weight, gradient are variables in Paddle. m.def("_is_compiled_with_heterps", IsCompiledWithHETERPS); m.def("supports_bfloat16", SupportsBfloat16); m.def("supports_bfloat16_fast_performance", SupportsBfloat16FastPerformance); + m.def("supports_int8", SupportsInt8); + m.def("supports_vnni", SupportsVNNI); m.def("op_supported_infos", OpSupportedInfos); m.def("is_compiled_with_brpc", IsCompiledWithBrpc); m.def("is_compiled_with_dist", IsCompiledWithDIST); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py index 2cfb6146f3f55d..7508ecbb2946d2 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py @@ -23,13 +23,12 @@ def conv2d_forward_refer(input, filter, group, conv_param): - out, in_n, out_h, out_w, out_c = conv2d_forward_naive(input, filter, group, - conv_param) + out, _, _, _, _ = conv2d_forward_naive(input, filter, group, conv_param) return out -@unittest.skipIf(not core.supports_bfloat16(), - "place does not support BF16 evaluation") +@unittest.skipIf(not core.supports_int8(), + "place does not support int8 computation") class TestConv2DInt8Op(TestConv2DOp): def setUp(self): self.op_type = "conv2d" @@ -53,73 +52,61 @@ def setUp(self): 'pad': self.pad, 'dilation': self.dilations } - + # This implementation of convolution quantization is based on OneDNN documentation + # https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html#doxid-dev-guide-int8-computations-1dg-i8-comp-s11 + scale_output_shift = (self.scale_out / + (self.scale_in * self.scale_weights[0])) filter = np.random.random(self.filter_size).astype(self.weighttype) - if self.srctype == np.uint8: - input = np.random.randint(0, 10, + + # When the Intel AVX2 or Intel AVX512 Instruction Set is used + # the reorder additionally scales the weights by 0.5 + # to overcome the potential overflow issue. If the processor supports VNNI instructions, + # modification of the weights is not necessary. + avx_scale = 0.5 if not core.supports_vnni( + ) and self.srctype == np.int8 else 1. + filter_int = np.round(filter * self.scale_weights[0] * + avx_scale).astype(np.int32) + scale_output_shift = scale_output_shift / avx_scale + + def conv2d_forward_refer_helper(input_): + return conv2d_forward_refer( + input_.astype(np.int32), filter_int, self.groups, + conv2d_param).astype(np.float32) * scale_output_shift + + def residual_helper(init_low, init_high, output_): + input_residual_ = np.random.randint( + init_low, init_high, + self.input_residual_size).astype(self.srctype) + return (output_ + input_residual_ * + (self.scale_out / self.scale_in_eltwise)), input_residual_ + + if self.srctype == np.int8: + init_low, init_high = (-5, 5) + input = np.random.randint(init_low, init_high, self.input_size).astype(self.srctype) + input_shift = (np.ones(self.input_size) * 128).astype(np.uint8) + + output1 = conv2d_forward_refer_helper( + np.round(input + input_shift).astype(np.int32)) + output2 = conv2d_forward_refer_helper( + np.round(input_shift).astype(np.int32)) + output = output1 - output2 else: - input = np.random.randint(-5, 5, + init_low, init_high = (0, 10) + input = np.random.randint(init_low, init_high, self.input_size).astype(self.srctype) - input_shift = (np.ones(self.input_size) * 128).astype(np.uint8) + output = conv2d_forward_refer_helper(input) - if self.srctype == np.int8: - filter_int = np.round(filter * self.scale_weights[0] * - 0.5).astype(np.int32) - scale_output_shift = self.scale_out / (self.scale_in * - self.scale_weights[0] * 0.5) - output1 = conv2d_forward_refer( - np.round((input.astype(np.int32) + input_shift) * - self.scale_in).astype(np.int32), filter_int, - self.groups, - conv2d_param).astype(np.float32) * scale_output_shift - output2 = conv2d_forward_refer( - np.round((input_shift) * self.scale_in).astype(np.int32), - filter_int, self.groups, - conv2d_param).astype(np.float32) * scale_output_shift - if self.fuse_residual: - input_residual = np.random.randint( - -5, 5, self.input_residual_size).astype(self.srctype) - output_tmp = np.round(output1 - output2 + input_residual.astype( - self.srctype) * (self.scale_out / self.scale_in_eltwise)) - if self.fuse_activation == "relu": - output = np.maximum(output_tmp, 0).astype(self.dsttype) - else: - output = output_tmp.astype(self.dsttype) - else: - if self.fuse_activation == "relu": - output = np.maximum(np.round(output1 - output2), - 0).astype(self.dsttype) - else: - output = np.round(output1 - output2).astype(self.dsttype) + if self.fuse_residual: + output, input_residual = residual_helper(init_low, init_high, + output) - else: - filter_int = np.round(filter * - self.scale_weights[0]).astype(np.int32) - scale_output_shift = self.scale_out / (self.scale_in * - self.scale_weights[0]) - output1 = conv2d_forward_refer( - input.astype(np.int32), filter_int, self.groups, - conv2d_param).astype(np.float32) - output1_tmp = np.round(output1 * ( - self.scale_out / (self.scale_in * self.scale_weights[0]))) - - if self.fuse_residual: - input_residual = np.random.randint( - 0, 10, self.input_residual_size).astype(self.srctype) - output_tmp_res = np.round(output1 * (self.scale_out / ( - self.scale_in * self.scale_weights[ - 0])) + input_residual.astype(np.int32) * ( - self.scale_out / self.scale_in_eltwise)) - if self.fuse_activation == "relu": - output = np.maximum(output_tmp_res, 0).astype(self.dsttype) - else: - output = output_tmp_res.astype(self.dsttype) - else: - if self.fuse_activation == "relu": - output = np.maximum(output1_tmp, 0).astype(self.dsttype) - else: - output = output1_tmp.astype(self.dsttype) + output = np.round(output) + + if self.fuse_activation == "relu": + output = np.maximum(output, 0) + + output = output.astype(self.dsttype) self.inputs = { 'Input': @@ -169,7 +156,7 @@ def init_test_case(self): f_c = self.input_size[1] // self.groups self.input_residual_size = [1, 2, 3, 3] self.filter_size = [2, f_c, 3, 3] - self.scale_in = 1.0 + self.scale_in = 0.95 self.scale_out = 0.5 self.scale_weights = [10.0] self.scale_in_eltwise = 0.6 @@ -185,7 +172,7 @@ def init_fuse_residual(self): self.fuse_residual = True -#--------------------test conv2d u8 in and u8 out with residual fuse-------------------- +# --------------------test conv2d u8 in and u8 out with residual fuse-------------------- class TestConv2D(TestConv2DInt8Op): @@ -197,7 +184,7 @@ def init_test_case(self): assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3] - self.scale_in = 1.0 + self.scale_in = 0.95 self.scale_out = 0.5 self.scale_weights = [10.0] self.scale_in_eltwise = 0.6 @@ -224,7 +211,7 @@ def init_test_case(self): assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3] - self.scale_in = 1.0 + self.scale_in = 0.95 self.scale_out = 0.8 self.scale_weights = [10.0] self.scale_in_eltwise = 0.5 @@ -240,7 +227,7 @@ def init_test_case(self): assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3] - self.scale_in = 1.0 + self.scale_in = 0.95 self.scale_out = 0.8 self.scale_weights = [10.0] self.scale_in_eltwise = 0.5 @@ -255,7 +242,7 @@ def init_test_case(self): assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 1, 1] - self.scale_in = 1.0 + self.scale_in = 0.95 self.scale_out = 0.5 self.scale_weights = [12.0] self.scale_in_eltwise = 0.5 @@ -270,7 +257,7 @@ def init_test_case(self): assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 1, 1] - self.scale_in = 1.0 + self.scale_in = 0.95 self.scale_out = 0.5 self.scale_weights = [10.0] self.scale_in_eltwise = 0.8 @@ -290,32 +277,32 @@ def init_data_type_with_fusion(self, input_dt, fuse_activation, fuse_residual): def create_test_int8_class(parent): - #--------------------test conv2d s8 in and u8 out-------------------- + # --------------------test conv2d s8 in and u8 out-------------------- class TestS8U8Case(parent): def init_data_type(self): init_data_type_with_fusion(self, np.int8, "relu", False) - #--------------------test conv2d s8 in and s8 out-------------------- + # --------------------test conv2d s8 in and s8 out-------------------- class TestS8S8Case(parent): def init_data_type(self): init_data_type_with_fusion(self, np.int8, "", False) - #--------------------test conv2d u8 in and s8 out-------------------- + # --------------------test conv2d u8 in and s8 out-------------------- class TestU8S8Case(parent): def init_data_type(self): init_data_type_with_fusion(self, np.uint8, "", False) - #--------------------test conv2d u8 in and u8 out without residual fuse-------------------- + # --------------------test conv2d u8 in and u8 out without residual fuse-------------------- class TestU8U8Case(parent): def init_data_type(self): init_data_type_with_fusion(self, np.uint8, "relu", False) - #--------------------test conv2d s8 in and s8 out with residual fuse-------------------- + # --------------------test conv2d s8 in and s8 out with residual fuse-------------------- class TestS8S8ResCase(parent): def init_data_type(self): init_data_type_with_fusion(self, np.int8, "", True) - #--------------------test conv2d u8 in and s8 out with residual fuse-------------------- + # --------------------test conv2d u8 in and s8 out with residual fuse-------------------- class TestU8S8ResCase(parent): def init_data_type(self): init_data_type_with_fusion(self, np.uint8, "", True) @@ -333,9 +320,9 @@ def init_data_type(self): TestS8S8Case.__name__ = cls_name_s8s8 TestU8S8Case.__name__ = cls_name_u8s8 TestU8U8Case.__name__ = cls_name_u8u8 - TestS8S8ResCase.__name__ = cls_name_s8s8_re_1 TestU8S8ResCase.__name__ = cls_name_u8s8_re_1 + globals()[cls_name_s8u8] = TestS8U8Case globals()[cls_name_s8s8] = TestS8S8Case globals()[cls_name_u8s8] = TestU8S8Case @@ -344,7 +331,7 @@ def init_data_type(self): globals()[cls_name_u8s8_re_1] = TestU8S8ResCase if os.name != 'nt': - #--------------------test conv2d s8 in and u8 out with residual fuse-------------------- + # --------------------test conv2d s8 in and u8 out with residual fuse-------------------- class TestS8U8ResCase(parent): def init_data_type(self): init_data_type_with_fusion(self, np.int8, "relu", True) From dc08c18757620c5cd1c1edb849a6c99df0881ba9 Mon Sep 17 00:00:00 2001 From: Liu-xiandong <85323580+Liu-xiandong@users.noreply.github.com> Date: Tue, 2 Nov 2021 12:59:08 +0800 Subject: [PATCH 3/6] fix cusparse compile bug in CUDA11.2, test=develop (#36911) --- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/platform/dynload/cusparse.h | 4 ++-- .../fluid/tests/unittests/test_sparse_attention_op.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index a9e15b5d405f2a..958a3e9fdde030 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -99,7 +99,7 @@ if (WITH_GPU OR WITH_ROCM) endif() op_library(sync_batch_norm_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n") - if ((NOT WIN32) AND (NOT WITH_ROCM) AND (NOT PADDLE_WITH_ARM) AND (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.2) ) + if ((NOT WIN32) AND (NOT WITH_ROCM) AND (NOT PADDLE_WITH_ARM) AND (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.3) ) op_library(sparse_attention_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sparse_attention);\n") endif() diff --git a/paddle/fluid/platform/dynload/cusparse.h b/paddle/fluid/platform/dynload/cusparse.h index e5be003fadf066..e44e8ed08560f7 100644 --- a/paddle/fluid/platform/dynload/cusparse.h +++ b/paddle/fluid/platform/dynload/cusparse.h @@ -56,8 +56,8 @@ extern void *cusparse_dso_handle; CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP); -// APIs available after CUDA 11.2 -#if CUDA_VERSION >= 11020 +// APIs available after CUDA 11.3 +#if CUDA_VERSION >= 11030 #define CUSPARSE_ROUTINE_EACH_R2(__macro) \ __macro(cusparseSDDMM_bufferSize); \ __macro(cusparseSDDMM_preprocess); \ diff --git a/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py index 5134b885f33072..cce4742f164557 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py @@ -128,8 +128,8 @@ def init_csr_format(batch_size, num_heads, rows, blocksize): @unittest.skipIf( - not core.is_compiled_with_cuda() or get_cuda_version() < 11020, - "core is not compiled with CUDA and cuda version need larger than or equal to 11.2" + not core.is_compiled_with_cuda() or get_cuda_version() < 11030, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" ) class TestSparseAttentionOp(OpTest): def config(self): @@ -190,8 +190,8 @@ def config(self): @unittest.skipIf( - not core.is_compiled_with_cuda() or get_cuda_version() < 11020, - "core is not compiled with CUDA and cuda version need larger than or equal to 11.2" + not core.is_compiled_with_cuda() or get_cuda_version() < 11030, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" ) class TestSparseAttentionAPI(unittest.TestCase): def setUp(self): From 093c4ec5d4ec53df214f619d59ed30e27664208d Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 2 Nov 2021 14:05:31 +0800 Subject: [PATCH 4/6] append test dir into gitignore (#36926) --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 749832c3930cf5..c246a56cf15a4e 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,6 @@ build_* cmake-build-* paddle/fluid/operators/distributed/send_recv.proto model_test + +Testing +tools/__pycache__ From b094110256c31da3ae002ceaefee5e367b9fcaec Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Tue, 2 Nov 2021 14:14:12 +0800 Subject: [PATCH 5/6] fix some bug, test=develop (#36888) --- .../framework/new_executor/interpretercore.cc | 7 +++-- .../new_executor/interpretercore_util.cc | 23 +++++++++++---- .../new_executor/new_executor_defs.h | 9 ++++++ .../operators/controlflow/fetch_v2_op.cc | 28 +++++++++++++++++-- 4 files changed, 57 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 3ea8b8d309d45c..a8007c2f26a0e8 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -241,13 +241,14 @@ void InterpreterCore::BuildInplace() { auto& outputs = instr.Outputs(); for (auto& pair : in_to_outs) { auto iter = inputs.find(pair.first); - if (iter != inputs.end()) { + if (iter != inputs.end() && !iter->second.empty()) { if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) { auto iterout = outputs.find(pair.second); - if (iterout != outputs.end()) { + if (iterout != outputs.end() && !iterout->second.empty()) { auto invar = global_scope_->Var(iter->second[0]); auto outvar = global_scope_->Var(iterout->second[0]); - if (invar && outvar) { + if (invar && outvar && invar->IsType() && + outvar->IsType()) { instr.AddInplace(invar, outvar); VLOG(3) << "inplace " << vec_instruction_[i].OpBase()->Type() << " " << global_scope_->GetNameById(iter->second[0]) diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index a4443b08847269..9de03a435ab586 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -142,8 +142,8 @@ void build_variable_scope(const framework::ProgramDesc& pdesc, if (nullptr == var_scope->FindVar(var_name)) { var_scope->AddVar(var_desc->Name(), var_desc); } else { - auto* var_desc = var_scope->VarDesc(var_name); - if (nullptr == var_desc) { + auto* var_desc_tmp = var_scope->VarDesc(var_name); + if (nullptr == var_desc_tmp) { VLOG(3) << "update var:" << var_name << " desc from nullptr into " << var_desc; var_scope->VarMetaInfo(var_name).vardesc_ = var_desc; @@ -206,9 +206,22 @@ void apply_device_guard(const OperatorBase* op_base, VLOG(3) << "Switch into CPUPlace by device_guard."; expected_kernel_key->place_ = platform::CPUPlace(); } else if (op_device.find("gpu") != std::string::npos && - platform::is_gpu_place(place)) { - VLOG(3) << "Switch into " << place << " by device_guard."; - expected_kernel_key->place_ = place; + (platform::is_gpu_place(place) || + platform::is_npu_place(place))) { + // when the Op that only has CPUKernel is assigned to GPU, the CPUKernel + // will be executed and a warning will be given at the same time. + if (op_base->SupportGPU()) { + expected_kernel_key->place_ = place; + } else if (op_base->SupportNPU()) { + expected_kernel_key->place_ = place; + } else { + expected_kernel_key->place_ = platform::CPUPlace(); + LOG_FIRST_N(WARNING, 1) + << "Op(" << op_base->Type() + << ") has no CUDA implementation. It will be assigned to CPUPlace."; + } + VLOG(3) << "Switch into " << expected_kernel_key->place_ + << " by device_guard."; } else { PADDLE_THROW( platform::errors::Fatal("Unsupported current place %s", op_device)); diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 58b6c924e23aab..d70243b93fed3f 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -474,6 +474,15 @@ struct VariableMetaInfo { // TODO(zhiqiu): Maybe we need to add rwlock for VariableScope? class VariableScope : public ScopeBase { public: + VariableScope() { + // for @EMPTY@ variable + var_list_.push_back(nullptr); + name2id_[kEmptyVarName] = 0; + VariableMetaInfo info; + info.var_ref_count_ = 0; + info.vardesc_ = nullptr; + vec_meta_info_.push_back(info); + } Variable* FindVar(const std::string& name) const { auto it = name2id_.find(name); if (it != name2id_.end()) { diff --git a/paddle/fluid/operators/controlflow/fetch_v2_op.cc b/paddle/fluid/operators/controlflow/fetch_v2_op.cc index bf9874c02f6203..0837caf9353a3a 100644 --- a/paddle/fluid/operators/controlflow/fetch_v2_op.cc +++ b/paddle/fluid/operators/controlflow/fetch_v2_op.cc @@ -77,12 +77,35 @@ class FetchV2Op : public framework::OperatorWithKernel { framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const framework::Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { + if (!tensor.IsInitialized()) { + return expected_kernel_type; + } return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { + auto *fetch_var = ctx.InputVar("X"); + if (fetch_var == nullptr) { + return framework::OpKernelType(framework::proto::VarType::FP32, + platform::CPUPlace()); + } + + if (fetch_var->IsType()) { + auto &src_item = fetch_var->Get(); + if (!src_item.IsInitialized()) { + return framework::OpKernelType(framework::proto::VarType::FP32, + platform::CPUPlace()); + } + } else { + auto &src_item = fetch_var->Get(); + if (src_item.empty() || !src_item[0].IsInitialized()) { + return framework::OpKernelType(framework::proto::VarType::FP32, + platform::CPUPlace()); + } + } + return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), platform::CPUPlace()); @@ -127,6 +150,9 @@ class FetchV2Kernel { if (fetch_var->IsType()) { auto &src_item = fetch_var->Get(); + if (!src_item.IsInitialized()) { + return; + } auto *dst_item = &(BOOST_GET(framework::LoDTensor, fetch_list->at(col))); bool check_place = platform::is_cpu_place(src_item.place()) || platform::is_cuda_pinned_place(src_item.place()); @@ -173,9 +199,7 @@ class FetchV2OpProtoMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(true); AddComment(R"DOC( FetchV2 Operator. - It should not be configured by users directly. - )DOC"); } }; From 4a7f1a0d840244f41c180a93101bae50bf487879 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 2 Nov 2021 14:37:17 +0800 Subject: [PATCH 6/6] Add Intermediate Kernel API for refactor Tensor Lib (#36914) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial tensor design & sign kernel demo * add move constructor for meta & add lodtensor * add dirs & sign xpu kernel * add mean cpu&cuda kernel impl * move sign & mean xpu & npu kernel * add selected_rows basic impl * refactor design, BaseTensor to DenseTensor, etc. * add scale mkldnn kernel * polish xpu & npu impl details * fix mkldnn reuse compile failed * change tensor operation lib name * rename util filename * add more comments * change TensorImplInterface to TensorInterface * add kernel key and factory * remove MKLDNNTensorMeta, add MKLDNNDenseTensor * change XXDeviceContext to XXContext * add base kernel registrar utils & test on sign * replace boost::any by paddle::any * fix several ci failed * fix npu compile error * add ordered map util * fix multiple ordered_map compile errors * move dev into include dir * support sign op in static op run * fix static op run error * fix new executor compile failed * add dygraph branch & remove sign_op.h * fix test_infer_no_need_buffer_slots * fix rocm compile link error * fix unitybuild error & clear glog * fix npu compile failed * skip quant trans test * fix part windows compile problem * fix xpu enforce error * fix inference test failed * remove ordered_map to solve quant failed * fix part of rcom compile faild * add more register kernels * revert scale kernel temporarily * fix code format error * add new kernel registrar marco * rename top to tcmpt * revert xpu, npu, mkldnn impl & remove op def * add kernel args parse functor to auto parse args * revert some change & add scale kernels * add op proto in dygraph kernelcontext building * polish kernel dispatch logic & nameing rule * fix scale kernel match error * fix scale test failed * add mean API and unittest * test mean api success * add branch to solve compiled error * skip clang format error * add mean skip rule in op_library * add dot kernel, api and unittest (#6) * remove old kernel and add symbol link * fix dot compiled failed * add merco for module declare * fix npu and xpu compile error * revert sign, mean, scale, dot kernel removing * add comment for keeping old kernel impl * fix mutable_data error * fix bfloat16 conflit * fix inference undef error * adapt to msvc compile rules * polish comment for template inst * add cmake template instantiation for win * fix backend to place device id bug * fix ifdef error * Op2functor (#7) * add kernel args maker class * make args maker non-const * remove debug log * modify codes by review options * split constructPrKernelContext function * fix output name bug * fix test_mean_op test_sign_op failed * fill_any_like kernel refactor (#10) * fill_any_like kernel refactor * remove useless code of full_like c++ api * skip dtype for fill_any_like * add attrs for kernel key constrcut * add use_pt_kernel Flags to control whether to use pt kernel (#13) * add use_pt_kernel Flags to control whether to use pt kernel * change the default value to true for cheking pt kernels * fix mutable_data cuda place error * move high level apis into hapi * remove selectedrows adapting temporarily * Support Scalar in Tensor Compute Library (#14) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * remove mkldnn tensor & polish details * use flat_hash_map and small_vector in kernel factory * Refactor flatten kernel (#12) * refactor flatten kernel * update infershape function * fix compile bugs * fix bugs when merge * fix compiler bugs * fix bugs when run test_flatten_api * fix bugs when run test * Revert "use flat_hash_map and small_vector in kernel factory" This reverts commit 23091495cfdd3df8cc1be592d30f09ea66a7c72b. * Move cpu, cuda and other device code into kernels (#15) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * start refactor matmul * move cpu, cuda and other device modules into kernels * merge code * polish code in operator.cc * Perfect unitests (#16) * perfect unittest * update license * replace with flat_hash_map, small_vector (#19) * fix small_vector build error on windows platform * replace with flat_hash_map, small_vector * remove todo * Perfect unitests (#20) * perfect unittest * update license * fix bug when run tcmpt_utils_test * refactor execution adapting impl * fix insert conflit * Fix CI bug of test_yolov3 (#21) * fill_any_like kernel refactor * remove useless code of full_like c++ api * Support Scalar in Tensor Compute Library * add scalar in dygraph and static graph mode * keep the basic type for attr, instead of using scalar for all * merge the code * start refactor matmul * move cpu, cuda and other device modules into kernels * merge code * polish code in operator.cc * Fix CI bug of test_yolov3 * add the tensor base class, test=develop (#17) * update the tensor base class, test=develop * remove two funcs, test=develop * update the error msg, test=develop Co-authored-by: Chen Weihang * [no-verify] commit backend and tensor signature changes * Rename tcmpt to pten (#23) * rename tcmpt to pten * update omitted files for rename to pten * update omitted file for rename to pten * remove k of all enum var * remove kernel_instantiate (#26) * remove symbols and spatial_tensor * change common to functions * readd share tensor impl methods * add a candidate dense tensor class, test=develop (#28) * change all Pt to Pten * resolve conflit with xiaowei * Op2functor opt1 (#27) * replace to small vector and change to const & * add std::move Co-authored-by: Chen Weihang * polish kernel factory and kernel registry * fix operator test error msg mismatch * remove tensor signature and backend set member * move scalar and polish enforce * revert dtype layout change to fix error * fix enum operator override error * Add Intermediate API layer * add several base unittests * add pten utils tests * polish some details * Dev/op2func refactor 3 (#30) * add a candidate dense tensor class, test=develop * remove TensorBase::backend(), test=develop * remove some ops, test=develop * cherry-pick the pr of tensor meta, test=develop * moves the dense tensor and some ops, test=develop * update the linalg operator, test=develop * update other operators, test=develop * fix errors, test=develop * fix bugs, test=develop * try to resolve the problem of windows ci, test=develop * updates codes, test=develop * fix the tensor_utils.cc, test=develop * modify the dense tensor, test=develop * fix the data type, test=develop Co-authored-by: shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> * intermediate api adapt to new dense tensor * add some TODO and delete include header Co-authored-by: Chen Weihang Co-authored-by: chentianyu03 Co-authored-by: zyfncg <1370305206@qq.com> Co-authored-by: 石晓伟 <39303645+Shixiaowei02@users.noreply.github.com> --- paddle/fluid/framework/operator.cc | 1 - paddle/fluid/imperative/prepared_operator.cc | 1 - paddle/pten/api/include/creation.h | 21 ++++ paddle/pten/api/include/linalg.h | 19 +++ paddle/pten/api/include/manipulation.h | 20 ++++ paddle/pten/api/include/math.h | 57 +++++++++ paddle/pten/hapi/lib/utils/tensor_utils.cc | 1 + paddle/pten/kernels/cpu/manipulation.cc | 5 +- paddle/pten/kernels/cuda/manipulation.cu | 5 +- paddle/pten/tests/CMakeLists.txt | 1 + paddle/pten/tests/test_dot_api.cc | 54 +++++++++ paddle/pten/tests/test_fill_api.cc | 37 ++++++ paddle/pten/tests/test_flatten_api.cc | 46 ++++++++ paddle/pten/tests/test_mean_api.cc | 35 ++++++ paddle/pten/tests/test_scale_api.cc | 118 +++++++++++++++++++ 15 files changed, 413 insertions(+), 8 deletions(-) create mode 100644 paddle/pten/tests/test_scale_api.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 16e63e433e6403..d317aac8594b4c 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -23,7 +23,6 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/op_call_stack.h" -#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/unused_var_check.h" diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index db26c66958140b..b2d55babc7e1c1 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -16,7 +16,6 @@ #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" -#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/imperative/infer_shape_context.h" #include "paddle/pten/common/scalar.h" #include "paddle/utils/small_vector.h" diff --git a/paddle/pten/api/include/creation.h b/paddle/pten/api/include/creation.h index d7311e6cd283b4..9795d88f81880b 100644 --- a/paddle/pten/api/include/creation.h +++ b/paddle/pten/api/include/creation.h @@ -14,5 +14,26 @@ #pragma once +#include "paddle/pten/api/include/infershape.h" +#include "paddle/pten/hapi/lib/utils/allocator.h" #include "paddle/pten/kernels/cpu/creation.h" #include "paddle/pten/kernels/cuda/creation.h" + +namespace pten { + +// TODO(YuanRisheng) This function name should be same as User API name. +// TODO(zyfncg) Automatic code generation +template +DenseTensor FillAnyLike(const ContextT& dev_ctx, + const DenseTensor& x, + const Scalar& val) { + auto out_meta = UnchangedInferShape(x.meta()); + const auto allocator = + std::make_shared( + dev_ctx.GetPlace()); + pten::DenseTensor dense_out(allocator, out_meta); + FillAnyLike(dev_ctx, x, val, &dense_out); + return dense_out; +} + +} // namespace pten diff --git a/paddle/pten/api/include/linalg.h b/paddle/pten/api/include/linalg.h index d9798c3a2e0a81..0d4c7a60fbc145 100644 --- a/paddle/pten/api/include/linalg.h +++ b/paddle/pten/api/include/linalg.h @@ -15,5 +15,24 @@ #pragma once // See Note: [ How do we organize the kernel directory ] +#include "paddle/pten/api/include/infershape.h" +#include "paddle/pten/hapi/lib/utils/allocator.h" #include "paddle/pten/kernels/cpu/linalg.h" #include "paddle/pten/kernels/cuda/linalg.h" + +namespace pten { + +template +DenseTensor Dot(const ContextT& dev_ctx, + const DenseTensor& x, + const DenseTensor& y) { + auto out_meta = DotInferShape(x.meta(), y.meta()); + const auto allocator = + std::make_shared( + dev_ctx.GetPlace()); + pten::DenseTensor dense_out(allocator, out_meta); + Dot(dev_ctx, x, y, &dense_out); + return dense_out; +} + +} // namespace pten diff --git a/paddle/pten/api/include/manipulation.h b/paddle/pten/api/include/manipulation.h index f2acad96499696..1f867686a6eb7e 100644 --- a/paddle/pten/api/include/manipulation.h +++ b/paddle/pten/api/include/manipulation.h @@ -15,5 +15,25 @@ #pragma once // See Note: [ How do we organize the kernel directory ] +#include "paddle/pten/api/include/infershape.h" +#include "paddle/pten/hapi/lib/utils/allocator.h" #include "paddle/pten/kernels/cpu/manipulation.h" #include "paddle/pten/kernels/cuda/manipulation.h" + +namespace pten { + +template +DenseTensor Flatten(const ContextT& dev_ctx, + const DenseTensor& x, + int start_axis, + int stop_axis) { + auto out_meta = FlattenInferShape(x.meta(), start_axis, stop_axis); + const auto allocator = + std::make_shared( + dev_ctx.GetPlace()); + pten::DenseTensor dense_out(allocator, out_meta); + Flatten(dev_ctx, x, start_axis, stop_axis, &dense_out); + return dense_out; +} + +} // namespace pten diff --git a/paddle/pten/api/include/math.h b/paddle/pten/api/include/math.h index 5145c823a5c6e0..fa512e8d6db0d7 100644 --- a/paddle/pten/api/include/math.h +++ b/paddle/pten/api/include/math.h @@ -15,5 +15,62 @@ limitations under the License. */ #pragma once // See Note: [ How do we organize the kernel directory ] +#include "paddle/pten/api/include/infershape.h" +#include "paddle/pten/hapi/lib/utils/allocator.h" #include "paddle/pten/kernels/cpu/math.h" #include "paddle/pten/kernels/cuda/math.h" + +namespace pten { + +template +DenseTensor Sign(const ContextT& dev_ctx, const DenseTensor& x) { + auto out_meta = UnchangedInferShape(x.meta()); + const auto allocator = + std::make_shared( + dev_ctx.GetPlace()); + pten::DenseTensor dense_out(allocator, out_meta); + Sign(dev_ctx, x, &dense_out); + return dense_out; +} + +template +DenseTensor Mean(const ContextT& dev_ctx, const DenseTensor& x) { + auto out_meta = ReductionInferShape(x.meta()); + const auto allocator = + std::make_shared( + dev_ctx.GetPlace()); + pten::DenseTensor dense_out(allocator, out_meta); + Mean(dev_ctx, x, &dense_out); + return dense_out; +} + +template +DenseTensor Scale(const ContextT& dev_ctx, + const DenseTensor& x, + float scale, + float bias, + bool bias_after_scale) { + auto out_meta = UnchangedInferShape(x.meta()); + const auto allocator = + std::make_shared( + dev_ctx.GetPlace()); + pten::DenseTensor dense_out(allocator, out_meta); + Scale(dev_ctx, x, scale, bias, bias_after_scale, &dense_out); + return dense_out; +} + +template +DenseTensor Scale(const ContextT& dev_ctx, + const DenseTensor& x, + const DenseTensor& scale, + float bias, + bool bias_after_scale) { + auto out_meta = UnchangedInferShape(x.meta()); + const auto allocator = + std::make_shared( + dev_ctx.GetPlace()); + pten::DenseTensor dense_out(allocator, out_meta); + ScaleHost(dev_ctx, x, scale, bias, bias_after_scale, &dense_out); + return dense_out; +} +} // namespace pten diff --git a/paddle/pten/hapi/lib/utils/tensor_utils.cc b/paddle/pten/hapi/lib/utils/tensor_utils.cc index a55c50db761a61..f7641f424f4910 100644 --- a/paddle/pten/hapi/lib/utils/tensor_utils.cc +++ b/paddle/pten/hapi/lib/utils/tensor_utils.cc @@ -45,6 +45,7 @@ std::unique_ptr MakePtenDenseTensor( SetLoD(&meta.lod, src.lod()); auto shared_storage = pten::make_intrusive(src.Holder(), src.offset()); + return std::make_unique(std::move(shared_storage), std::move(meta)); } diff --git a/paddle/pten/kernels/cpu/manipulation.cc b/paddle/pten/kernels/cpu/manipulation.cc index c436e14e0caab7..87c76149f127fe 100644 --- a/paddle/pten/kernels/cpu/manipulation.cc +++ b/paddle/pten/kernels/cpu/manipulation.cc @@ -24,10 +24,9 @@ void Flatten(const CPUContext& dev_ctx, int start_axis, int stop_axis, DenseTensor* out) { - auto out_meta = FlattenInferShape(x.meta(), start_axis, stop_axis); + auto out_dims = out->dims(); pten::Copy(dev_ctx, x, out); - out->set_lod(out_meta.lod); - out->Resize(out_meta.dims); + out->Resize(out_dims); } // TODO(yuanrisheng): this kernel is for training and xshape is a Intermediate diff --git a/paddle/pten/kernels/cuda/manipulation.cu b/paddle/pten/kernels/cuda/manipulation.cu index 43614f859c58bf..38111f2b8c02fd 100644 --- a/paddle/pten/kernels/cuda/manipulation.cu +++ b/paddle/pten/kernels/cuda/manipulation.cu @@ -24,10 +24,9 @@ void Flatten(const CUDAContext& dev_ctx, int start_axis, int stop_axis, DenseTensor* out) { - auto out_meta = FlattenInferShape(x.meta(), start_axis, stop_axis); + auto out_dims = out->dims(); pten::Copy(dev_ctx, x, out); - out->set_lod(out_meta.lod); - out->Resize(out_meta.dims); + out->Resize(out_dims); } // TODO(yuanrisheng): this kernel is for training and xshape is a Intermediate diff --git a/paddle/pten/tests/CMakeLists.txt b/paddle/pten/tests/CMakeLists.txt index 3dc779380527f1..3d2da6a5afdd1a 100644 --- a/paddle/pten/tests/CMakeLists.txt +++ b/paddle/pten/tests/CMakeLists.txt @@ -12,3 +12,4 @@ cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS linalg_api pten_hapi_utils) cc_test(test_fill_api SRCS test_fill_api.cc DEPS creation_api pten_hapi_utils) cc_test(test_copy_api SRCS test_copy_api.cc DEPS utils_cpu pten_hapi_utils) cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS utils_cpu manipulation_api pten_hapi_utils) +cc_test(test_scale_api SRCS test_scale_api.cc DEPS math_api pten_hapi_utils) diff --git a/paddle/pten/tests/test_dot_api.cc b/paddle/pten/tests/test_dot_api.cc index 69e785904fe3c9..5401b665444739 100644 --- a/paddle/pten/tests/test_dot_api.cc +++ b/paddle/pten/tests/test_dot_api.cc @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/hapi/lib/utils/allocator.h" +#include "paddle/pten/api/include/linalg.h" + PT_DECLARE_MODULE(LinalgCPU); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -82,3 +84,55 @@ TEST(API, dot) { ASSERT_NEAR(expect_result[1], actual_result1, 1e-6f); ASSERT_NEAR(expect_result[2], actual_result2, 1e-6f); } + +// TODO(YuanRisheng) This unitest should be created in other file. +// It is convenient to make compilation decoupling. +TEST(DEV_API, dot) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x(alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 10}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(); + + pten::DenseTensor dense_y(alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 10}), + pten::DataLayout::NCHW)); + auto* dense_y_data = dense_y.mutable_data(); + + float sum[3] = {0.0, 0.0, 0.0}; + for (size_t i = 0; i < 3; ++i) { + for (size_t j = 0; j < 10; ++j) { + dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0; + dense_y_data[i * 10 + j] = (i * 10 + j) * 1.0; + sum[i] += (i * 10 + j) * (i * 10 + j) * 1.0; + } + } + + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + + // 2. test API + auto out = pten::Dot( + *(static_cast(dev_ctx)), + dense_x, + dense_y); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 3); + ASSERT_EQ(out.meta().type, pten::DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); + + auto expect_result = sum; + auto actual_result0 = out.data()[0]; + auto actual_result1 = out.data()[1]; + auto actual_result2 = out.data()[2]; + ASSERT_NEAR(expect_result[0], actual_result0, 1e-6f); + ASSERT_NEAR(expect_result[1], actual_result1, 1e-6f); + ASSERT_NEAR(expect_result[2], actual_result2, 1e-6f); +} diff --git a/paddle/pten/tests/test_fill_api.cc b/paddle/pten/tests/test_fill_api.cc index 4f93e03aca2f3d..5a788226086dcc 100644 --- a/paddle/pten/tests/test_fill_api.cc +++ b/paddle/pten/tests/test_fill_api.cc @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/hapi/lib/utils/allocator.h" +#include "paddle/pten/api/include/creation.h" + PT_DECLARE_MODULE(CreationCPU); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -131,3 +133,38 @@ TEST(API, ones_like) { ASSERT_EQ(actual_result[i], 1); } } + +TEST(DEV_API, fill_any_like) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x(alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 2}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(); + dense_x_data[0] = 0; + float val = 1.0; + + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + + // 2. test API + auto out = pten::FillAnyLike( + *(static_cast(dev_ctx)), + dense_x, + val); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 3); + ASSERT_EQ(out.numel(), 6); + ASSERT_EQ(out.meta().type, pten::DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); + + auto* actual_result = out.data(); + for (auto i = 0; i < 6; i++) { + ASSERT_NEAR(actual_result[i], val, 1e-6f); + } +} diff --git a/paddle/pten/tests/test_flatten_api.cc b/paddle/pten/tests/test_flatten_api.cc index 48d2205c2ff484..dfb777678a94d0 100644 --- a/paddle/pten/tests/test_flatten_api.cc +++ b/paddle/pten/tests/test_flatten_api.cc @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/hapi/lib/utils/allocator.h" +#include "paddle/pten/api/include/manipulation.h" + PT_DECLARE_MODULE(ManipulationCPU); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -70,3 +72,47 @@ TEST(API, flatten) { } ASSERT_EQ(value_equal, true); } + +TEST(DEV_API, flatten) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 2, 2, 3}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(); + + for (int i = 0; i < dense_x.numel(); i++) { + dense_x_data[i] = i; + } + int start_axis = 1, stop_axis = 2; + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + + // 2. test API + auto out = pten::Flatten( + *(static_cast(dev_ctx)), + dense_x, + start_axis, + stop_axis); + + // 3. check result + std::vector expect_shape = {3, 4, 3}; + ASSERT_EQ(out.dims()[0], expect_shape[0]); + ASSERT_EQ(out.dims()[1], expect_shape[1]); + ASSERT_EQ(out.dims()[2], expect_shape[2]); + ASSERT_EQ(out.numel(), 36); + ASSERT_EQ(out.meta().type, pten::DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); + + bool value_equal = true; + auto* dense_out_data = out.data(); + for (int i = 0; i < dense_x.numel(); i++) { + if (std::abs(dense_x_data[i] - dense_out_data[i]) > 1e-6f) + value_equal = false; + } + ASSERT_EQ(value_equal, true); +} diff --git a/paddle/pten/tests/test_mean_api.cc b/paddle/pten/tests/test_mean_api.cc index ee8388671b7ebe..b3da90659d005a 100644 --- a/paddle/pten/tests/test_mean_api.cc +++ b/paddle/pten/tests/test_mean_api.cc @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/hapi/lib/utils/allocator.h" +#include "paddle/pten/api/include/math.h" + PT_DECLARE_MODULE(MathCPU); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -67,3 +69,36 @@ TEST(API, mean) { auto actual_result = dense_out->data()[0]; ASSERT_NEAR(expect_result, actual_result, 1e-6f); } + +TEST(DEV_API, mean) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x(alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 4}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(); + + float sum = 0.0; + for (size_t i = 0; i < 12; ++i) { + dense_x_data[i] = i * 1.0; + sum += i * 1.0; + } + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + // 2. test API + auto out = pten::Mean( + *(static_cast(dev_ctx)), dense_x); + + // 3. check result + ASSERT_EQ(out.dims().size(), 1); + ASSERT_EQ(out.numel(), 1); + ASSERT_EQ(out.meta().type, pten::DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); + + auto expect_result = sum / 12; + auto actual_result = out.data()[0]; + ASSERT_NEAR(expect_result, actual_result, 1e-6f); +} diff --git a/paddle/pten/tests/test_scale_api.cc b/paddle/pten/tests/test_scale_api.cc new file mode 100644 index 00000000000000..9f80d6d2cc126b --- /dev/null +++ b/paddle/pten/tests/test_scale_api.cc @@ -0,0 +1,118 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/hapi/include/math.h" + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/hapi/lib/utils/allocator.h" + +#include "paddle/pten/api/include/math.h" + +PT_DECLARE_MODULE(MathCPU); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_DECLARE_MODULE(MathCUDA); +#endif + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +TEST(DEV_API, scale) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x(alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 4}), + pten::DataLayout::NCHW)); + + auto* dense_x_data = dense_x.mutable_data(); + for (size_t i = 0; i < 12; ++i) { + dense_x_data[i] = i * 1.0; + } + float scale = 2; + float bias = 1; + bool bias_after_scale = true; + + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + + // 2. test API + auto out = pten::Scale( + *(static_cast(dev_ctx)), + dense_x, + scale, + bias, + bias_after_scale); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.numel(), 12); + ASSERT_EQ(out.meta().type, pten::DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); + + auto expect_result = 23; + auto actual_result = out.data()[11]; + ASSERT_NEAR(expect_result, actual_result, 1e-6f); +} + +TEST(DEV_API, scale_host) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor dense_x(alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 4}), + pten::DataLayout::NCHW)); + auto* dense_x_data = dense_x.mutable_data(); + for (size_t i = 0; i < 12; ++i) { + dense_x_data[i] = i * 1.0; + } + const auto alloc2 = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensor scale(alloc2, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({1}), + pten::DataLayout::NCHW)); + scale.mutable_data()[0] = 2; + float bias = 1; + bool bias_after_scale = true; + + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(paddle::platform::CPUPlace()); + + // 2. test API + auto out = pten::Scale( + *(static_cast(dev_ctx)), + dense_x, + scale, + bias, + bias_after_scale); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.numel(), 12); + ASSERT_EQ(out.meta().type, pten::DataType::FLOAT32); + ASSERT_EQ(out.meta().layout, pten::DataLayout::NCHW); + + auto expect_result = 23; + auto actual_result = out.data()[11]; + ASSERT_NEAR(expect_result, actual_result, 1e-6f); +}