Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add new API/OP: paddle.linalg.triangular_solve #36714

Merged
merged 6 commits into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions paddle/fluid/operators/math/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,12 @@ class Blas {
void BatchedGETRS(CBLAS_TRANSPOSE trans, int n, int nrhs, const T** a,
int lda, int* ipiv, T** b, int ldb, int* info,
int batch_size) const;

// cuBlas triangular_solve
template <typename T>
void BatchedTRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA,
CBLAS_DIAG diag, int M, int N, T alpha, const T** a, int lda,
T** b, int ldb, int batch_size) const;
#endif

private:
Expand Down Expand Up @@ -414,6 +420,12 @@ class BlasT : private Blas<DeviceContext> {
void BatchedGETRS(ARGS... args) const {
Base()->template BatchedGETRS<T>(args...);
}

// triangular_solve
template <typename... ARGS>
void BatchedTRSM(ARGS... args) const {
Base()->template BatchedTRSM<T>(args...);
}
#endif

private:
Expand Down
88 changes: 88 additions & 0 deletions paddle/fluid/operators/math/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ struct CUBlas<float> {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasSgetrsBatched(args...));
}

template <typename... ARGS>
static void TRSM_BATCH(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasStrsmBatched(args...));
}
};

template <>
Expand Down Expand Up @@ -194,6 +199,11 @@ struct CUBlas<double> {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasDgetrsBatched(args...));
}

template <typename... ARGS>
static void TRSM_BATCH(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDtrsmBatched(args...));
}
};

template <>
Expand Down Expand Up @@ -339,6 +349,19 @@ struct CUBlas<platform::complex<float>> {
reinterpret_cast<cuFloatComplex *>(C), ldc));
}

static void TRSM(cublasHandle_t handle, cublasSideMode_t side,
cublasFillMode_t uplo, cublasOperation_t transa,
cublasDiagType_t diag, int m, int n,
const paddle::platform::complex<float> *alpha,
const paddle::platform::complex<float> *A, int lda,
paddle::platform::complex<float> *B, int ldb) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCtrsm(
handle, side, uplo, transa, diag, m, n,
reinterpret_cast<const cuFloatComplex *>(alpha),
reinterpret_cast<const cuFloatComplex *>(A), lda,
reinterpret_cast<cuFloatComplex *>(B), ldb));
}

// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template <typename... ARGS>
Expand Down Expand Up @@ -370,6 +393,20 @@ struct CUBlas<platform::complex<float>> {
"cublasGemmEx is not supported on cuda <= 7.5"));
#endif
}

static void TRSM_BATCH(cublasHandle_t handle, cublasSideMode_t side,
cublasFillMode_t uplo, cublasOperation_t transa,
cublasDiagType_t diag, int m, int n,
const paddle::platform::complex<float> *alpha,
const paddle::platform::complex<float> **A, int lda,
paddle::platform::complex<float> **B, int ldb,
int batch_size) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCtrsmBatched(
handle, side, uplo, transa, diag, m, n,
reinterpret_cast<const cuFloatComplex *>(alpha),
reinterpret_cast<const cuFloatComplex **>(A), lda,
reinterpret_cast<cuFloatComplex **>(B), ldb, batch_size));
}
};

template <>
Expand Down Expand Up @@ -440,6 +477,33 @@ struct CUBlas<platform::complex<double>> {
reinterpret_cast<cuDoubleComplex *>(C), ldc));
}

static void TRSM(cublasHandle_t handle, cublasSideMode_t side,
cublasFillMode_t uplo, cublasOperation_t transa,
cublasDiagType_t diag, int m, int n,
const paddle::platform::complex<double> *alpha,
const paddle::platform::complex<double> *A, int lda,
paddle::platform::complex<double> *B, int ldb) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZtrsm(
handle, side, uplo, transa, diag, m, n,
reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex *>(A), lda,
reinterpret_cast<cuDoubleComplex *>(B), ldb));
}

static void TRSM_BATCH(cublasHandle_t handle, cublasSideMode_t side,
cublasFillMode_t uplo, cublasOperation_t transa,
cublasDiagType_t diag, int m, int n,
const paddle::platform::complex<double> *alpha,
const paddle::platform::complex<double> **A, int lda,
paddle::platform::complex<double> **B, int ldb,
int batch_size) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZtrsmBatched(
handle, side, uplo, transa, diag, m, n,
reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex **>(A), lda,
reinterpret_cast<cuDoubleComplex **>(B), ldb, batch_size));
}

// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template <typename... ARGS>
Expand Down Expand Up @@ -897,6 +961,30 @@ void Blas<platform::CUDADeviceContext>::BatchedGETRS(
});
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedTRSM(
CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, CBLAS_DIAG diag,
int M, int N, T alpha, const T **A, int lda, T **B, int ldb,
int batch_size) const {
// solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'`
// where ' stands for transpose
cublasSideMode_t cuSide =
(side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;
cublasFillMode_t cuUplo =
(uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
// use CUBLAS_OP_C (conjugate transpose) for complex
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasDiagType_t cuDiag =
(diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;

context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::TRSM_BATCH(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M,
&alpha, A, lda, B, ldb, batch_size);
});
}

} // namespace math
} // namespace operators
} // namespace paddle
40 changes: 40 additions & 0 deletions paddle/fluid/operators/math/blas_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,17 @@ struct CBlas<platform::complex<float>> {
a_, lda, b_, ldb, &beta, c_, ldc);
}

static void TRSM(CBLAS_LAYOUT layout, CBLAS_SIDE side, CBLAS_UPLO uplo,
CBLAS_TRANSPOSE trans_a, CBLAS_DIAG diag, int M, int N,
paddle::platform::complex<float> alpha,
const paddle::platform::complex<float> *A, int lda,
paddle::platform::complex<float> *B, int ldb) {
const void *a_ = (const void *)(A);
void *b_ = static_cast<void *>(B);
platform::dynload::cblas_ctrsm(layout, side, uplo, trans_a, diag, M, N,
&alpha, a_, lda, b_, ldb);
}

template <typename... ARGS>
static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a,
CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K,
Expand Down Expand Up @@ -562,6 +573,17 @@ struct CBlas<platform::complex<double>> {
a_, lda, b_, ldb, &beta, c_, ldc);
}

static void TRSM(CBLAS_LAYOUT layout, CBLAS_SIDE side, CBLAS_UPLO uplo,
CBLAS_TRANSPOSE trans_a, CBLAS_DIAG diag, int M, int N,
paddle::platform::complex<double> alpha,
const paddle::platform::complex<double> *A, int lda,
paddle::platform::complex<double> *B, int ldb) {
const void *a_ = (const void *)(A);
void *b_ = static_cast<void *>(B);
platform::dynload::cblas_ztrsm(layout, side, uplo, trans_a, diag, M, N,
&alpha, a_, lda, b_, ldb);
}

template <typename... ARGS>
static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a,
CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K,
Expand Down Expand Up @@ -682,6 +704,15 @@ struct CBlas<platform::complex<float>> {
cblas_cgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta,
C, ldc);
}

static void TRSM(const CBLAS_LAYOUT layout, const CBLAS_SIDE side,
const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transA,
const CBLAS_DIAG diag, const int M, const int N,
const paddle::platform::complex<float> alpha,
const paddle::platform::complex<float> *A, const int lda,
paddle::platform::complex<double> *B, const int ldb) {
cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb);
}
};

template <>
Expand Down Expand Up @@ -720,6 +751,15 @@ struct CBlas<platform::complex<double>> {
cblas_zgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta,
C, ldc);
}

static void TRSM(const CBLAS_LAYOUT layout, const CBLAS_SIDE side,
const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transA,
const CBLAS_DIAG diag, const int M, const int N,
const paddle::platform::complex<double> alpha,
const paddle::platform::complex<double> *A, const int lda,
paddle::platform::complex<double> *B, const int ldb) {
cblas_ztrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb);
}
};

#endif
Expand Down
38 changes: 38 additions & 0 deletions paddle/fluid/operators/math/blas_impl.hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ struct CUBlas<float> {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasSmatinvBatched is not supported on HIP platform."));
}

template <typename... ARGS>
static void TRSM_BATCH(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasStrsmBatched is not supported on HIP platform."));
}
};

template <>
Expand Down Expand Up @@ -153,6 +159,12 @@ struct CUBlas<double> {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasDmatinvBatched is not supported on HIP platform."));
}

template <typename... ARGS>
static void TRSM_BATCH(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasDtrsmBatched is not supported on HIP platform."));
}
};

