Skip to content

Commit

Permalink
add matmul kernel in pten
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg committed Oct 28, 2021
1 parent be9df70 commit 0fb60d0
Show file tree
Hide file tree
Showing 10 changed files with 786 additions and 2 deletions.
5 changes: 5 additions & 0 deletions paddle/pten/hapi/include/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,10 @@ namespace experimental {

Tensor dot(const Tensor& x, const Tensor& y);

Tensor matmul(const Tensor& x,
const Tensor& y,
bool transpose_x,
bool transpose_y);

} // namespace experimental
} // namespace paddle
107 changes: 107 additions & 0 deletions paddle/pten/hapi/lib/creation copy.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/hapi/include/creation.h"

#include <memory>

#include "glog/logging.h"

#include "paddle/pten/api/include/core.h"
#include "paddle/pten/api/include/infershape.h"
#include "paddle/pten/hapi/lib/kernel_dispatch.h"

namespace paddle {
namespace experimental {

Tensor full_like(const Tensor& x,
const Scalar& value,
paddle::experimental::DataType dtype) {
// 1. Get kernel signature and kernel
CustomKernelKeyParser kernel_key_parser;
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;

kernel_backend = kernel_key_parser.ParseBackend(backend);
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key_parser.ParseBackend(x);
}
kernel_layout = kernel_key_parser.ParseLayout(layout);
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key_parser.ParseLayout(x);
}
DataType kernel_data_type = kernel_key_parser.ParseDataType(dtype);
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key_parser.ParseDataType(x);
}

if (kernel_backend == Backend::UNDEFINED
|| kernel_layout == DataLayout::UNDEFINED
|| kernel_data_type == DataType::UNDEFINED ) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}

auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"fill_any_like", {kernel_backend, kernel_layout, kernel_data_type});

// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto kernel_context = pten::KernelContext(*dev_ctx);

// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);
kernel_context.EmplaceBackAttr(value);

// 4. InferShape
auto out_meta = UnchangedInferShape(dense_x->meta());

// 5. Prepare outputs
Tensor out;
// InferDataType
if (dtype != pten::DataType::UNDEFINED) {
out_meta.type = dtype;
}
auto dense_out =
std::make_shared<pten::DenseTensor>(out_meta, pten::TensorStatus());
kernel_context.EmplaceBackOutput(dense_out);
out.set_impl(dense_out);

// 6. Call kernel
kernel(&kernel_context);

return out;
}

Tensor ones_like(const Tensor& x, DataType dtype) {
return full_like(x, 1, dtype);
}

Tensor zeros_like(const Tensor& x, DataType dtype) {
return full_like(x, 0, dtype);
}

} // namespace experimental
} // namespace paddle
41 changes: 40 additions & 1 deletion paddle/pten/hapi/lib/linalg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ limitations under the License. */
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/hapi/lib/kernel_dispatch.h"
#include "paddle/pten/hapi/lib/utils/allocator.h"
#include "paddle/pten/infershape/binary.h"

namespace paddle {
namespace experimental {
Expand Down Expand Up @@ -65,5 +64,45 @@ Tensor dot(const Tensor& x, const Tensor& y) {
return out;
}

Tensor matmul(const Tensor& x,
const Tensor& y,
bool transpose_x,
bool transpose_y) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x, y);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"matmul_v2", kernel_key);

// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(*dev_ctx);

// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
auto dense_y = std::dynamic_pointer_cast<pten::DenseTensor>(y.impl());
kernel_context.EmplaceBackInput(dense_x);
kernel_context.EmplaceBackInput(dense_y);
// TODO(chenweihang): add transform impl

// 4. InferShape
auto out_meta = MatmulInferShape(
dense_x->meta(), dense_y->meta(), transpose_x, transpose_y);

// 5. Prepare outputs
const auto allocator = std::make_shared<DefaultAllocator>(
pten::TransToFluidPlace(kernel_key.backend()));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out);

Tensor out;
out.set_impl(dense_out);

// 6. Call kernel
kernel(&kernel_context);

return out;
}

} // namespace experimental
} // namespace paddle
70 changes: 70 additions & 0 deletions paddle/pten/infershape/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,74 @@ DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta,
return return_meta;
}

DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta,
bool trans_x,
bool trans_y) {
std::vector<int64_t> dims_x = paddle::framework::vectorize(x_meta.dims);
std::vector<int64_t> dims_y = paddle::framework::vectorize(y_meta.dims);
auto ndims_x = dims_x.size();
auto ndims_y = dims_y.size();
PADDLE_ENFORCE_GT(ndims_x,
0,
paddle::platform::errors::InvalidArgument(
"The Input(x) dims size must be greater than 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_GT(ndims_y,
0,
paddle::platform::errors::InvalidArgument(
"The Input(y) dims size must be greater than 0,"
" but reviced dims size is 0. "));

bool x_broadcasted = false, y_broadcasted = false;
if (ndims_x == 1) {
dims_x.insert(dims_x.begin(), 1);
ndims_x = 2;
x_broadcasted = true;
}

if (ndims_y == 1) {
dims_y.push_back(1);
ndims_y = 2;
y_broadcasted = true;
}

size_t M, N;
if (trans_x) {
M = dims_x[ndims_x - 1];
} else {
M = dims_x[ndims_x - 2];
}
if (trans_y) {
N = dims_y[ndims_y - 2];
} else {
N = dims_y[ndims_y - 1];
}

std::vector<int64_t> new_dims;
if (ndims_x > ndims_y) {
new_dims.assign(dims_x.begin(), dims_x.end() - 2);
} else if (ndims_x < ndims_y) {
new_dims.assign(dims_y.begin(), dims_y.end() - 2);
} else {
new_dims.reserve(ndims_x);
for (size_t i = 0; i < ndims_x - 2; ++i) {
new_dims.push_back(std::max(dims_x[i], dims_y[i]));
}
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}

auto ddim_out = paddle::framework::make_ddim(new_dims);

return {x_meta.type, ddim_out, x_meta.layout};
}

} // namespace pten
5 changes: 5 additions & 0 deletions paddle/pten/infershape/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,9 @@ namespace pten {
DenseTensorMeta DotInferShape(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta);

DenseTensorMeta MatmulInferShape(const DenseTensorMeta& x_meta,
const DenseTensorMeta& y_meta,
bool trans_x,
bool trans_y);

} // namespace pten
27 changes: 27 additions & 0 deletions paddle/pten/kernels/cpu/linalg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/complex.h"

#include "paddle/pten/kernels/functions/math/matmul_func.h"

namespace pten {

template <typename T>
Expand All @@ -45,6 +47,27 @@ void Dot(const CPUContext& dev_ctx,
}
}

template <typename T>
void Matmul(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y,
DenseTensor* out) {
PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(X) dims size must not be equal 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
math::MatMulFunction<CPUContext, T>(
dev_ctx, x, y, out, transpose_x, transpose_y);
}

} // namespace pten

PT_REGISTER_MODULE(LinalgCPU);
Expand All @@ -62,3 +85,7 @@ PT_REGISTER_KERNEL("dot",
int64_t,
complex64,
complex128) {}

PT_REGISTER_KERNEL(
"matmul_v2", CPU, ANY, pten::Matmul, float, double, complex64, complex128) {
}
2 changes: 1 addition & 1 deletion paddle/pten/kernels/cpu/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void Dot(const CPUContext& dev_ctx,
DenseTensor* out);

template <typename T>
void matmul(const CPUContext& dev_ctx,
void Matmul(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
Expand Down
31 changes: 31 additions & 0 deletions paddle/pten/kernels/cuda/linalg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/functions/eigen/dot.h"
#include "paddle/pten/kernels/functions/math/matmul_func.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/complex.h"
Expand All @@ -30,6 +31,27 @@ void Dot(const CUDAContext& dev_ctx,
eigen::Dot<CUDAContext, T>(dev_ctx, x, y, out);
}

template <typename T>
void Matmul(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y,
DenseTensor* out) {
PADDLE_ENFORCE_NE(paddle::framework::product(x.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(X) dims size must not be equal 0,"
" but reviced dims size is 0. "));
PADDLE_ENFORCE_NE(paddle::framework::product(y.dims()),
0,
paddle::platform::errors::InvalidArgument(
"The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
math::MatMulFunction<CUDAContext, T>(
dev_ctx, x, y, out, transpose_x, transpose_y);
}

} // namespace pten

PT_REGISTER_MODULE(LinalgCUDA);
Expand All @@ -47,3 +69,12 @@ PT_REGISTER_KERNEL("dot",
int64_t,
complex64,
complex128) {}

PT_REGISTER_KERNEL("matmul_v2",
CUDA,
ANY,
pten::Matmul,
float,
double,
complex64,
complex128) {}
8 changes: 8 additions & 0 deletions paddle/pten/kernels/cuda/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ void Dot(const CUDAContext& dev_ctx,
const DenseTensor& y,
DenseTensor* out);

template <typename T>
void Matmul(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y,
DenseTensor* out);

} // namespace pten

#endif
Loading

0 comments on commit 0fb60d0

Please sign in to comment.