Skip to content

Commit

Permalink
Add paddle.linalg.solve OP (PaddlePaddle#35715)
Browse files Browse the repository at this point in the history
* Add linalg.solve op, test=develop

* Fix a bug caused by accidental deletion

* updated description and fix a bug: missing a comma

* Add linalg.solve op, test=develop

* updated solve op backward logic

* updated solve op backward logic again

* Add linalg.solve Op, test=develop

* Updated and modified to fit CI requirements

* Fix a bug

* 1)Add more test cases; 2)Fix a wrong usage in reduces operation; 3)Remove redundant code

* Remove redundant comments

* 1)Removed redundant code; 2)Updated to enhance code robustness

* Removed redundant code

* Updated API documents
  • Loading branch information
veyron95 committed Sep 24, 2021
1 parent e9c0414 commit b4821f6
Show file tree
Hide file tree
Showing 19 changed files with 1,945 additions and 3 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ lod_tensor maxouting unpooling pooling lod_rank_table context_project
sequence_pooling segment_pooling executor device_memory_aligment generator)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse matrix_solve)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost ps_gpu_wrapper)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_function)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ math_library(bert_encoder_functor)
math_library(tree2col DEPS math_function)
math_library(matrix_inverse)
math_library(segment_pooling)
math_library(matrix_solve)

cc_test(math_function_test SRCS math_function_test.cc DEPS math_function)
cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/operators/math/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ class Blas {
template <typename T>
void BatchedMatInv(int n, const T** a, T** a_inv, int* info,
int batch_size) const;

// cuBlas solve
template <typename T>
void BatchedGETRS(CBLAS_TRANSPOSE trans, int n, int nrhs, const T** a,
int lda, int* ipiv, T** b, int ldb, int* info,
int batch_size) const;
#endif

private:
Expand Down Expand Up @@ -402,6 +408,12 @@ class BlasT : private Blas<DeviceContext> {
void BatchedMatInv(ARGS... args) const {
Base()->template BatchedMatInv<T>(args...);
}

// solve
template <typename... ARGS>
void BatchedGETRS(ARGS... args) const {
Base()->template BatchedGETRS<T>(args...);
}
#endif

private:
Expand Down
26 changes: 26 additions & 0 deletions paddle/fluid/operators/math/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ struct CUBlas<float> {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasSmatinvBatched(args...));
}

template <typename... ARGS>
static void GETRS_BATCH(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasSgetrsBatched(args...));
}
};

template <>
Expand Down Expand Up @@ -182,6 +188,12 @@ struct CUBlas<double> {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasDmatinvBatched(args...));
}

template <typename... ARGS>
static void GETRS_BATCH(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasDgetrsBatched(args...));
}
};

template <>
Expand Down Expand Up @@ -871,6 +883,20 @@ void Blas<platform::CUDADeviceContext>::BatchedMatInv(int n, const T **a,
});
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGETRS(
CBLAS_TRANSPOSE trans, int n, int nrhs, const T **a, int lda, int *ipiv,
T **b, int ldb, int *info, int batch_size) const {
// use CUBLAS_OP_C (conjugate transpose) for complex
cublasOperation_t cuTrans =
(trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GETRS_BATCH(handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info,
batch_size);
});
}

} // namespace math
} // namespace operators
} // namespace paddle
13 changes: 13 additions & 0 deletions paddle/fluid/operators/math/blas_impl.hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,19 @@ void Blas<platform::CUDADeviceContext>::BatchedMatInv(int n, const T **a,
});
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGETRS(
CBLAS_TRANSPOSE trans, int n, int nrhs, const T **a, int lda, int *ipiv,
T **b, int ldb, int *info, int batch_size) const {
rocblas_operation cuTrans = (trans == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::GETRS_BATCH(handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info,
batch_size);
});
}
} // namespace math
} // namespace operators
} // namespace paddle
39 changes: 39 additions & 0 deletions paddle/fluid/operators/math/matrix_solve.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/math/matrix_solve.h"
#include "Eigen/Core"
#include "Eigen/LU"
#include "paddle/fluid/operators/math/blas.h"

namespace paddle {
namespace operators {
namespace math {

template <typename T>
class MatrixSolveFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor& a, const framework::Tensor& b,
framework::Tensor* out) {
compute_solve_eigen<platform::CPUDeviceContext, T>(dev_ctx, a, b, out);
}
};

template class MatrixSolveFunctor<platform::CPUDeviceContext, float>;
template class MatrixSolveFunctor<platform::CPUDeviceContext, double>;

} // namespace math
} // namespace operators
} // namespace paddle
168 changes: 168 additions & 0 deletions paddle/fluid/operators/math/matrix_solve.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/solve_op.h"
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace platform {
class CUDADeviceContext;
} // namespace platform
} // namespace paddle