template <>
Expand Down Expand Up @@ -730,6 +742,32 @@ void Blas<platform::CUDADeviceContext>::BatchedGETRS(
batch_size);
});
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedTRSM(
CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, CBLAS_DIAG diag,
int M, int N, T alpha, const T **A, int lda, T **B, int ldb,
int batch_size) const {
// solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'`
// where ' stands for transpose
rocblas_side cuSide =
(side == CblasLeft) ? rocblas_side_right : rocblas_side_left;
rocblas_fill cuUplo =
(uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower;
// use CUBLAS_OP_C (conjugate transpose) for complex
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_diagonal cuDiag =
(diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit;

context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::TRSM_BATCH(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M,
&alpha, A, lda, B, ldb, batch_size);
});
}

} // namespace math
} // namespace operators
} // namespace paddle
39 changes: 39 additions & 0 deletions paddle/fluid/operators/math/matrix_solve.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,45 @@ class MatrixSolveFunctor<platform::CPUDeviceContext, T> {
template class MatrixSolveFunctor<platform::CPUDeviceContext, float>;
template class MatrixSolveFunctor<platform::CPUDeviceContext, double>;

template <typename T>
class TriangularSolveFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor* a, framework::Tensor* b, bool left,
bool upper, bool transpose, bool unitriangular) {
CBLAS_SIDE side = left ? CblasLeft : CblasRight;
CBLAS_UPLO uplo = upper ? CblasUpper : CblasLower;
CBLAS_TRANSPOSE transA = transpose ? CblasTrans : CblasNoTrans;
CBLAS_DIAG diag = unitriangular ? CblasUnit : CblasNonUnit;

const T* a_data = a->data<T>();
T* b_data = b->mutable_data<T>(context.GetPlace());

int a_dim_size = a->dims().size();
int b_dim_size = b->dims().size();

int M = static_cast<int>(b->dims()[b_dim_size - 2]);
int N = static_cast<int>(b->dims()[b_dim_size - 1]);
auto lda = left ? std::max(1, M) : std::max(1, N);
auto ldb = std::max(1, N);

int batch_size = 1;
auto& a_dim = a->dims();
for (int i = 0; i < a_dim_size - 2; i++) {
batch_size *= a_dim[i];
}

auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < batch_size; i++) {
blas.TRSM(side, uplo, transA, diag, M, N, T(1), a_data + i * M * M, lda,
b_data + i * N * M, ldb);
}
}
};

template class TriangularSolveFunctor<platform::CPUDeviceContext, float>;
template class TriangularSolveFunctor<platform::CPUDeviceContext, double>;

} // namespace math
} // namespace operators
} // namespace paddle
62 changes: 62 additions & 0 deletions paddle/fluid/operators/math/matrix_solve.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,68 @@ class MatrixSolveFunctor<platform::CUDADeviceContext, T> {
template class MatrixSolveFunctor<platform::CUDADeviceContext, float>;
template class MatrixSolveFunctor<platform::CUDADeviceContext, double>;

template <typename T>
class TriangularSolveFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context, const Tensor* a,
Tensor* b, bool left, bool upper, bool transpose,
bool unitriangular) {
CBLAS_SIDE side = left ? CblasLeft : CblasRight;
CBLAS_UPLO uplo = upper ? CblasUpper : CblasLower;
CBLAS_TRANSPOSE transA = transpose ? CblasTrans : CblasNoTrans;
CBLAS_DIAG diag = unitriangular ? CblasUnit : CblasNonUnit;

const T* a_data = a->data<T>();
T* b_data = b->mutable_data<T>(context.GetPlace());

int a_dim_size = a->dims().size();
int b_dim_size = b->dims().size();

int M = static_cast<int>(b->dims()[b_dim_size - 2]);
int N = static_cast<int>(b->dims()[b_dim_size - 1]);
auto lda = left ? std::max(1, M) : std::max(1, N);
auto ldb = std::max(1, N);

int batch_size = 1;
auto& a_dim = a->dims();
for (int i = 0; i < a_dim_size - 2; i++) {
batch_size *= a_dim[i];
}

auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
if (batch_size <= 8 && M >= 64) {
for (auto i = 0; i < batch_size; i++) {
blas.TRSM(side, uplo, transA, diag, M, N, static_cast<T>(1.0),
a_data + i * M * M, lda, b_data + i * N * M, ldb);
}
} else {
std::vector<const T*> cpu_ptrs(batch_size * 2);
for (int i = 0; i < batch_size; ++i) {
cpu_ptrs[i] = a_data + i * M * M;
cpu_ptrs[i + batch_size] = b_data + i * M * N;
}

// Copy the addresses of A and tmp_b from host to device.
memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
memory::Alloc(context, cpu_ptrs.size() * sizeof(T*));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_gpu_ptrs_data->ptr(), platform::CPUPlace(),
static_cast<void*>(cpu_ptrs.data()),
cpu_ptrs.size() * sizeof(T*), context.stream());

const T** gpu_a_ptrs =
reinterpret_cast<const T**>(tmp_gpu_ptrs_data->ptr());
T** gpu_b_ptrs =
reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()) + batch_size;
blas.BatchedTRSM(side, uplo, transA, diag, M, N, static_cast<T>(1.0),
gpu_a_ptrs, lda, gpu_b_ptrs, ldb, batch_size);
}
}
};

template class TriangularSolveFunctor<platform::CUDADeviceContext, float>;
template class TriangularSolveFunctor<platform::CUDADeviceContext, double>;

} // namespace math
} // namespace operators
} // namespace paddle
Loading