Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lapack - SVD: fix for unit-test when MKL is enabled #2110

Merged
merged 1 commit into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ namespace Impl {
Kokkos::Profiling::pushRegion("KokkosBlas::gemm[TPL_BLAS," #SCALAR_TYPE \
"]"); \
const bool A_t = (transA[0] != 'N') && (transA[0] != 'n'); \
const int M = C.extent(0); \
const int N = C.extent(1); \
const int K = A.extent(A_t ? 0 : 1); \
const KK_INT M = C.extent(0); \
const KK_INT N = C.extent(1); \
const KK_INT K = A.extent(A_t ? 0 : 1); \
\
bool A_is_lr = std::is_same<Kokkos::LayoutRight, LAYOUTA>::value; \
bool B_is_lr = std::is_same<Kokkos::LayoutRight, LAYOUTB>::value; \
bool C_is_lr = std::is_same<Kokkos::LayoutRight, LAYOUTC>::value; \
\
const int AST = A_is_lr ? A.stride(0) : A.stride(1), \
LDA = AST == 0 ? 1 : AST; \
const int BST = B_is_lr ? B.stride(0) : B.stride(1), \
LDB = BST == 0 ? 1 : BST; \
const int CST = C_is_lr ? C.stride(0) : C.stride(1), \
LDC = CST == 0 ? 1 : CST; \
const KK_INT AST = A_is_lr ? A.stride(0) : A.stride(1), \
LDA = AST == 0 ? 1 : AST; \
const KK_INT BST = B_is_lr ? B.stride(0) : B.stride(1), \
LDB = BST == 0 ? 1 : BST; \
const KK_INT CST = C_is_lr ? C.stride(0) : C.stride(1), \
LDC = CST == 0 ? 1 : CST; \
\
const BASE_SCALAR_TYPE alpha_val = alpha, beta_val = beta; \
if (!A_is_lr && !B_is_lr && !C_is_lr) \
Expand Down
73 changes: 39 additions & 34 deletions blas/tpls/KokkosBlas_Host_tpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

#if defined(KOKKOSKERNELS_ENABLE_TPL_BLAS)

using KokkosBlas::Impl::KK_INT;

/// Fortran headers
extern "C" {

Expand Down Expand Up @@ -339,26 +341,27 @@ void F77_BLAS_MANGLE(ztrsv, ZTRSV)(const char*, const char*, const char*, int*,
/// Gemm
///

void F77_BLAS_MANGLE(sgemm, SGEMM)(const char*, const char*, int*, int*, int*,
const float*, const float*, int*,
const float*, int*, const float*,
/* */ float*, int*);
void F77_BLAS_MANGLE(dgemm, DGEMM)(const char*, const char*, int*, int*, int*,
const double*, const double*, int*,
const double*, int*, const double*,
/* */ double*, int*);
void F77_BLAS_MANGLE(cgemm, CGEMM)(const char*, const char*, int*, int*, int*,
const std::complex<float>*,
const std::complex<float>*, int*,
const std::complex<float>*, int*,
void F77_BLAS_MANGLE(sgemm, SGEMM)(const char*, const char*, KK_INT*, KK_INT*,
KK_INT*, const float*, const float*, KK_INT*,
const float*, KK_INT*, const float*,
/* */ float*, KK_INT*);
void F77_BLAS_MANGLE(dgemm, DGEMM)(const char*, const char*, KK_INT*, KK_INT*,
KK_INT*, const double*, const double*,
KK_INT*, const double*, KK_INT*,
const double*,
/* */ double*, KK_INT*);
void F77_BLAS_MANGLE(cgemm, CGEMM)(const char*, const char*, KK_INT*, KK_INT*,
KK_INT*, const std::complex<float>*,
const std::complex<float>*, KK_INT*,
const std::complex<float>*, KK_INT*,
const std::complex<float>*,
/* */ std::complex<float>*, int*);
void F77_BLAS_MANGLE(zgemm, ZGEMM)(const char*, const char*, int*, int*, int*,
/* */ std::complex<float>*, KK_INT*);
void F77_BLAS_MANGLE(zgemm, ZGEMM)(const char*, const char*, KK_INT*, KK_INT*,
KK_INT*, const std::complex<double>*,
const std::complex<double>*, KK_INT*,
const std::complex<double>*, KK_INT*,
const std::complex<double>*,
const std::complex<double>*, int*,
const std::complex<double>*, int*,
const std::complex<double>*,
/* */ std::complex<double>*, int*);
/* */ std::complex<double>*, KK_INT*);

///
/// Herk
Expand Down Expand Up @@ -632,10 +635,11 @@ void HostBlas<float>::trsv(const char uplo, const char transa, const char diag,
F77_FUNC_STRSV(&uplo, &transa, &diag, &m, a, &lda, b, &ldb);
}
template <>
void HostBlas<float>::gemm(const char transa, const char transb, int m, int n,
int k, const float alpha, const float* a, int lda,
const float* b, int ldb, const float beta,
/* */ float* c, int ldc) {
void HostBlas<float>::gemm(const char transa, const char transb, KK_INT m,
KK_INT n, KK_INT k, const float alpha,
const float* a, KK_INT lda, const float* b,
KK_INT ldb, const float beta,
/* */ float* c, KK_INT ldc) {
F77_FUNC_SGEMM(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta,
c, &ldc);
}
Expand Down Expand Up @@ -750,10 +754,11 @@ void HostBlas<double>::trsv(const char uplo, const char transa, const char diag,
F77_FUNC_DTRSV(&uplo, &transa, &diag, &m, a, &lda, b, &ldb);
}
template <>
void HostBlas<double>::gemm(const char transa, const char transb, int m, int n,
int k, const double alpha, const double* a, int lda,
const double* b, int ldb, const double beta,
/* */ double* c, int ldc) {
void HostBlas<double>::gemm(const char transa, const char transb, KK_INT m,
KK_INT n, KK_INT k, const double alpha,
const double* a, KK_INT lda, const double* b,
KK_INT ldb, const double beta,
/* */ double* c, KK_INT ldc) {
F77_FUNC_DGEMM(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta,
c, &ldc);
}
Expand Down Expand Up @@ -906,10 +911,10 @@ void HostBlas<std::complex<float> >::trsv(const char uplo, const char transa,
}
template <>
void HostBlas<std::complex<float> >::gemm(
const char transa, const char transb, int m, int n, int k,
const std::complex<float> alpha, const std::complex<float>* a, int lda,
const std::complex<float>* b, int ldb, const std::complex<float> beta,
/* */ std::complex<float>* c, int ldc) {
const char transa, const char transb, KK_INT m, KK_INT n, KK_INT k,
const std::complex<float> alpha, const std::complex<float>* a, KK_INT lda,
const std::complex<float>* b, KK_INT ldb, const std::complex<float> beta,
/* */ std::complex<float>* c, KK_INT ldc) {
F77_FUNC_CGEMM(&transa, &transb, &m, &n, &k, &alpha,
(const std::complex<float>*)a, &lda,
(const std::complex<float>*)b, &ldb, &beta,
Expand Down Expand Up @@ -1081,10 +1086,10 @@ void HostBlas<std::complex<double> >::trsv(const char uplo, const char transa,

template <>
void HostBlas<std::complex<double> >::gemm(
const char transa, const char transb, int m, int n, int k,
const std::complex<double> alpha, const std::complex<double>* a, int lda,
const std::complex<double>* b, int ldb, const std::complex<double> beta,
/* */ std::complex<double>* c, int ldc) {
const char transa, const char transb, KK_INT m, KK_INT n, KK_INT k,
const std::complex<double> alpha, const std::complex<double>* a, KK_INT lda,
const std::complex<double>* b, KK_INT ldb, const std::complex<double> beta,
/* */ std::complex<double>* c, KK_INT ldc) {
F77_FUNC_ZGEMM(&transa, &transb, &m, &n, &k, &alpha,
(const std::complex<double>*)a, &lda,
(const std::complex<double>*)b, &ldb, &beta,
Expand Down
17 changes: 13 additions & 4 deletions blas/tpls/KokkosBlas_Host_tpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,19 @@
#include "Kokkos_ArithTraits.hpp"

#if defined(KOKKOSKERNELS_ENABLE_TPL_BLAS)
#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
#include "mkl_types.h"
#endif

namespace KokkosBlas {
namespace Impl {

#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
using KK_INT = MKL_INT;
#else
using KK_INT = int;
#endif

template <typename T>
struct HostBlas {
typedef Kokkos::ArithTraits<T> ats;
Expand Down Expand Up @@ -97,10 +106,10 @@ struct HostBlas {
const T *a, int lda,
/* */ T *b, int ldb);

static void gemm(const char transa, const char transb, int m, int n, int k,
const T alpha, const T *a, int lda, const T *b, int ldb,
const T beta,
/* */ T *c, int ldc);
static void gemm(const char transa, const char transb, KK_INT m, KK_INT n,
KK_INT k, const T alpha, const T *a, KK_INT lda, const T *b,
KK_INT ldb, const T beta,
/* */ T *c, KK_INT ldc);

static void herk(const char transa, const char transb, int n, int k,
const T alpha, const T *a, int lda, const T beta,
Expand Down
4 changes: 4 additions & 0 deletions lapack/unit_test/Test_Lapack_svd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,10 @@ int test_svd() {
Kokkos::View<ScalarA**, Kokkos::LayoutLeft, Device>;

ret = Test::impl_analytic_2x2_svd<view_type_a_layout_left, Device>();
EXPECT_EQ(ret, 0);

ret = Test::impl_analytic_2x3_svd<view_type_a_layout_left, Device>();
EXPECT_EQ(ret, 0);

ret = Test::impl_test_svd<view_type_a_layout_left, Device>(0, 0);
EXPECT_EQ(ret, 0);
Expand Down Expand Up @@ -558,8 +560,10 @@ int test_svd() {
Kokkos::View<ScalarA**, Kokkos::LayoutRight, Device>;

ret = Test::impl_analytic_2x2_svd<view_type_a_layout_right, Device>();
EXPECT_EQ(ret, 0);

ret = Test::impl_analytic_2x3_svd<view_type_a_layout_right, Device>();
EXPECT_EQ(ret, 0);

ret = Test::impl_test_svd<view_type_a_layout_right, Device>(0, 0);
EXPECT_EQ(ret, 0);
Expand Down
Loading