Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add paddle.linalg.solve OP #35715

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ lod_tensor maxouting unpooling pooling lod_rank_table context_project
sequence_pooling segment_pooling executor device_memory_aligment generator)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse matrix_solve)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost ps_gpu_wrapper)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_function)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ math_library(bert_encoder_functor)
math_library(tree2col DEPS math_function)
math_library(matrix_inverse)
math_library(segment_pooling)
math_library(matrix_solve)

cc_test(math_function_test SRCS math_function_test.cc DEPS math_function)
cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/operators/math/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ class Blas {
template <typename T>
void BatchedMatInv(int n, const T** a, T** a_inv, int* info,
int batch_size) const;

// cuBlas solve
template <typename T>
void BatchedGETRS(CBLAS_TRANSPOSE trans, int n, int nrhs, const T** a,
int lda, int* ipiv, T** b, int ldb, int* info,
int batch_size) const;
#endif

private:
Expand Down Expand Up @@ -402,6 +408,12 @@ class BlasT : private Blas<DeviceContext> {
void BatchedMatInv(ARGS... args) const {
Base()->template BatchedMatInv<T>(args...);
}

// solve
template <typename... ARGS>
void BatchedGETRS(ARGS... args) const {
Base()->template BatchedGETRS<T>(args...);
}
#endif

private:
Expand Down
26 changes: 26 additions & 0 deletions paddle/fluid/operators/math/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ struct CUBlas<float> {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasSmatinvBatched(args...));
}

template <typename... ARGS>
static void GETRS_BATCH(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasSgetrsBatched(args...));
}
};

template <>
Expand Down Expand Up @@ -182,6 +188,12 @@ struct CUBlas<double> {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasDmatinvBatched(args...));
}

template <typename... ARGS>
static void GETRS_BATCH(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasDgetrsBatched(args...));
}
};

template <>
Expand Down Expand Up @@ -871,6 +883,20 @@ void Blas<platform::CUDADeviceContext>::BatchedMatInv(int n, const T **a,
});
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGETRS(
CBLAS_TRANSPOSE trans, int n, int nrhs, const T **a, int lda, int *ipiv,
T **b, int ldb, int *info, int batch_size) const {
// use CUBLAS_OP_C (conjugate transpose) for complex
cublasOperation_t cuTrans =
(trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GETRS_BATCH(handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info,
batch_size);
});
}

} // namespace math
} // namespace operators
} // namespace paddle
13 changes: 13 additions & 0 deletions paddle/fluid/operators/math/blas_impl.hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,19 @@ void Blas<platform::CUDADeviceContext>::BatchedMatInv(int n, const T **a,
});
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGETRS(
CBLAS_TRANSPOSE trans, int n, int nrhs, const T **a, int lda, int *ipiv,
T **b, int ldb, int *info, int batch_size) const {
rocblas_operation cuTrans = (trans == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::GETRS_BATCH(handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info,
batch_size);
});
}
} // namespace math
} // namespace operators
} // namespace paddle
39 changes: 39 additions & 0 deletions paddle/fluid/operators/math/matrix_solve.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/math/matrix_solve.h"
#include "Eigen/Core"
#include "Eigen/LU"
#include "paddle/fluid/operators/math/blas.h"

namespace paddle {
namespace operators {
namespace math {

template <typename T>
class MatrixSolveFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor& a, const framework::Tensor& b,
framework::Tensor* out) {
compute_solve_eigen<platform::CPUDeviceContext, T>(dev_ctx, a, b, out);
}
};

template class MatrixSolveFunctor<platform::CPUDeviceContext, float>;
template class MatrixSolveFunctor<platform::CPUDeviceContext, double>;

} // namespace math
} // namespace operators
} // namespace paddle
216 changes: 216 additions & 0 deletions paddle/fluid/operators/math/matrix_solve.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace platform {
class CUDADeviceContext;
} // namespace platform
} // namespace paddle