namespace paddle {
namespace operators {
namespace math {

template <typename DeviceContext, typename T>
class MatrixSolveFunctor;

template <typename T>
class MatrixSolveFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& a, const framework::Tensor& b,
framework::Tensor* out) {
#ifndef PADDLE_WITH_HIP

// solve the equation: Ax = B,
// use cuBlas cublas<S/D>getrfBatched funcion to performs the LU
// factorization of each matrix A,
// and then use cuBlas cublas<S/D>getriBatched function to solve the
// equation after LU factorization.
// ref:
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched
const auto& a_dims = a.dims();
const int a_rank = a_dims.size();
int n = a_dims[a_rank - 1];
int lda = n;
int batch_size = a_rank > 2 ? a.numel() / (n * n) : 1;

const auto& b_dims = b.dims();
const int b_rank = b_dims.size();
int nrhs = b_dims[b_rank - 1];
int ldb = b_dims[b_rank - 2];

// make sure the out dims is right
out->Resize(b_dims);
out->mutable_data<T>(context.GetPlace());

// copy input A to a temporary tensor tmp_a,
// LU factorization, written back to original matrix A, so in the beginning,
// it's necessary to create a temporary tensor tmp_a.
Tensor tmp_a(a.type());
tmp_a.Resize(a.dims());
tmp_a.mutable_data<T>(context.GetPlace());
TensorCopy(a, context.GetPlace(), &tmp_a);

// copy input B to a temporary tensor tmp_b, and transpose tmp_b,
// because cuBlas assumes column-major while Paddle uses row-majar.
Tensor tmp_b(b.type());
const auto& new_dims_vec = getNewDimsVec(b_dims);
tmp_b.Resize(framework::make_ddim(new_dims_vec));
tmp_b.mutable_data<T>(context.GetPlace());
math::TransposeNormal<platform::CUDADeviceContext, T> trans;
std::vector<int> new_axis = getNewAxis(b_rank);
trans(context, b, &tmp_b, new_axis);

const T* a_data_in_gpu = tmp_a.data<T>();
const T* b_data_in_gpu = tmp_b.data<T>();

std::vector<const T*> cpu_ptrs(batch_size * 2);
for (int i = 0; i < batch_size; ++i) {
cpu_ptrs[i] = a_data_in_gpu + i * n * n;
cpu_ptrs[i + batch_size] = b_data_in_gpu + i * n * nrhs;
}

// Copy the addresses of A and tmp_b from host to device.
memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
memory::Alloc(context, cpu_ptrs.size() * sizeof(T*));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_gpu_ptrs_data->ptr(), platform::CPUPlace(),
static_cast<void*>(cpu_ptrs.data()),
cpu_ptrs.size() * sizeof(T*), context.stream());

T** gpu_tmp_b_ptrs =
reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()) + batch_size;

// Allocate device memory for BatchedGETRF's info and pivots.
int num_ints = n < 32 ? batch_size : batch_size * (n + 1);
memory::allocation::AllocationPtr tmp_gpu_info_data =
memory::Alloc(context, num_ints * sizeof(int));
int* gpu_info_ptr = reinterpret_cast<int*>(tmp_gpu_info_data->ptr());

auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);

// only for singular checking
std::vector<int> info;
info.resize(batch_size);

int* gpu_pivot_ptr =
reinterpret_cast<int*>(tmp_gpu_info_data->ptr()) + batch_size;

// This function performs the LU factorization of each matrix A by the
// equation A = L * U. L and U are written back to original matrix A,
// and diagonal elements of L are discarded.
blas.BatchedGETRF(n, reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()),
gpu_pivot_ptr, gpu_info_ptr, batch_size);

// check whether BatchedGETRF is executed successfully or not
memory::Copy(platform::CPUPlace(), info.data(),
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
gpu_info_ptr, sizeof(int) * batch_size, context.stream());
for (int i = 0; i < batch_size; ++i) {
PADDLE_ENFORCE_EQ(info[i], 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: U(%d, %d) is zero, singular U. "
"Please check the matrix value and change it to a "
"non-singular matrix",
i, info[i], info[i]));
}

// hold the result code from BatchedGETRS
int host_info = 0;

// to solve the equation after LU factorization
CBLAS_TRANSPOSE transA = CblasTrans;
blas.BatchedGETRS(
transA, n, nrhs, reinterpret_cast<const T**>(tmp_gpu_ptrs_data->ptr()),
lda, gpu_pivot_ptr, gpu_tmp_b_ptrs, ldb, &host_info, batch_size);

// check whether BatchedGETRS is executed successfully or not
PADDLE_ENFORCE_EQ(host_info, 0,
platform::errors::InvalidArgument(
"The [%d]'th argument to cublas*getrsBatched had "
"an illegal value.",
-host_info));

// transpose tmp_b to get the final result in row-major form.
math::TransposeNormal<platform::CUDADeviceContext, T> trans2;
trans2(context, tmp_b, out, new_axis);

#else
compute_solve_eigen<platform::CUDADeviceContext, T>(context, a, b, out);
#endif
}
};

template class MatrixSolveFunctor<platform::CUDADeviceContext, float>;
template class MatrixSolveFunctor<platform::CUDADeviceContext, double>;

} // namespace math
} // namespace operators
} // namespace paddle
Loading

0 comments on commit b4821f6

Please sign in to comment.