diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d317aac8594b4..16e63e433e640 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/op_call_stack.h" +#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/unused_var_check.h" diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index b2d55babc7e1c..db26c66958140 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" +#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/imperative/infer_shape_context.h" #include "paddle/pten/common/scalar.h" #include "paddle/utils/small_vector.h" diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 2ba82243fe6b5..1d361ae45dad9 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -25,6 +25,11 @@ limitations under the License. */ #include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" +// only can include the headers in paddle/pten/api dirs +#include "paddle/pten/api/include/core.h" +#include "paddle/pten/api/include/linalg.h" +#include "paddle/pten/hapi/lib/utils/tensor_utils.h" + #if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #endif @@ -380,15 +385,17 @@ class MatMulV2Kernel : public framework::OpKernel { auto* Out = ctx.Output("Out"); bool trans_x = ctx.Attr("trans_x"); bool trans_y = ctx.Attr("trans_y"); - PADDLE_ENFORCE_NE(framework::product(X->dims()), 0, - platform::errors::InvalidArgument( - "The Input(X) dims size must not be equal 0," - " but reviced dims size is 0. ")); - PADDLE_ENFORCE_NE(framework::product(Y->dims()), 0, - platform::errors::InvalidArgument( - "The Input(Y) dims size must not be equal 0," - " but reviced dims size is 0. ")); - MatMulFunction(X, Y, Out, trans_x, trans_y, ctx); + + auto& dev_ctx = ctx.device_context(); + Out->mutable_data(X->place()); + + auto pt_x = paddle::experimental::MakePtenDenseTensor(*X); + auto pt_y = paddle::experimental::MakePtenDenseTensor(*Y); + auto pt_out = paddle::experimental::MakePtenDenseTensor(*Out); + + // call new kernel + pten::Matmul(dev_ctx, *pt_x.get(), *pt_y.get(), trans_x, trans_y, + pt_out.get()); } }; diff --git a/paddle/pten/hapi/include/linalg.h b/paddle/pten/hapi/include/linalg.h index fd628ea19334e..6e78b50af11c3 100644 --- a/paddle/pten/hapi/include/linalg.h +++ b/paddle/pten/hapi/include/linalg.h @@ -21,5 +21,10 @@ namespace experimental { Tensor dot(const Tensor& x, const Tensor& y); +Tensor matmul(const Tensor& x, + const Tensor& y, + bool transpose_x, + bool transpose_y); + } // namespace experimental } // namespace paddle diff --git a/paddle/pten/hapi/lib/linalg.cc b/paddle/pten/hapi/lib/linalg.cc index 54829feb43a24..3f13f546ee25e 100644 --- a/paddle/pten/hapi/lib/linalg.cc +++ b/paddle/pten/hapi/lib/linalg.cc @@ -25,7 +25,6 @@ limitations under the License. */ #include "paddle/pten/core/kernel_context.h" #include "paddle/pten/hapi/lib/kernel_dispatch.h" #include "paddle/pten/hapi/lib/utils/allocator.h" -#include "paddle/pten/infershape/binary.h" namespace paddle { namespace experimental { @@ -65,5 +64,47 @@ Tensor dot(const Tensor& x, const Tensor& y) { return out; } +Tensor matmul(const Tensor& x, + const Tensor& y, + bool transpose_x, + bool transpose_y) { + // 1. Get kernel signature and kernel + auto kernel_key_set = ParseKernelKeyByInputArgs(x, y); + auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( + "matmul_v2", kernel_key); + + // 2. Get Device Context + auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto kernel_context = pten::KernelContext(*dev_ctx); + + // 3. Auto data transform + auto dense_x = std::dynamic_pointer_cast(x.impl()); + auto dense_y = std::dynamic_pointer_cast(y.impl()); + kernel_context.EmplaceBackInput(dense_x); + kernel_context.EmplaceBackInput(dense_y); + kernel_context.EmplaceBackAttr(transpose_x); + kernel_context.EmplaceBackAttr(transpose_y); + // TODO(chenweihang): add transform impl + + // 4. InferShape + auto out_meta = MatmulInferShape( + dense_x->meta(), dense_y->meta(), transpose_x, transpose_y); + + // 5. Prepare outputs + const auto allocator = std::make_shared( + pten::TransToFluidPlace(kernel_key.backend())); + auto dense_out = std::make_shared(allocator, out_meta); + kernel_context.EmplaceBackOutput(dense_out); + + Tensor out; + out.set_impl(dense_out); + + // 6. Call kernel + kernel(&kernel_context); + + return out; +} + } // namespace experimental } // namespace paddle diff --git a/paddle/pten/infershape/binary.cc b/paddle/pten/infershape/binary.cc index c2b88c74d847e..c17e087158183 100644 --- a/paddle/pten/infershape/binary.cc +++ b/paddle/pten/infershape/binary.cc @@ -59,4 +59,74 @@ DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta, return return_meta; } +DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta, + const DenseTensorMeta& y_meta, + bool trans_x, + bool trans_y) { + std::vector dims_x = paddle::framework::vectorize(x_meta.dims); + std::vector dims_y = paddle::framework::vectorize(y_meta.dims); + auto ndims_x = dims_x.size(); + auto ndims_y = dims_y.size(); + PADDLE_ENFORCE_GT(ndims_x, + 0, + paddle::platform::errors::InvalidArgument( + "The Input(x) dims size must be greater than 0," + " but reviced dims size is 0. ")); + PADDLE_ENFORCE_GT(ndims_y, + 0, + paddle::platform::errors::InvalidArgument( + "The Input(y) dims size must be greater than 0," + " but reviced dims size is 0. ")); + + bool x_broadcasted = false, y_broadcasted = false; + if (ndims_x == 1) { + dims_x.insert(dims_x.begin(), 1); + ndims_x = 2; + x_broadcasted = true; + } + + if (ndims_y == 1) { + dims_y.push_back(1); + ndims_y = 2; + y_broadcasted = true; + } + + size_t M, N; + if (trans_x) { + M = dims_x[ndims_x - 1]; + } else { + M = dims_x[ndims_x - 2]; + } + if (trans_y) { + N = dims_y[ndims_y - 2]; + } else { + N = dims_y[ndims_y - 1]; + } + + std::vector new_dims; + if (ndims_x > ndims_y) { + new_dims.assign(dims_x.begin(), dims_x.end() - 2); + } else if (ndims_x < ndims_y) { + new_dims.assign(dims_y.begin(), dims_y.end() - 2); + } else { + new_dims.reserve(ndims_x); + for (size_t i = 0; i < ndims_x - 2; ++i) { + new_dims.push_back(std::max(dims_x[i], dims_y[i])); + } + } + if (!x_broadcasted) { + new_dims.push_back(M); + } + if (!y_broadcasted) { + new_dims.push_back(N); + } + if (x_broadcasted && y_broadcasted) { + new_dims.push_back(1); + } + + auto ddim_out = paddle::framework::make_ddim(new_dims); + + return {x_meta.type, ddim_out, x_meta.layout}; +} + } // namespace pten diff --git a/paddle/pten/infershape/binary.h b/paddle/pten/infershape/binary.h index 613d2f66a6edd..f58e5503f22a1 100644 --- a/paddle/pten/infershape/binary.h +++ b/paddle/pten/infershape/binary.h @@ -36,4 +36,9 @@ namespace pten { DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta, const DenseTensorMeta& y_meta); +DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta, + const DenseTensorMeta& y_meta, + bool trans_x, + bool trans_y); + } // namespace pten diff --git a/paddle/pten/kernels/cpu/linalg.cc b/paddle/pten/kernels/cpu/linalg.cc index df401370c881f..ced13dc41d1ae 100644 --- a/paddle/pten/kernels/cpu/linalg.cc +++ b/paddle/pten/kernels/cpu/linalg.cc @@ -21,6 +21,8 @@ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/complex.h" +#include "paddle/pten/kernels/functions/math/matmul_func.h" + namespace pten { template @@ -45,6 +47,27 @@ void Dot(const CPUContext& dev_ctx, } } +template +void Matmul(const CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y, + DenseTensor* out) { + PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()), + 0, + paddle::platform::errors::InvalidArgument( + "The Input(X) dims size must not be equal 0," + " but reviced dims size is 0. ")); + PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()), + 0, + paddle::platform::errors::InvalidArgument( + "The Input(Y) dims size must not be equal 0," + " but reviced dims size is 0. ")); + math::MatMulFunction( + dev_ctx, x, y, out, transpose_x, transpose_y); +} + } // namespace pten PT_REGISTER_MODULE(LinalgCPU); @@ -62,3 +85,7 @@ PT_REGISTER_KERNEL("dot", int64_t, complex64, complex128) {} + +PT_REGISTER_KERNEL( + "matmul_v2", CPU, ANY, pten::Matmul, float, double, complex64, complex128) { +} diff --git a/paddle/pten/kernels/cpu/linalg.h b/paddle/pten/kernels/cpu/linalg.h index a9447be74934c..a954033866f17 100644 --- a/paddle/pten/kernels/cpu/linalg.h +++ b/paddle/pten/kernels/cpu/linalg.h @@ -30,7 +30,7 @@ void Dot(const CPUContext& dev_ctx, DenseTensor* out); template -void matmul(const CPUContext& dev_ctx, +void Matmul(const CPUContext& dev_ctx, const DenseTensor& x, const DenseTensor& y, bool transpose_x, diff --git a/paddle/pten/kernels/cuda/linalg.cu b/paddle/pten/kernels/cuda/linalg.cu index 928a09a4edbff..6811afa8a49ff 100644 --- a/paddle/pten/kernels/cuda/linalg.cu +++ b/paddle/pten/kernels/cuda/linalg.cu @@ -16,6 +16,7 @@ #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/kernels/functions/eigen/dot.h" +#include "paddle/pten/kernels/functions/math/matmul_func.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/complex.h" @@ -30,10 +31,32 @@ void Dot(const CUDAContext& dev_ctx, eigen::Dot(dev_ctx, x, y, out); } +template +void Matmul(const CUDAContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y, + DenseTensor* out) { + PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()), + 0, + paddle::platform::errors::InvalidArgument( + "The Input(X) dims size must not be equal 0," + " but reviced dims size is 0. ")); + PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()), + 0, + paddle::platform::errors::InvalidArgument( + "The Input(Y) dims size must not be equal 0," + " but reviced dims size is 0. ")); + math::MatMulFunction( + dev_ctx, x, y, out, transpose_x, transpose_y); +} + } // namespace pten PT_REGISTER_MODULE(LinalgCUDA); +using float16 = paddle::platform::float16; using complex64 = ::paddle::platform::complex; using complex128 = ::paddle::platform::complex; @@ -47,3 +70,13 @@ PT_REGISTER_KERNEL("dot", int64_t, complex64, complex128) {} + +PT_REGISTER_KERNEL("matmul_v2", + CUDA, + ANY, + pten::Matmul, + float, + double, + float16, + complex64, + complex128) {} diff --git a/paddle/pten/kernels/cuda/linalg.h b/paddle/pten/kernels/cuda/linalg.h index ad38f71ec080a..a6489efa72eee 100644 --- a/paddle/pten/kernels/cuda/linalg.h +++ b/paddle/pten/kernels/cuda/linalg.h @@ -32,6 +32,14 @@ void Dot(const CUDAContext& dev_ctx, const DenseTensor& y, DenseTensor* out); +template +void Matmul(const CUDAContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y, + DenseTensor* out); + } // namespace pten #endif diff --git a/paddle/pten/kernels/cuda/utils.h b/paddle/pten/kernels/cuda/utils.h index a8a6838f4602a..0d79f04f2ee5e 100644 --- a/paddle/pten/kernels/cuda/utils.h +++ b/paddle/pten/kernels/cuda/utils.h @@ -14,6 +14,9 @@ limitations under the License. */ #pragma once +// CUDA and HIP use same api +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_registry.h" @@ -26,3 +29,5 @@ using CUDAContext = paddle::platform::CUDADeviceContext; void Copy(const CUDAContext& dev_ctx, const DenseTensor& src, DenseTensor* dst); } // namespace pten + +#endif diff --git a/paddle/pten/kernels/functions/math/matmul_func.h b/paddle/pten/kernels/functions/math/matmul_func.h new file mode 100644 index 0000000000000..b5ddd26a95576 --- /dev/null +++ b/paddle/pten/kernels/functions/math/matmul_func.h @@ -0,0 +1,491 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/complex_functors.h" + +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/functions/eigen/common.h" + +namespace pten { +namespace math { + +static void GetBroadcastFromDims(const int x_ndim, + const std::int64_t* x_dims, + const int y_ndim, + const std::int64_t* y_dims, + std::int64_t* x_bd_dims, + std::int64_t* y_bd_dims, + std::int64_t* out_bd_dims) { + const int ndim = (std::max)(x_ndim, y_ndim); + std::fill(x_bd_dims, x_bd_dims + ndim - x_ndim, 1); + std::fill(y_bd_dims, y_bd_dims + ndim - y_ndim, 1); + std::copy(x_dims, x_dims + x_ndim, x_bd_dims + ndim - x_ndim); + std::copy(y_dims, y_dims + y_ndim, y_bd_dims + ndim - y_ndim); + + for (int i = 0; i < ndim; ++i) { + PADDLE_ENFORCE_EQ( + x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1, + true, + paddle::platform::errors::InvalidArgument( + "Input(X) and Input(Y) has error dim." + "X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s]," + "or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1," + "But received X_broadcast's shape[%s] = [%s]" + "received Y_broadcast's shape[%s] = [%s]", + i, + i, + i, + i, + i, + x_bd_dims[i], + i, + y_bd_dims[i])); + if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) { + out_bd_dims[i] = 0; + } else { + out_bd_dims[i] = (std::max)(x_bd_dims[i], y_bd_dims[i]); + } + } +} + +static int64_t GetIndexMessage(const int n, + const int64_t* dims, + const int64_t* index) { + int64_t sum = 0; + for (int i = 0; i < n; ++i) { + if (dims[i] > 1) { + sum = sum * dims[i] + index[i]; + } + } + return sum; +} + +static void IndexIncreaseFromDims(const int ndim, + const int64_t* dims, + int64_t* index) { + for (int i = ndim - 1; i >= 0; --i) { + ++index[i]; + if (index[i] >= dims[i]) { + index[i] -= dims[i]; + } else { + break; + } + } +} + +template +void MatMulFunction(const DeviceContext& dev_ctx, + const DenseTensor& X, + const DenseTensor& Y, + const std::vector& x_dims, + const std::vector& y_dims, + DenseTensor* Out, + bool trans_x, + bool trans_y, + bool flag = false) { + const int x_ndim = x_dims.size(); + const int y_ndim = y_dims.size(); + + // Get data ptr + const T* x_data = X.data(); + const T* y_data = Y.data(); + + if (x_ndim == 1 && y_ndim == 1) { + PADDLE_ENFORCE_EQ( + X.numel(), + Y.numel(), + paddle::platform::errors::InvalidArgument( + "X's numbers must be equal to Y's numbers," + "when X/Y's dims =1. But received X has [%d] elements," + "received Y has [%d] elements", + X.numel(), + Y.numel())); + VLOG(3) << "MatMul's case 1"; + Out->Resize({1}); + Out->mutable_data(); + auto out_eigen = EigenScalar::From(*Out); + auto x_eigen = EigenVector::Flatten(X); + auto y_eigen = EigenVector::Flatten(Y); + + auto& dev = *dev_ctx.eigen_device(); + if (flag) { + out_eigen.device(dev) = (x_eigen * y_eigen).sum() + out_eigen; + } else { + out_eigen.device(dev) = (x_eigen * y_eigen).sum(); + } + return; + } + + auto blas = paddle::operators::math::GetBlas(dev_ctx); + + if (x_ndim == 1) { + const int N = X.numel(); + if (trans_y) { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], + N, + paddle::platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, + N, + y_ndim - 1, + y_dims[y_ndim - 1])); + } else { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], + N, + paddle::platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, + N, + y_ndim - 2, + y_dims[y_ndim - 2])); + } + std::vector out_dims(y_ndim - 1); + if (trans_y) { + std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); + } else { + std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); + out_dims.back() = y_dims.back(); + } + Out->Resize(paddle::framework::make_ddim(out_dims)); + Out->mutable_data(); + if (trans_y) { + const int M = Y.numel() / N; + VLOG(3) << "MatMul's case 2"; + blas.GEMV(false, + M, + N, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data()); + } else { + const int M = y_dims[y_ndim - 1]; + const int batch_size = Y.numel() / (M * N); + if (batch_size == 1) { + VLOG(3) << "MatMul's case 3"; + blas.GEMV(true, + N, + M, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data()); + } else { + VLOG(3) << "MatMul's case 4"; + blas.BatchedGEMM(CblasTrans, + CblasNoTrans, + M, + 1, + N, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data(), + batch_size, + M * N, + 0); + } + } + return; + } + + if (y_ndim == 1) { + const int N = Y.numel(); + if (trans_x) { + PADDLE_ENFORCE_EQ(x_dims[x_ndim - 2], + N, + paddle::platform::errors::InvalidArgument( + "Input(X) has error dim." + "X'dims[%d] must be equal to %d" + "But received X'dims[%d] is %d", + x_ndim - 2, + N, + x_ndim - 2, + x_dims[x_ndim - 2])); + } else { + PADDLE_ENFORCE_EQ(x_dims[x_ndim - 1], + N, + paddle::platform::errors::InvalidArgument( + "Input(X) has error dim." + "X'dims[%d] must be equal to %d" + "But received X'dims[%d] is %d", + x_ndim - 1, + N, + x_ndim - 1, + x_dims[x_ndim - 1])); + } + std::vector out_dims(x_ndim - 1); + if (trans_x) { + std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); + out_dims.back() = x_dims.back(); + } else { + std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); + } + Out->Resize(paddle::framework::make_ddim(out_dims)); + Out->mutable_data(); + + if (trans_x) { + const int M = x_dims[x_ndim - 1]; + const int batch_size = X.numel() / (M * N); + if (batch_size == 1) { + VLOG(3) << "MatMul's case 5"; + blas.GEMV(true, + N, + M, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data()); + } else { + VLOG(3) << "MatMul's case 6"; + blas.BatchedGEMM(CblasTrans, + CblasNoTrans, + M, + 1, + N, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data(), + batch_size, + M * N, + 0); + } + } else { + const int M = X.numel() / N; + VLOG(3) << "MatMul's case 7"; + blas.GEMV(false, + M, + N, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data()); + } + return; + } + + const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; + const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; + if (trans_y) { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], + K, + paddle::platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, + K, + y_ndim - 1, + y_dims[y_ndim - 1])); + } else { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], + K, + paddle::platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, + K, + y_ndim - 2, + y_dims[y_ndim - 2])); + } + const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; + const int ndim = (std::max)(x_ndim, y_ndim); + std::vector x_broadcast_dims(ndim); + std::vector y_broadcast_dims(ndim); + std::vector out_broadcast_dims(ndim); + + GetBroadcastFromDims(x_ndim - 2, + x_dims.data(), + y_ndim - 2, + y_dims.data(), + x_broadcast_dims.data(), + y_broadcast_dims.data(), + out_broadcast_dims.data()); + out_broadcast_dims[ndim - 2] = M; + out_broadcast_dims[ndim - 1] = N; + + Out->Resize(paddle::framework::make_ddim(out_broadcast_dims)); + Out->mutable_data(); + + const int batch_dim = ndim - 2; + // broadcast message + const bool is_broadcast_dims = + !std::equal(x_broadcast_dims.cbegin(), + x_broadcast_dims.cbegin() + batch_dim, + y_broadcast_dims.cbegin()); + + const std::int64_t x_batch_size = + std::accumulate(x_broadcast_dims.cbegin(), + x_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + const std::int64_t y_batch_size = + std::accumulate(y_broadcast_dims.cbegin(), + y_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + const std::int64_t out_batch_size = + std::accumulate(out_broadcast_dims.cbegin(), + out_broadcast_dims.cbegin() + batch_dim, + 1LL, + std::multiplies()); + if (out_batch_size == 0) return; + if (x_batch_size == 1 && y_batch_size == 1) { + VLOG(3) << "MatMul's case 8"; + blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data()); + } else if (x_batch_size == 1) { + if (M == 1 && trans_y) { + VLOG(3) << "MatMul's case 9"; + blas.GEMV(false, + y_batch_size * N, + K, + static_cast(1), + y_data, + x_data, + static_cast(flag), + Out->mutable_data()); + } else { + VLOG(3) << "MatMul's case 10"; + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data(), + out_batch_size, + 0, + K * N); + } + } else if (y_batch_size == 1) { + if (!trans_x) { + VLOG(3) << "MatMul's case 11"; + blas.GEMM(CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + x_batch_size * M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data()); + } else { + VLOG(3) << "MatMul's case 12"; + blas.BatchedGEMM(CblasTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data(), + out_batch_size, + M * K, + 0); + } + } else if (!is_broadcast_dims) { + VLOG(3) << "MatMul's case 13"; + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_data, + y_data, + static_cast(flag), + Out->mutable_data(), + out_batch_size, + M * K, + K * N); + } else { + // in the case, can't use stridedgemm + std::vector x_ptr(out_batch_size); + std::vector y_ptr(out_batch_size); + std::vector out_ptr(out_batch_size); + std::vector index(batch_dim, 0); + for (std::int64_t i = 0; i < out_batch_size; ++i) { + // using the index to get offset + const std::int64_t x_index = + GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data()); + const std::int64_t y_index = + GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data()); + + x_ptr[i] = x_data + x_index * M * K; + y_ptr[i] = y_data + y_index * K * N; + out_ptr[i] = Out->mutable_data() + i * M * N; + IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data()); + } + VLOG(3) << "MatMul's case 14"; + blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, + trans_y ? CblasTrans : CblasNoTrans, + M, + N, + K, + static_cast(1), + x_ptr.data(), + y_ptr.data(), + static_cast(flag), + out_ptr.data(), + out_batch_size); + } +} + +template +void MatMulFunction(const DeviceContext& dev_ctx, + const DenseTensor& X, + const DenseTensor& Y, + DenseTensor* Out, + bool trans_x, + bool trans_y, + bool flag = false) { + const std::vector x_dims = vectorize(X.dims()); + const std::vector y_dims = vectorize(Y.dims()); + MatMulFunction( + dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag); +} + +} // namespace math +} // namespace pten diff --git a/paddle/pten/tests/CMakeLists.txt b/paddle/pten/tests/CMakeLists.txt index 27e76c87c6c0b..3dc779380527f 100644 --- a/paddle/pten/tests/CMakeLists.txt +++ b/paddle/pten/tests/CMakeLists.txt @@ -8,6 +8,7 @@ cc_test(dense_tensor_test SRCS dense_tensor_test.cc DEPS dense_tensor) cc_test(kernel_factory_test SRCS kernel_factory_test.cc DEPS kernel_factory) cc_test(test_mean_api SRCS test_mean_api.cc DEPS math_api pten_hapi_utils) cc_test(test_dot_api SRCS test_dot_api.cc DEPS linalg_api pten_hapi_utils) +cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS linalg_api pten_hapi_utils) cc_test(test_fill_api SRCS test_fill_api.cc DEPS creation_api pten_hapi_utils) cc_test(test_copy_api SRCS test_copy_api.cc DEPS utils_cpu pten_hapi_utils) cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS utils_cpu manipulation_api pten_hapi_utils) diff --git a/paddle/pten/tests/test_fill_api.cc b/paddle/pten/tests/test_fill_api.cc index c19d14efaa976..4f93e03aca2f3 100644 --- a/paddle/pten/tests/test_fill_api.cc +++ b/paddle/pten/tests/test_fill_api.cc @@ -32,7 +32,6 @@ using DDim = paddle::framework::DDim; // TODO(chenweihang): Remove this test after the API is used in the dygraph TEST(API, full_like) { - // 1. create tensor // 1. create tensor const auto alloc = std::make_shared( paddle::platform::CPUPlace()); diff --git a/paddle/pten/tests/test_matmul_api.cc b/paddle/pten/tests/test_matmul_api.cc new file mode 100644 index 0000000000000..b0579834519aa --- /dev/null +++ b/paddle/pten/tests/test_matmul_api.cc @@ -0,0 +1,160 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/hapi/include/linalg.h" + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/hapi/lib/utils/allocator.h" +#include "paddle/pten/kernels/cuda/utils.h" + +PT_DECLARE_MODULE(LinalgCPU); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_DECLARE_MODULE(LinalgCUDA); +#endif + +namespace framework = paddle::framework; +using DDim = paddle::framework::DDim; + +TEST(API, matmul_cpu) { + // 1. create tensor + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + auto dense_x = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + + auto* dense_x_data = dense_x->mutable_data(); + + auto dense_y = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + auto* dense_y_data = dense_y->mutable_data(); + + for (size_t i = 0; i < 9; ++i) { + dense_x_data[i] = 1.0; + dense_y_data[i] = 2.0; + } + std::vector sum(9, 6.0); + + paddle::experimental::Tensor x(dense_x); + paddle::experimental::Tensor y(dense_y); + + // 2. test API + auto out = paddle::experimental::matmul(x, y, false, false); + + // 3. check result + ASSERT_EQ(out.shape().size(), 2); + ASSERT_EQ(out.shape()[0], 3); + ASSERT_EQ(out.shape()[1], 3); + ASSERT_EQ(out.numel(), 9); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + + auto dense_out = std::dynamic_pointer_cast(out.impl()); + + for (size_t i = 0; i < 9; i++) { + ASSERT_NEAR(sum[i], dense_out->data()[i], 1e-6f); + } +} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +TEST(API, matmul_cuda) { + // Prepare CPU Dense Tensor + const auto alloc_cpu = + std::make_shared( + paddle::platform::CPUPlace()); + auto ref_x = std::make_shared( + alloc_cpu, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + + auto* ref_x_data = ref_x->mutable_data(); + + auto ref_y = std::make_shared( + alloc_cpu, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + auto* ref_y_data = ref_y->mutable_data(); + + for (size_t i = 0; i < 9; ++i) { + ref_x_data[i] = 1.0; + ref_y_data[i] = 2.0; + } + std::vector sum(9, 6.0); + + // 1. create tensor + const auto alloc_cuda = + std::make_shared( + paddle::platform::CUDAPlace()); + auto dense_x = std::make_shared( + alloc_cuda, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + + auto dense_y = std::make_shared( + alloc_cuda, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({3, 3}), + pten::DataLayout::NCHW)); + + auto& pool = paddle::platform::DeviceContextPool::Instance(); + auto place = paddle::platform::CUDAPlace(); + auto* dev_ctx = pool.GetByPlace(place); + + pten::Copy(*dev_ctx, *ref_x.get(), dense_x.get()); + pten::Copy(*dev_ctx, *ref_y.get(), dense_y.get()); + + paddle::experimental::Tensor x(dense_x); + paddle::experimental::Tensor y(dense_y); + + // 2. test API + auto out = paddle::experimental::matmul(x, y, false, false); + + // 3. check result + ASSERT_EQ(out.shape().size(), 2); + ASSERT_EQ(out.shape()[0], 3); + ASSERT_EQ(out.shape()[1], 3); + ASSERT_EQ(out.numel(), 9); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + + auto dense_out = std::dynamic_pointer_cast(out.impl()); + + auto ref_out = std::make_shared( + alloc_cpu, + pten::DenseTensorMeta( + pten::DataType::FLOAT32, out.shape(), pten::DataLayout::NCHW)); + + pten::Copy(*dev_ctx, *dense_out.get(), ref_out.get()); + + for (size_t i = 0; i < 9; i++) { + ASSERT_NEAR(sum[i], ref_out->data()[i], 1e-6f); + } +} + +#endif