namespace paddle {
namespace operators {
namespace math {

template <typename DeviceContext, typename T>
class MatrixSolveFunctor;

// for TransposeNormal, transpose the last two dimmentions
std::vector<int> getNewAxis(const int b_rank) {
veyron95 marked this conversation as resolved.
Show resolved Hide resolved
std::vector<int> axis_1 = {0};
std::vector<int> axis_2 = {1, 0};
std::vector<int> axis_3 = {0, 2, 1};
std::vector<int> axis_4 = {0, 1, 3, 2};
std::vector<int> axis_5 = {0, 1, 2, 4, 3};
std::vector<int> axis_6 = {0, 1, 2, 3, 5, 4};
std::vector<int> axis_7 = {0, 1, 2, 3, 4, 6, 5};
std::vector<int> axis_8 = {0, 1, 2, 3, 4, 5, 7, 6};
std::vector<int> axis_9 = {0, 1, 2, 3, 4, 5, 6, 8, 7};
switch (b_rank) {
case 1:
return axis_1;
break;
case 2:
return axis_2;
break;
case 3:
return axis_3;
break;
case 4:
return axis_4;
break;
case 5:
return axis_5;
break;
case 6:
return axis_6;
break;
case 7:
return axis_7;
break;
case 8:
return axis_8;
break;
default:
return axis_9;
}
}

// for Resize
std::vector<int64_t> getNewDimsVec(const DDim& b_dims) {
veyron95 marked this conversation as resolved.
Show resolved Hide resolved
std::vector<int64_t> b_dims_vec = paddle::framework::vectorize(b_dims);
int size = b_dims_vec.size();
if (b_dims_vec.size() >= 2) {
veyron95 marked this conversation as resolved.
Show resolved Hide resolved
// swap the last 2 elements in b_dims_vec
int64_t temp = b_dims_vec[size - 1];
b_dims_vec[size - 1] = b_dims_vec[size - 2];
b_dims_vec[size - 2] = temp;
return b_dims_vec;
}
// if b_dims_vec.size() == 1, just retun original vec
return b_dims_vec;
veyron95 marked this conversation as resolved.
Show resolved Hide resolved
}

template <typename T>
class MatrixSolveFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& a, const framework::Tensor& b,
framework::Tensor* out) {
#ifndef PADDLE_WITH_HIP

veyron95 marked this conversation as resolved.
Show resolved Hide resolved
const auto& a_dims = a.dims();
const int a_rank = a_dims.size();
int n = a_dims[a_rank - 1];
int lda = n;
int batch_size = a_rank > 2 ? a.numel() / (n * n) : 1;

const auto& b_dims = b.dims();
const int b_rank = b_dims.size();
int nrhs = b_dims[b_rank - 1];
int ldb = b_dims[b_rank - 2];

// make sure the out dims is right
out->Resize(b_dims);
out->mutable_data<T>(context.GetPlace());

// copy input A to a temporary tensor tmp_a,
// LU factorization, written back to original matrix A, so in the beginning,
// it's necessary to create a temporary tensor tmp_a.
Tensor tmp_a(a.type());
tmp_a.Resize(a.dims());
tmp_a.mutable_data<T>(context.GetPlace());
TensorCopy(a, context.GetPlace(), &tmp_a);

// copy input B to a temporary tensor tmp_b, and transpose tmp_b,
// because cuBlas assumes column-major while Paddle uses row-majar.
Tensor tmp_b(b.type());
const auto& new_dims_vec = getNewDimsVec(b_dims);
tmp_b.Resize(framework::make_ddim(new_dims_vec));
tmp_b.mutable_data<T>(context.GetPlace());
math::TransposeNormal<platform::CUDADeviceContext, T> trans;
std::vector<int> new_axis = getNewAxis(b_rank);
trans(context, b, &tmp_b, new_axis);

memory::allocation::AllocationPtr tmp_a_data_in_gpu;
veyron95 marked this conversation as resolved.
Show resolved Hide resolved
const T* a_data_in_gpu = tmp_a.data<T>();

std::vector<const T*> cpu_ptrs(batch_size * 2);
for (int i = 0; i < batch_size; ++i) {
cpu_ptrs[i] = a_data_in_gpu + i * n * n;
cpu_ptrs[i + batch_size] = tmp_b.data<T>() + i * n * nrhs;
veyron95 marked this conversation as resolved.
Show resolved Hide resolved
}

// Copy the addresses of A and tmp_b from host to device.
memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
veyron95 marked this conversation as resolved.
Show resolved Hide resolved
memory::Alloc(context, cpu_ptrs.size() * sizeof(T*));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_gpu_ptrs_data->ptr(), platform::CPUPlace(),
static_cast<void*>(cpu_ptrs.data()),
cpu_ptrs.size() * sizeof(T*), context.stream());

T** gpu_tmp_b_ptrs =
reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()) + batch_size;

// Allocate device memory for BatchedGETRF's info and pivots.
int num_ints = n < 32 ? batch_size : batch_size * (n + 1);
memory::allocation::AllocationPtr tmp_gpu_info_data =
memory::Alloc(context, num_ints * sizeof(int));
int* gpu_info_ptr = reinterpret_cast<int*>(tmp_gpu_info_data->ptr());

auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);

// only for singular checking
std::vector<int> info;
veyron95 marked this conversation as resolved.
Show resolved Hide resolved
info.resize(batch_size);

int* gpu_pivot_ptr =
reinterpret_cast<int*>(tmp_gpu_info_data->ptr()) + batch_size;

// This function performs the LU factorization of each matrix A by the
// equation A = L * U. L and U are written back to original matrix A,
// and diagonal elements of L are discarded.
blas.BatchedGETRF(n, reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()),
gpu_pivot_ptr, gpu_info_ptr, batch_size);

// check whether BatchedGETRF is executed successfully or not
memory::Copy(platform::CPUPlace(), info.data(),
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
gpu_info_ptr, sizeof(int) * batch_size, context.stream());
for (int i = 0; i < batch_size; ++i) {
PADDLE_ENFORCE_EQ(info[i], 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: U(%d, %d) is zero, singular U. "
"Please check the matrix value and change it to a "
"non-singular matrix",
i, info[i], info[i]));
}

// hold the result code from BatchedGETRS
int host_info = 0;

// to solve the equation after LU factorization
CBLAS_TRANSPOSE transA = CblasTrans;
blas.BatchedGETRS(
transA, n, nrhs, reinterpret_cast<const T**>(tmp_gpu_ptrs_data->ptr()),
lda, gpu_pivot_ptr, gpu_tmp_b_ptrs, ldb, &host_info, batch_size);

// check whether BatchedGETRS is executed successfully or not
PADDLE_ENFORCE_EQ(host_info, 0,
platform::errors::InvalidArgument(
"The [%d]'th argument to cublas*getrsBatched had "
"an illegal value.",
-host_info));

// transpose tmp_b to get the final result in row-major form.
math::TransposeNormal<platform::CUDADeviceContext, T> trans2;
trans2(context, tmp_b, out, new_axis);

#else
compute_solve_eigen<platform::CUDADeviceContext, T>(context, a, b, out);
#endif
}
};

template class MatrixSolveFunctor<platform::CUDADeviceContext, float>;
template class MatrixSolveFunctor<platform::CUDADeviceContext, double>;

} // namespace math
} // namespace operators
} // namespace paddle
Loading