From 9202f86be315135075dd2017143d9588892043f0 Mon Sep 17 00:00:00 2001 From: Luc Berger-Vergiat Date: Fri, 16 Feb 2024 11:03:19 -0700 Subject: [PATCH] BLAS - MKL: fixing HostBlas calls to handle MKL_INT type MKL redefines the BLAS interface based on how MKL_INT is defined we need to wrap that definition with our own Kokkos Kernels INT type to make both compatible with regular BLAS. applying clang-format --- blas/tpls/KokkosBlas_Host_tpl.cpp | 822 +++++++++++++++--------------- blas/tpls/KokkosBlas_Host_tpl.hpp | 93 ++-- 2 files changed, 464 insertions(+), 451 deletions(-) diff --git a/blas/tpls/KokkosBlas_Host_tpl.cpp b/blas/tpls/KokkosBlas_Host_tpl.cpp index 68f2810907..50aab57c73 100644 --- a/blas/tpls/KokkosBlas_Host_tpl.cpp +++ b/blas/tpls/KokkosBlas_Host_tpl.cpp @@ -30,66 +30,68 @@ extern "C" { /// /// scal /// -void F77_BLAS_MANGLE(sscal, SSCAL)(const int* N, const float* alpha, - /* */ float* x, const int* x_inc); -void F77_BLAS_MANGLE(dscal, DSCAL)(const int* N, const double* alpha, - /* */ double* x, const int* x_inc); +void F77_BLAS_MANGLE(sscal, SSCAL)(const KK_INT* N, const float* alpha, + /* */ float* x, const KK_INT* x_inc); +void F77_BLAS_MANGLE(dscal, DSCAL)(const KK_INT* N, const double* alpha, + /* */ double* x, const KK_INT* x_inc); void F77_BLAS_MANGLE(cscal, - CSCAL)(const int* N, const std::complex* alpha, - /* */ std::complex* x, const int* x_inc); + CSCAL)(const KK_INT* N, const std::complex* alpha, + /* */ std::complex* x, const KK_INT* x_inc); void F77_BLAS_MANGLE(zscal, - ZSCAL)(const int* N, const std::complex* alpha, - /* */ std::complex* x, const int* x_inc); + ZSCAL)(const KK_INT* N, const std::complex* alpha, + /* */ std::complex* x, const KK_INT* x_inc); /// /// max /// -int F77_BLAS_MANGLE(isamax, ISAMAX)(const int* N, const float* x, - const int* x_inc); -int F77_BLAS_MANGLE(idamax, IDAMAX)(const int* N, const double* x, - const int* x_inc); -int F77_BLAS_MANGLE(icamax, ICAMAX)(const int* N, const std::complex* x, - const int* x_inc); -int F77_BLAS_MANGLE(izamax, IZAMAX)(const int* N, const std::complex* x, - const int* x_inc); +KK_INT F77_BLAS_MANGLE(isamax, ISAMAX)(const KK_INT* N, const float* x, + const KK_INT* x_inc); +KK_INT F77_BLAS_MANGLE(idamax, IDAMAX)(const KK_INT* N, const double* x, + const KK_INT* x_inc); +KK_INT F77_BLAS_MANGLE(icamax, ICAMAX)(const KK_INT* N, + const std::complex* x, + const KK_INT* x_inc); +KK_INT F77_BLAS_MANGLE(izamax, IZAMAX)(const KK_INT* N, + const std::complex* x, + const KK_INT* x_inc); /// /// nrm2 /// -float F77_BLAS_MANGLE(snrm2, SNRM2)(const int* N, const float* x, - const int* x_inc); -double F77_BLAS_MANGLE(dnrm2, DNRM2)(const int* N, const double* x, - const int* x_inc); -float F77_BLAS_MANGLE(scnrm2, SCNRM2)(const int* N, +float F77_BLAS_MANGLE(snrm2, SNRM2)(const KK_INT* N, const float* x, + const KK_INT* x_inc); +double F77_BLAS_MANGLE(dnrm2, DNRM2)(const KK_INT* N, const double* x, + const KK_INT* x_inc); +float F77_BLAS_MANGLE(scnrm2, SCNRM2)(const KK_INT* N, const std::complex* x, - const int* x_inc); -double F77_BLAS_MANGLE(dznrm2, DZNRM2)(const int* N, + const KK_INT* x_inc); +double F77_BLAS_MANGLE(dznrm2, DZNRM2)(const KK_INT* N, const std::complex* x, - const int* x_inc); + const KK_INT* x_inc); /// /// sum /// -float F77_BLAS_MANGLE(sasum, SASUM)(const int* N, const float* x, - const int* x_inc); -double F77_BLAS_MANGLE(dasum, DASUM)(const int* N, const double* x, - const int* x_inc); -float F77_BLAS_MANGLE(scasum, SCASUM)(const int* N, +float F77_BLAS_MANGLE(sasum, SASUM)(const KK_INT* N, const float* x, + const KK_INT* x_inc); +double F77_BLAS_MANGLE(dasum, DASUM)(const KK_INT* N, const double* x, + const KK_INT* x_inc); +float F77_BLAS_MANGLE(scasum, SCASUM)(const KK_INT* N, const std::complex* x, - const int* x_inc); -double F77_BLAS_MANGLE(dzasum, DZASUM)(const int* N, + const KK_INT* x_inc); +double F77_BLAS_MANGLE(dzasum, DZASUM)(const KK_INT* N, const std::complex* x, - const int* x_inc); + const KK_INT* x_inc); /// /// dot /// -float F77_BLAS_MANGLE(sdot, SDOT)(const int* N, const float* x, - const int* x_inc, const float* y, - const int* y_inc); -double F77_BLAS_MANGLE(ddot, DDOT)(const int* N, const double* x, - const int* x_inc, const double* y, - const int* y_inc); +float F77_BLAS_MANGLE(sdot, SDOT)(const KK_INT* N, const float* x, + const KK_INT* x_inc, const float* y, + const KK_INT* y_inc); +double F77_BLAS_MANGLE(ddot, DDOT)(const KK_INT* N, const double* x, + const KK_INT* x_inc, const double* y, + const KK_INT* y_inc); #if defined(KOKKOSKERNELS_TPL_BLAS_RETURN_COMPLEX) // clang-format off // For the return type, don't use std::complex, otherwise compiler will complain @@ -104,77 +106,78 @@ typedef struct { double vals[2]; } _kk_double2; -_kk_float2 F77_BLAS_MANGLE(cdotu, CDOTU)(const int* N, +_kk_float2 F77_BLAS_MANGLE(cdotu, CDOTU)(const KK_INT* N, const std::complex* x, - const int* x_inc, + const KK_INT* x_inc, const std::complex* y, - const int* y_inc); -_kk_double2 F77_BLAS_MANGLE(zdotu, ZDOTU)(const int* N, + const KK_INT* y_inc); +_kk_double2 F77_BLAS_MANGLE(zdotu, ZDOTU)(const KK_INT* N, const std::complex* x, - const int* x_inc, + const KK_INT* x_inc, const std::complex* y, - const int* y_inc); -_kk_float2 F77_BLAS_MANGLE(cdotc, CDOTC)(const int* N, + const KK_INT* y_inc); +_kk_float2 F77_BLAS_MANGLE(cdotc, CDOTC)(const KK_INT* N, const std::complex* x, - const int* x_inc, + const KK_INT* x_inc, const std::complex* y, - const int* y_inc); -_kk_double2 F77_BLAS_MANGLE(zdotc, ZDOTC)(const int* N, + const KK_INT* y_inc); +_kk_double2 F77_BLAS_MANGLE(zdotc, ZDOTC)(const KK_INT* N, const std::complex* x, - const int* x_inc, + const KK_INT* x_inc, const std::complex* y, - const int* y_inc); + const KK_INT* y_inc); #else void F77_BLAS_MANGLE(cdotu, - CDOTU)(std::complex* res, const int* N, - const std::complex* x, const int* x_inc, - const std::complex* y, const int* y_inc); + CDOTU)(std::complex* res, const KK_INT* N, + const std::complex* x, const KK_INT* x_inc, + const std::complex* y, const KK_INT* y_inc); void F77_BLAS_MANGLE(zdotu, - ZDOTU)(std::complex* res, const int* N, - const std::complex* x, const int* x_inc, - const std::complex* y, const int* y_inc); + ZDOTU)(std::complex* res, const KK_INT* N, + const std::complex* x, const KK_INT* x_inc, + const std::complex* y, const KK_INT* y_inc); void F77_BLAS_MANGLE(cdotc, - CDOTC)(std::complex* res, const int* N, - const std::complex* x, const int* x_inc, - const std::complex* y, const int* y_inc); + CDOTC)(std::complex* res, const KK_INT* N, + const std::complex* x, const KK_INT* x_inc, + const std::complex* y, const KK_INT* y_inc); void F77_BLAS_MANGLE(zdotc, - ZDOTC)(std::complex* res, const int* N, - const std::complex* x, const int* x_inc, - const std::complex* y, const int* y_inc); + ZDOTC)(std::complex* res, const KK_INT* N, + const std::complex* x, const KK_INT* x_inc, + const std::complex* y, const KK_INT* y_inc); #endif /// /// axpy /// -void F77_BLAS_MANGLE(saxpy, SAXPY)(const int* N, const float* alpha, - const float* x, const int* x_inc, - /* */ float* y, const int* y_inc); -void F77_BLAS_MANGLE(daxpy, DAXPY)(const int* N, const double* alpha, - const double* x, const int* x_inc, - /* */ double* y, const int* y_inc); +void F77_BLAS_MANGLE(saxpy, SAXPY)(const KK_INT* N, const float* alpha, + const float* x, const KK_INT* x_inc, + /* */ float* y, const KK_INT* y_inc); +void F77_BLAS_MANGLE(daxpy, DAXPY)(const KK_INT* N, const double* alpha, + const double* x, const KK_INT* x_inc, + /* */ double* y, const KK_INT* y_inc); void F77_BLAS_MANGLE(caxpy, - CAXPY)(const int* N, const std::complex* alpha, - const std::complex* x, const int* x_inc, - /* */ std::complex* y, const int* y_inc); + CAXPY)(const KK_INT* N, const std::complex* alpha, + const std::complex* x, const KK_INT* x_inc, + /* */ std::complex* y, const KK_INT* y_inc); void F77_BLAS_MANGLE(zaxpy, - ZAXPY)(const int* N, const std::complex* alpha, - const std::complex* x, const int* x_inc, - /* */ std::complex* y, const int* y_inc); + ZAXPY)(const KK_INT* N, const std::complex* alpha, + const std::complex* x, const KK_INT* x_inc, + /* */ std::complex* y, const KK_INT* y_inc); /// /// rot /// -void F77_BLAS_MANGLE(srot, SROT)(int const* N, float* X, int const* incx, - float* Y, int const* incy, float* c, float* s); -void F77_BLAS_MANGLE(drot, DROT)(int const* N, double* X, int const* incx, - double* Y, int const* incy, double* c, +void F77_BLAS_MANGLE(srot, SROT)(KK_INT const* N, float* X, KK_INT const* incx, + float* Y, KK_INT const* incy, float* c, + float* s); +void F77_BLAS_MANGLE(drot, DROT)(KK_INT const* N, double* X, KK_INT const* incx, + double* Y, KK_INT const* incy, double* c, double* s); -void F77_BLAS_MANGLE(crot, CROT)(int const* N, std::complex* X, - int const* incx, std::complex* Y, - int const* incy, float* c, float* s); -void F77_BLAS_MANGLE(zrot, ZROT)(int const* N, std::complex* X, - int const* incx, std::complex* Y, - int const* incy, double* c, double* s); +void F77_BLAS_MANGLE(crot, CROT)(KK_INT const* N, std::complex* X, + KK_INT const* incx, std::complex* Y, + KK_INT const* incy, float* c, float* s); +void F77_BLAS_MANGLE(zrot, ZROT)(KK_INT const* N, std::complex* X, + KK_INT const* incx, std::complex* Y, + KK_INT const* incy, double* c, double* s); /// /// rotg @@ -191,12 +194,12 @@ void F77_BLAS_MANGLE(zrotg, ZROTG)(std::complex* a, /// /// rotm /// -void F77_BLAS_MANGLE(srotm, SROTM)(const int* n, float* X, const int* incx, - float* Y, const int* incy, - float const* param); -void F77_BLAS_MANGLE(drotm, DROTM)(const int* n, double* X, const int* incx, - double* Y, const int* incy, - double const* param); +void F77_BLAS_MANGLE(srotm, SROTM)(const KK_INT* n, float* X, + const KK_INT* incx, float* Y, + const KK_INT* incy, float const* param); +void F77_BLAS_MANGLE(drotm, DROTM)(const KK_INT* n, double* X, + const KK_INT* incx, double* Y, + const KK_INT* incy, double const* param); /// /// rotmg @@ -209,72 +212,78 @@ void F77_BLAS_MANGLE(drotmg, DROTMG)(double* d1, double* d2, double* x1, /// /// swap /// -void F77_BLAS_MANGLE(sswap, SSWAP)(int const* N, float* X, int const* incx, - float* Y, int const* incy); -void F77_BLAS_MANGLE(dswap, DSWAP)(int const* N, double* X, int const* incx, - double* Y, int const* incy); -void F77_BLAS_MANGLE(cswap, CSWAP)(int const* N, std::complex* X, - int const* incx, std::complex* Y, - int const* incy); -void F77_BLAS_MANGLE(zswap, ZSWAP)(int const* N, std::complex* X, - int const* incx, std::complex* Y, - int const* incy); +void F77_BLAS_MANGLE(sswap, SSWAP)(KK_INT const* N, float* X, + KK_INT const* incx, float* Y, + KK_INT const* incy); +void F77_BLAS_MANGLE(dswap, DSWAP)(KK_INT const* N, double* X, + KK_INT const* incx, double* Y, + KK_INT const* incy); +void F77_BLAS_MANGLE(cswap, CSWAP)(KK_INT const* N, std::complex* X, + KK_INT const* incx, std::complex* Y, + KK_INT const* incy); +void F77_BLAS_MANGLE(zswap, ZSWAP)(KK_INT const* N, std::complex* X, + KK_INT const* incx, std::complex* Y, + KK_INT const* incy); /// /// Gemv /// -void F77_BLAS_MANGLE(sgemv, SGEMV)(const char*, int*, int*, const float*, - const float*, int*, const float*, int*, +void F77_BLAS_MANGLE(sgemv, SGEMV)(const char*, KK_INT*, KK_INT*, const float*, + const float*, KK_INT*, const float*, KK_INT*, const float*, - /* */ float*, int*); -void F77_BLAS_MANGLE(dgemv, DGEMV)(const char*, int*, int*, const double*, - const double*, int*, const double*, int*, - const double*, - /* */ double*, int*); -void F77_BLAS_MANGLE(cgemv, CGEMV)(const char*, int*, int*, + /* */ float*, KK_INT*); +void F77_BLAS_MANGLE(dgemv, DGEMV)(const char*, KK_INT*, KK_INT*, const double*, + const double*, KK_INT*, const double*, + KK_INT*, const double*, + /* */ double*, KK_INT*); +void F77_BLAS_MANGLE(cgemv, CGEMV)(const char*, KK_INT*, KK_INT*, const std::complex*, - const std::complex*, int*, - const std::complex*, int*, + const std::complex*, KK_INT*, + const std::complex*, KK_INT*, const std::complex*, - /* */ std::complex*, int*); -void F77_BLAS_MANGLE(zgemv, ZGEMV)(const char*, int*, int*, + /* */ std::complex*, KK_INT*); +void F77_BLAS_MANGLE(zgemv, ZGEMV)(const char*, KK_INT*, KK_INT*, const std::complex*, - const std::complex*, int*, - const std::complex*, int*, + const std::complex*, KK_INT*, + const std::complex*, KK_INT*, const std::complex*, - /* */ std::complex*, int*); + /* */ std::complex*, KK_INT*); /// /// Ger /// -void F77_BLAS_MANGLE(sger, SGER)(int*, int*, const float*, const float*, int*, - const float*, int*, float*, int*); -void F77_BLAS_MANGLE(dger, DGER)(int*, int*, const double*, const double*, int*, - const double*, int*, double*, int*); -void F77_BLAS_MANGLE(cgeru, CGERU)(int*, int*, const std::complex*, - const std::complex*, int*, - const std::complex*, int*, - std::complex*, int*); -void F77_BLAS_MANGLE(zgeru, ZGERU)(int*, int*, const std::complex*, - const std::complex*, int*, - const std::complex*, int*, - std::complex*, int*); -void F77_BLAS_MANGLE(cgerc, CGERC)(int*, int*, const std::complex*, - const std::complex*, int*, - const std::complex*, int*, - std::complex*, int*); -void F77_BLAS_MANGLE(zgerc, ZGERC)(int*, int*, const std::complex*, - const std::complex*, int*, - const std::complex*, int*, - std::complex*, int*); +void F77_BLAS_MANGLE(sger, SGER)(KK_INT*, KK_INT*, const float*, const float*, + KK_INT*, const float*, KK_INT*, float*, + KK_INT*); +void F77_BLAS_MANGLE(dger, DGER)(KK_INT*, KK_INT*, const double*, const double*, + KK_INT*, const double*, KK_INT*, double*, + KK_INT*); +void F77_BLAS_MANGLE(cgeru, CGERU)(KK_INT*, KK_INT*, const std::complex*, + const std::complex*, KK_INT*, + const std::complex*, KK_INT*, + std::complex*, KK_INT*); +void F77_BLAS_MANGLE(zgeru, ZGERU)(KK_INT*, KK_INT*, + const std::complex*, + const std::complex*, KK_INT*, + const std::complex*, KK_INT*, + std::complex*, KK_INT*); +void F77_BLAS_MANGLE(cgerc, CGERC)(KK_INT*, KK_INT*, const std::complex*, + const std::complex*, KK_INT*, + const std::complex*, KK_INT*, + std::complex*, KK_INT*); +void F77_BLAS_MANGLE(zgerc, ZGERC)(KK_INT*, KK_INT*, + const std::complex*, + const std::complex*, KK_INT*, + const std::complex*, KK_INT*, + std::complex*, KK_INT*); /// /// Syr /// -void F77_BLAS_MANGLE(ssyr, SSYR)(const char*, int*, const float*, const float*, - int*, float*, int*); -void F77_BLAS_MANGLE(dsyr, DSYR)(const char*, int*, const double*, - const double*, int*, double*, int*); +void F77_BLAS_MANGLE(ssyr, SSYR)(const char*, KK_INT*, const float*, + const float*, KK_INT*, float*, KK_INT*); +void F77_BLAS_MANGLE(dsyr, DSYR)(const char*, KK_INT*, const double*, + const double*, KK_INT*, double*, KK_INT*); // Although there is a cgeru, there is no csyru // Although there is a zgeru, there is no zsyru // Although there is a cgerc, there is no csyrc, but there is cher (see below) @@ -284,22 +293,22 @@ void F77_BLAS_MANGLE(dsyr, DSYR)(const char*, int*, const double*, /// Her /// -void F77_BLAS_MANGLE(cher, CHER)(const char*, int*, const float*, - const std::complex*, int*, - std::complex*, int*); -void F77_BLAS_MANGLE(zher, ZHER)(const char*, int*, const double*, - const std::complex*, int*, - std::complex*, int*); +void F77_BLAS_MANGLE(cher, CHER)(const char*, KK_INT*, const float*, + const std::complex*, KK_INT*, + std::complex*, KK_INT*); +void F77_BLAS_MANGLE(zher, ZHER)(const char*, KK_INT*, const double*, + const std::complex*, KK_INT*, + std::complex*, KK_INT*); /// /// Syr2 /// -void F77_BLAS_MANGLE(ssyr2, SSYR2)(const char*, int*, const float*, - const float*, const int*, const float*, int*, - float*, int*); -void F77_BLAS_MANGLE(dsyr2, DSYR2)(const char*, int*, const double*, - const double*, const int*, const double*, - int*, double*, int*); +void F77_BLAS_MANGLE(ssyr2, SSYR2)(const char*, KK_INT*, const float*, + const float*, const KK_INT*, const float*, + KK_INT*, float*, KK_INT*); +void F77_BLAS_MANGLE(dsyr2, DSYR2)(const char*, KK_INT*, const double*, + const double*, const KK_INT*, const double*, + KK_INT*, double*, KK_INT*); // Although there is a cgeru, there is no csyr2u // Although there is a zgeru, there is no zsyr2u // Although there is a cgerc, there is no csyr2c, but there is cher2 (see below) @@ -309,33 +318,34 @@ void F77_BLAS_MANGLE(dsyr2, DSYR2)(const char*, int*, const double*, /// Her2 /// -void F77_BLAS_MANGLE(cher2, CHER2)(const char*, int*, +void F77_BLAS_MANGLE(cher2, CHER2)(const char*, KK_INT*, const std::complex*, - const std::complex*, int*, - const std::complex*, int*, - std::complex*, int*); -void F77_BLAS_MANGLE(zher2, ZHER2)(const char*, int*, + const std::complex*, KK_INT*, + const std::complex*, KK_INT*, + std::complex*, KK_INT*); +void F77_BLAS_MANGLE(zher2, ZHER2)(const char*, KK_INT*, const std::complex*, - const std::complex*, int*, - const std::complex*, int*, - std::complex*, int*); + const std::complex*, KK_INT*, + const std::complex*, KK_INT*, + std::complex*, KK_INT*); /// /// Trsv /// -void F77_BLAS_MANGLE(strsv, STRSV)(const char*, const char*, const char*, int*, - const float*, int*, - /* */ float*, int*); -void F77_BLAS_MANGLE(dtrsv, DTRSV)(const char*, const char*, const char*, int*, - const double*, int*, - /* */ double*, int*); -void F77_BLAS_MANGLE(ctrsv, CTRSV)(const char*, const char*, const char*, int*, - const std::complex*, int*, - /* */ std::complex*, int*); -void F77_BLAS_MANGLE(ztrsv, ZTRSV)(const char*, const char*, const char*, int*, - const std::complex*, int*, - /* */ std::complex*, int*); +void F77_BLAS_MANGLE(strsv, STRSV)(const char*, const char*, const char*, + KK_INT*, const float*, KK_INT*, + /* */ float*, KK_INT*); +void F77_BLAS_MANGLE(dtrsv, DTRSV)(const char*, const char*, const char*, + KK_INT*, const double*, KK_INT*, + /* */ double*, KK_INT*); +void F77_BLAS_MANGLE(ctrsv, CTRSV)(const char*, const char*, const char*, + KK_INT*, const std::complex*, KK_INT*, + /* */ std::complex*, KK_INT*); +void F77_BLAS_MANGLE(ztrsv, ZTRSV)(const char*, const char*, const char*, + KK_INT*, const std::complex*, + KK_INT*, + /* */ std::complex*, KK_INT*); /// /// Gemm @@ -367,82 +377,82 @@ void F77_BLAS_MANGLE(zgemm, ZGEMM)(const char*, const char*, KK_INT*, KK_INT*, /// Herk /// -void F77_BLAS_MANGLE(ssyrk, SSYRK)(const char*, const char*, int*, int*, - const float*, const float*, int*, +void F77_BLAS_MANGLE(ssyrk, SSYRK)(const char*, const char*, KK_INT*, KK_INT*, + const float*, const float*, KK_INT*, const float*, - /* */ float*, int*); -void F77_BLAS_MANGLE(dsyrk, DSYRK)(const char*, const char*, int*, int*, - const double*, const double*, int*, + /* */ float*, KK_INT*); +void F77_BLAS_MANGLE(dsyrk, DSYRK)(const char*, const char*, KK_INT*, KK_INT*, + const double*, const double*, KK_INT*, const double*, - /* */ double*, int*); -void F77_BLAS_MANGLE(cherk, CHERK)(const char*, const char*, int*, int*, + /* */ double*, KK_INT*); +void F77_BLAS_MANGLE(cherk, CHERK)(const char*, const char*, KK_INT*, KK_INT*, const std::complex*, - const std::complex*, int*, + const std::complex*, KK_INT*, const std::complex*, - /* */ std::complex*, int*); -void F77_BLAS_MANGLE(zherk, ZHERK)(const char*, const char*, int*, int*, + /* */ std::complex*, KK_INT*); +void F77_BLAS_MANGLE(zherk, ZHERK)(const char*, const char*, KK_INT*, KK_INT*, const std::complex*, - const std::complex*, int*, + const std::complex*, KK_INT*, const std::complex*, - /* */ std::complex*, int*); + /* */ std::complex*, KK_INT*); /// /// Trmm /// void F77_BLAS_MANGLE(strmm, STRMM)(const char*, const char*, const char*, - const char*, int*, int*, const float*, - const float*, int*, - /* */ float*, int*); + const char*, KK_INT*, KK_INT*, const float*, + const float*, KK_INT*, + /* */ float*, KK_INT*); void F77_BLAS_MANGLE(dtrmm, DTRMM)(const char*, const char*, const char*, - const char*, int*, int*, const double*, - const double*, int*, - /* */ double*, int*); + const char*, KK_INT*, KK_INT*, const double*, + const double*, KK_INT*, + /* */ double*, KK_INT*); void F77_BLAS_MANGLE(ctrmm, CTRMM)(const char*, const char*, const char*, - const char*, int*, int*, + const char*, KK_INT*, KK_INT*, const std::complex*, - const std::complex*, int*, - /* */ std::complex*, int*); + const std::complex*, KK_INT*, + /* */ std::complex*, KK_INT*); void F77_BLAS_MANGLE(ztrmm, ZTRMM)(const char*, const char*, const char*, - const char*, int*, int*, + const char*, KK_INT*, KK_INT*, const std::complex*, - const std::complex*, int*, - /* */ std::complex*, int*); + const std::complex*, KK_INT*, + /* */ std::complex*, KK_INT*); /// /// Trsm /// void F77_BLAS_MANGLE(strsm, STRSM)(const char*, const char*, const char*, - const char*, int*, int*, const float*, - const float*, int*, - /* */ float*, int*); + const char*, KK_INT*, KK_INT*, const float*, + const float*, KK_INT*, + /* */ float*, KK_INT*); void F77_BLAS_MANGLE(dtrsm, DTRSM)(const char*, const char*, const char*, - const char*, int*, int*, const double*, - const double*, int*, - /* */ double*, int*); + const char*, KK_INT*, KK_INT*, const double*, + const double*, KK_INT*, + /* */ double*, KK_INT*); void F77_BLAS_MANGLE(ctrsm, CTRSM)(const char*, const char*, const char*, - const char*, int*, int*, + const char*, KK_INT*, KK_INT*, const std::complex*, - const std::complex*, int*, - /* */ std::complex*, int*); + const std::complex*, KK_INT*, + /* */ std::complex*, KK_INT*); void F77_BLAS_MANGLE(ztrsm, ZTRSM)(const char*, const char*, const char*, - const char*, int*, int*, + const char*, KK_INT*, KK_INT*, const std::complex*, - const std::complex*, int*, - /* */ std::complex*, int*); + const std::complex*, KK_INT*, + /* */ std::complex*, KK_INT*); } -void F77_BLAS_MANGLE(sscal, SSCAL)(const int* N, const float* alpha, - /* */ float* x, const int* x_inc); -void F77_BLAS_MANGLE(dscal, DSCAL)(const int* N, const double* alpha, - /* */ double* x, const int* x_inc); +void F77_BLAS_MANGLE(sscal, SSCAL)(const KK_INT* N, const float* alpha, + /* */ float* x, const KK_INT* x_inc); +void F77_BLAS_MANGLE(dscal, DSCAL)(const KK_INT* N, const double* alpha, + /* */ double* x, const KK_INT* x_inc); void F77_BLAS_MANGLE(cscal, - CSCAL)(const int* N, const std::complex* alpha, - /* */ std::complex* x, const int* x_inc); + CSCAL)(const KK_INT* N, const std::complex* alpha, + /* */ std::complex* x, const KK_INT* x_inc); void F77_BLAS_MANGLE(zscal, - ZSCAL)(const int* N, const std::complex* alpha, - /* */ std::complex* x, const int* x_inc); + ZSCAL)(const KK_INT* N, const std::complex* alpha, + /* */ std::complex* x, const KK_INT* x_inc); #define F77_FUNC_SSCAL F77_BLAS_MANGLE(sscal, SSCAL) #define F77_FUNC_DSCAL F77_BLAS_MANGLE(dscal, DSCAL) @@ -554,35 +564,36 @@ namespace Impl { /// template <> -void HostBlas::scal(int n, const float alpha, - /* */ float* x, int x_inc) { +void HostBlas::scal(KK_INT n, const float alpha, + /* */ float* x, KK_INT x_inc) { F77_FUNC_SSCAL(&n, &alpha, x, &x_inc); } template <> -int HostBlas::iamax(int n, const float* x, int x_inc) { +KK_INT HostBlas::iamax(KK_INT n, const float* x, KK_INT x_inc) { return F77_FUNC_ISAMAX(&n, x, &x_inc); } template <> -float HostBlas::nrm2(int n, const float* x, int x_inc) { +float HostBlas::nrm2(KK_INT n, const float* x, KK_INT x_inc) { return F77_FUNC_SNRM2(&n, x, &x_inc); } template <> -float HostBlas::asum(int n, const float* x, int x_inc) { +float HostBlas::asum(KK_INT n, const float* x, KK_INT x_inc) { return F77_FUNC_SASUM(&n, x, &x_inc); } template <> -float HostBlas::dot(int n, const float* x, int x_inc, const float* y, - int y_inc) { +float HostBlas::dot(KK_INT n, const float* x, KK_INT x_inc, + const float* y, KK_INT y_inc) { return F77_FUNC_SDOT(&n, x, &x_inc, y, &y_inc); } template <> -void HostBlas::axpy(int n, const float alpha, const float* x, int x_inc, - /* */ float* y, int y_inc) { +void HostBlas::axpy(KK_INT n, const float alpha, const float* x, + KK_INT x_inc, + /* */ float* y, KK_INT y_inc) { F77_FUNC_SAXPY(&n, &alpha, x, &x_inc, y, &y_inc); } template <> -void HostBlas::rot(int const N, float* X, int const incx, float* Y, - int const incy, float* c, float* s) { +void HostBlas::rot(KK_INT const N, float* X, KK_INT const incx, float* Y, + KK_INT const incy, float* c, float* s) { F77_FUNC_SROT(&N, X, &incx, Y, &incy, c, s); } template <> @@ -590,8 +601,8 @@ void HostBlas::rotg(float* a, float* b, float* c, float* s) { F77_FUNC_SROTG(a, b, c, s); } template <> -void HostBlas::rotm(const int n, float* X, const int incx, float* Y, - const int incy, const float* param) { +void HostBlas::rotm(const KK_INT n, float* X, const KK_INT incx, + float* Y, const KK_INT incy, const float* param) { F77_FUNC_SROTM(&n, X, &incx, Y, &incy, param); } template <> @@ -600,38 +611,38 @@ void HostBlas::rotmg(float* d1, float* d2, float* x1, const float* y1, F77_FUNC_SROTMG(d1, d2, x1, y1, param); } template <> -void HostBlas::swap(int const N, float* X, int const incx, float* Y, - int const incy) { +void HostBlas::swap(KK_INT const N, float* X, KK_INT const incx, + float* Y, KK_INT const incy) { F77_FUNC_SSWAP(&N, X, &incx, Y, &incy); } template <> -void HostBlas::gemv(const char trans, int m, int n, const float alpha, - const float* a, int lda, const float* b, int ldb, - const float beta, - /* */ float* c, int ldc) { +void HostBlas::gemv(const char trans, KK_INT m, KK_INT n, + const float alpha, const float* a, KK_INT lda, + const float* b, KK_INT ldb, const float beta, + /* */ float* c, KK_INT ldc) { F77_FUNC_SGEMV(&trans, &m, &n, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } template <> -void HostBlas::ger(int m, int n, const float alpha, const float* x, - int incx, const float* y, int incy, float* a, - int lda) { +void HostBlas::ger(KK_INT m, KK_INT n, const float alpha, const float* x, + KK_INT incx, const float* y, KK_INT incy, float* a, + KK_INT lda) { F77_FUNC_SGER(&m, &n, &alpha, x, &incx, y, &incy, a, &lda); } template <> -void HostBlas::syr(const char uplo, int n, const float alpha, - const float* x, int incx, float* a, int lda) { +void HostBlas::syr(const char uplo, KK_INT n, const float alpha, + const float* x, KK_INT incx, float* a, KK_INT lda) { F77_FUNC_SSYR(&uplo, &n, &alpha, x, &incx, a, &lda); } template <> -void HostBlas::syr2(const char uplo, int n, const float alpha, - const float* x, int incx, const float* y, int incy, - float* a, int lda) { +void HostBlas::syr2(const char uplo, KK_INT n, const float alpha, + const float* x, KK_INT incx, const float* y, + KK_INT incy, float* a, KK_INT lda) { F77_FUNC_SSYR2(&uplo, &n, &alpha, x, &incx, y, &incy, a, &lda); } template <> void HostBlas::trsv(const char uplo, const char transa, const char diag, - int m, const float* a, int lda, - /* */ float* b, int ldb) { + KK_INT m, const float* a, KK_INT lda, + /* */ float* b, KK_INT ldb) { F77_FUNC_STRSV(&uplo, &transa, &diag, &m, a, &lda, b, &ldb); } template <> @@ -644,25 +655,25 @@ void HostBlas::gemm(const char transa, const char transb, KK_INT m, c, &ldc); } template <> -void HostBlas::herk(const char transa, const char transb, int n, int k, - const float alpha, const float* a, int lda, - const float beta, - /* */ float* c, int ldc) { +void HostBlas::herk(const char transa, const char transb, KK_INT n, + KK_INT k, const float alpha, const float* a, + KK_INT lda, const float beta, + /* */ float* c, KK_INT ldc) { F77_FUNC_SSYRK(&transa, &transb, &n, &k, &alpha, a, &lda, &beta, c, &ldc); } template <> void HostBlas::trmm(const char side, const char uplo, const char transa, - const char diag, int m, int n, const float alpha, - const float* a, int lda, - /* */ float* b, int ldb) { + const char diag, KK_INT m, KK_INT n, + const float alpha, const float* a, KK_INT lda, + /* */ float* b, KK_INT ldb) { F77_FUNC_STRMM(&side, &uplo, &transa, &diag, &m, &n, &alpha, a, &lda, b, &ldb); } template <> void HostBlas::trsm(const char side, const char uplo, const char transa, - const char diag, int m, int n, const float alpha, - const float* a, int lda, - /* */ float* b, int ldb) { + const char diag, KK_INT m, KK_INT n, + const float alpha, const float* a, KK_INT lda, + /* */ float* b, KK_INT ldb) { F77_FUNC_STRSM(&side, &uplo, &transa, &diag, &m, &n, &alpha, a, &lda, b, &ldb); } @@ -672,36 +683,36 @@ void HostBlas::trsm(const char side, const char uplo, const char transa, /// template <> -void HostBlas::scal(int n, const double alpha, - /* */ double* x, int x_inc) { +void HostBlas::scal(KK_INT n, const double alpha, + /* */ double* x, KK_INT x_inc) { F77_FUNC_DSCAL(&n, &alpha, x, &x_inc); } template <> -int HostBlas::iamax(int n, const double* x, int x_inc) { +KK_INT HostBlas::iamax(KK_INT n, const double* x, KK_INT x_inc) { return F77_FUNC_IDAMAX(&n, x, &x_inc); } template <> -double HostBlas::nrm2(int n, const double* x, int x_inc) { +double HostBlas::nrm2(KK_INT n, const double* x, KK_INT x_inc) { return F77_FUNC_DNRM2(&n, x, &x_inc); } template <> -double HostBlas::asum(int n, const double* x, int x_inc) { +double HostBlas::asum(KK_INT n, const double* x, KK_INT x_inc) { return F77_FUNC_DASUM(&n, x, &x_inc); } template <> -double HostBlas::dot(int n, const double* x, int x_inc, const double* y, - int y_inc) { +double HostBlas::dot(KK_INT n, const double* x, KK_INT x_inc, + const double* y, KK_INT y_inc) { return F77_FUNC_DDOT(&n, x, &x_inc, y, &y_inc); } template <> -void HostBlas::axpy(int n, const double alpha, const double* x, - int x_inc, - /* */ double* y, int y_inc) { +void HostBlas::axpy(KK_INT n, const double alpha, const double* x, + KK_INT x_inc, + /* */ double* y, KK_INT y_inc) { F77_FUNC_DAXPY(&n, &alpha, x, &x_inc, y, &y_inc); } template <> -void HostBlas::rot(int const N, double* X, int const incx, double* Y, - int const incy, double* c, double* s) { +void HostBlas::rot(KK_INT const N, double* X, KK_INT const incx, + double* Y, KK_INT const incy, double* c, double* s) { F77_FUNC_DROT(&N, X, &incx, Y, &incy, c, s); } template <> @@ -709,8 +720,8 @@ void HostBlas::rotg(double* a, double* b, double* c, double* s) { F77_FUNC_DROTG(a, b, c, s); } template <> -void HostBlas::rotm(const int n, double* X, const int incx, double* Y, - const int incy, const double* param) { +void HostBlas::rotm(const KK_INT n, double* X, const KK_INT incx, + double* Y, const KK_INT incy, const double* param) { F77_FUNC_DROTM(&n, X, &incx, Y, &incy, param); } template <> @@ -719,38 +730,39 @@ void HostBlas::rotmg(double* d1, double* d2, double* x1, F77_FUNC_DROTMG(d1, d2, x1, y1, param); } template <> -void HostBlas::swap(int const N, double* X, int const incx, double* Y, - int const incy) { +void HostBlas::swap(KK_INT const N, double* X, KK_INT const incx, + double* Y, KK_INT const incy) { F77_FUNC_DSWAP(&N, X, &incx, Y, &incy); } template <> -void HostBlas::gemv(const char trans, int m, int n, const double alpha, - const double* a, int lda, const double* b, int ldb, - const double beta, - /* */ double* c, int ldc) { +void HostBlas::gemv(const char trans, KK_INT m, KK_INT n, + const double alpha, const double* a, KK_INT lda, + const double* b, KK_INT ldb, const double beta, + /* */ double* c, KK_INT ldc) { F77_FUNC_DGEMV(&trans, &m, &n, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } template <> -void HostBlas::ger(int m, int n, const double alpha, const double* x, - int incx, const double* y, int incy, double* a, - int lda) { +void HostBlas::ger(KK_INT m, KK_INT n, const double alpha, + const double* x, KK_INT incx, const double* y, + KK_INT incy, double* a, KK_INT lda) { F77_FUNC_DGER(&m, &n, &alpha, x, &incx, y, &incy, a, &lda); } template <> -void HostBlas::syr(const char uplo, int n, const double alpha, - const double* x, int incx, double* a, int lda) { +void HostBlas::syr(const char uplo, KK_INT n, const double alpha, + const double* x, KK_INT incx, double* a, + KK_INT lda) { F77_FUNC_DSYR(&uplo, &n, &alpha, x, &incx, a, &lda); } template <> -void HostBlas::syr2(const char uplo, int n, const double alpha, - const double* x, int incx, const double* y, - int incy, double* a, int lda) { +void HostBlas::syr2(const char uplo, KK_INT n, const double alpha, + const double* x, KK_INT incx, const double* y, + KK_INT incy, double* a, KK_INT lda) { F77_FUNC_DSYR2(&uplo, &n, &alpha, x, &incx, y, &incy, a, &lda); } template <> void HostBlas::trsv(const char uplo, const char transa, const char diag, - int m, const double* a, int lda, - /* */ double* b, int ldb) { + KK_INT m, const double* a, KK_INT lda, + /* */ double* b, KK_INT ldb) { F77_FUNC_DTRSV(&uplo, &transa, &diag, &m, a, &lda, b, &ldb); } template <> @@ -763,25 +775,25 @@ void HostBlas::gemm(const char transa, const char transb, KK_INT m, c, &ldc); } template <> -void HostBlas::herk(const char transa, const char transb, int n, int k, - const double alpha, const double* a, int lda, - const double beta, - /* */ double* c, int ldc) { +void HostBlas::herk(const char transa, const char transb, KK_INT n, + KK_INT k, const double alpha, const double* a, + KK_INT lda, const double beta, + /* */ double* c, KK_INT ldc) { F77_FUNC_DSYRK(&transa, &transb, &n, &k, &alpha, a, &lda, &beta, c, &ldc); } template <> void HostBlas::trmm(const char side, const char uplo, const char transa, - const char diag, int m, int n, const double alpha, - const double* a, int lda, - /* */ double* b, int ldb) { + const char diag, KK_INT m, KK_INT n, + const double alpha, const double* a, KK_INT lda, + /* */ double* b, KK_INT ldb) { F77_FUNC_DTRMM(&side, &uplo, &transa, &diag, &m, &n, &alpha, a, &lda, b, &ldb); } template <> void HostBlas::trsm(const char side, const char uplo, const char transa, - const char diag, int m, int n, const double alpha, - const double* a, int lda, - /* */ double* b, int ldb) { + const char diag, KK_INT m, KK_INT n, + const double alpha, const double* a, KK_INT lda, + /* */ double* b, KK_INT ldb) { F77_FUNC_DTRSM(&side, &uplo, &transa, &diag, &m, &n, &alpha, a, &lda, b, &ldb); } @@ -791,31 +803,34 @@ void HostBlas::trsm(const char side, const char uplo, const char transa, /// template <> -void HostBlas >::scal(int n, +void HostBlas >::scal(KK_INT n, const std::complex alpha, /* */ std::complex* x, - int x_inc) { + KK_INT x_inc) { F77_FUNC_CSCAL(&n, &alpha, x, &x_inc); } template <> -int HostBlas >::iamax(int n, const std::complex* x, - int x_inc) { +KK_INT HostBlas >::iamax(KK_INT n, + const std::complex* x, + KK_INT x_inc) { return F77_FUNC_ICAMAX(&n, x, &x_inc); } template <> -float HostBlas >::nrm2(int n, const std::complex* x, - int x_inc) { +float HostBlas >::nrm2(KK_INT n, + const std::complex* x, + KK_INT x_inc) { return F77_FUNC_SCNRM2(&n, x, &x_inc); } template <> -float HostBlas >::asum(int n, const std::complex* x, - int x_inc) { +float HostBlas >::asum(KK_INT n, + const std::complex* x, + KK_INT x_inc) { return F77_FUNC_SCASUM(&n, x, &x_inc); } template <> std::complex HostBlas >::dot( - int n, const std::complex* x, int x_inc, - const std::complex* y, int y_inc) { + KK_INT n, const std::complex* x, KK_INT x_inc, + const std::complex* y, KK_INT y_inc) { #if defined(KOKKOSKERNELS_TPL_BLAS_RETURN_COMPLEX) _kk_float2 res = F77_FUNC_CDOTC(&n, x, &x_inc, y, &y_inc); return std::complex(res.vals[0], res.vals[1]); @@ -826,18 +841,20 @@ std::complex HostBlas >::dot( #endif } template <> -void HostBlas >::axpy(int n, +void HostBlas >::axpy(KK_INT n, const std::complex alpha, const std::complex* x, - int x_inc, + KK_INT x_inc, /* */ std::complex* y, - int y_inc) { + KK_INT y_inc) { F77_FUNC_CAXPY(&n, &alpha, x, &x_inc, y, &y_inc); } template <> -void HostBlas >::rot(int const N, std::complex* X, - int const incx, std::complex* Y, - int const incy, float* c, float* s) { +void HostBlas >::rot(KK_INT const N, std::complex* X, + KK_INT const incx, + std::complex* Y, + KK_INT const incy, float* c, + float* s) { F77_FUNC_CROT(&N, X, &incx, Y, &incy, c, s); } template <> @@ -847,38 +864,37 @@ void HostBlas >::rotg(std::complex* a, F77_FUNC_CROTG(a, b, c, s); } template <> -void HostBlas >::swap(int const N, std::complex* X, - int const incx, +void HostBlas >::swap(KK_INT const N, + std::complex* X, + KK_INT const incx, std::complex* Y, - int const incy) { + KK_INT const incy) { F77_FUNC_CSWAP(&N, X, &incx, Y, &incy); } template <> -void HostBlas >::gemv(const char trans, int m, int n, - const std::complex alpha, - const std::complex* a, int lda, - const std::complex* b, int ldb, - const std::complex beta, - /* */ std::complex* c, - int ldc) { +void HostBlas >::gemv( + const char trans, KK_INT m, KK_INT n, const std::complex alpha, + const std::complex* a, KK_INT lda, const std::complex* b, + KK_INT ldb, const std::complex beta, + /* */ std::complex* c, KK_INT ldc) { F77_FUNC_CGEMV(&trans, &m, &n, &alpha, (const std::complex*)a, &lda, (const std::complex*)b, &ldb, &beta, (std::complex*)c, &ldc); } template <> void HostBlas >::geru( - int m, int n, const std::complex alpha, const std::complex* x, - int incx, const std::complex* y, int incy, std::complex* a, - int lda) { + KK_INT m, KK_INT n, const std::complex alpha, + const std::complex* x, KK_INT incx, const std::complex* y, + KK_INT incy, std::complex* a, KK_INT lda) { F77_FUNC_CGERU(&m, &n, &alpha, (const std::complex*)x, &incx, (const std::complex*)y, &incy, (std::complex*)a, &lda); } template <> void HostBlas >::gerc( - int m, int n, const std::complex alpha, const std::complex* x, - int incx, const std::complex* y, int incy, std::complex* a, - int lda) { + KK_INT m, KK_INT n, const std::complex alpha, + const std::complex* x, KK_INT incx, const std::complex* y, + KK_INT incy, std::complex* a, KK_INT lda) { F77_FUNC_CGERC(&m, &n, &alpha, (const std::complex*)x, &incx, (const std::complex*)y, &incy, (std::complex*)a, &lda); @@ -886,26 +902,27 @@ void HostBlas >::gerc( template <> template <> void HostBlas >::cher( - const char uplo, int n, const float alpha, const std::complex* x, - int incx, std::complex* a, int lda) { + const char uplo, KK_INT n, const float alpha, const std::complex* x, + KK_INT incx, std::complex* a, KK_INT lda) { F77_FUNC_CHER(&uplo, &n, &alpha, (const std::complex*)x, &incx, (std::complex*)a, &lda); } template <> void HostBlas >::cher2( - const char uplo, int n, const std::complex alpha, - const std::complex* x, int incx, const std::complex* y, - int incy, std::complex* a, int lda) { + const char uplo, KK_INT n, const std::complex alpha, + const std::complex* x, KK_INT incx, const std::complex* y, + KK_INT incy, std::complex* a, KK_INT lda) { F77_FUNC_CHER2(&uplo, &n, &alpha, (const std::complex*)x, &incx, (const std::complex*)y, &incy, (std::complex*)a, &lda); } template <> void HostBlas >::trsv(const char uplo, const char transa, - const char diag, int m, - const std::complex* a, int lda, + const char diag, KK_INT m, + const std::complex* a, + KK_INT lda, /* */ std::complex* b, - int ldb) { + KK_INT ldb) { F77_FUNC_CTRSV(&uplo, &transa, &diag, &m, (const std::complex*)a, &lda, (std::complex*)b, &ldb); } @@ -921,37 +938,31 @@ void HostBlas >::gemm( (std::complex*)c, &ldc); } template <> -void HostBlas >::herk(const char transa, const char transb, - int n, int k, - const std::complex alpha, - const std::complex* a, int lda, - const std::complex beta, - /* */ std::complex* c, - int ldc) { +void HostBlas >::herk( + const char transa, const char transb, KK_INT n, KK_INT k, + const std::complex alpha, const std::complex* a, KK_INT lda, + const std::complex beta, + /* */ std::complex* c, KK_INT ldc) { F77_FUNC_CHERK(&transa, &transb, &n, &k, &alpha, (const std::complex*)a, &lda, &beta, (std::complex*)c, &ldc); } template <> -void HostBlas >::trmm(const char side, const char uplo, - const char transa, const char diag, - int m, int n, - const std::complex alpha, - const std::complex* a, int lda, - /* */ std::complex* b, - int ldb) { +void HostBlas >::trmm( + const char side, const char uplo, const char transa, const char diag, + KK_INT m, KK_INT n, const std::complex alpha, + const std::complex* a, KK_INT lda, + /* */ std::complex* b, KK_INT ldb) { F77_FUNC_CTRMM(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const std::complex*)a, &lda, (std::complex*)b, &ldb); } template <> -void HostBlas >::trsm(const char side, const char uplo, - const char transa, const char diag, - int m, int n, - const std::complex alpha, - const std::complex* a, int lda, - /* */ std::complex* b, - int ldb) { +void HostBlas >::trsm( + const char side, const char uplo, const char transa, const char diag, + KK_INT m, KK_INT n, const std::complex alpha, + const std::complex* a, KK_INT lda, + /* */ std::complex* b, KK_INT ldb) { F77_FUNC_CTRSM(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const std::complex*)a, &lda, (std::complex*)b, &ldb); @@ -962,33 +973,34 @@ void HostBlas >::trsm(const char side, const char uplo, /// template <> -void HostBlas >::scal(int n, +void HostBlas >::scal(KK_INT n, const std::complex alpha, /* */ std::complex* x, - int x_inc) { + KK_INT x_inc) { F77_FUNC_ZSCAL(&n, &alpha, x, &x_inc); } template <> -int HostBlas >::iamax(int n, const std::complex* x, - int x_inc) { +KK_INT HostBlas >::iamax(KK_INT n, + const std::complex* x, + KK_INT x_inc) { return F77_FUNC_IZAMAX(&n, x, &x_inc); } template <> -double HostBlas >::nrm2(int n, +double HostBlas >::nrm2(KK_INT n, const std::complex* x, - int x_inc) { + KK_INT x_inc) { return F77_FUNC_DZNRM2(&n, x, &x_inc); } template <> -double HostBlas >::asum(int n, +double HostBlas >::asum(KK_INT n, const std::complex* x, - int x_inc) { + KK_INT x_inc) { return F77_FUNC_DZASUM(&n, x, &x_inc); } template <> std::complex HostBlas >::dot( - int n, const std::complex* x, int x_inc, - const std::complex* y, int y_inc) { + KK_INT n, const std::complex* x, KK_INT x_inc, + const std::complex* y, KK_INT y_inc) { #if defined(KOKKOSKERNELS_TPL_BLAS_RETURN_COMPLEX) _kk_double2 res = F77_FUNC_ZDOTC(&n, x, &x_inc, y, &y_inc); return std::complex(res.vals[0], res.vals[1]); @@ -999,20 +1011,18 @@ std::complex HostBlas >::dot( #endif } template <> -void HostBlas >::axpy(int n, +void HostBlas >::axpy(KK_INT n, const std::complex alpha, const std::complex* x, - int x_inc, + KK_INT x_inc, /* */ std::complex* y, - int y_inc) { + KK_INT y_inc) { F77_FUNC_ZAXPY(&n, &alpha, x, &x_inc, y, &y_inc); } template <> -void HostBlas >::rot(int const N, std::complex* X, - int const incx, - std::complex* Y, - int const incy, double* c, - double* s) { +void HostBlas >::rot( + KK_INT const N, std::complex* X, KK_INT const incx, + std::complex* Y, KK_INT const incy, double* c, double* s) { F77_FUNC_ZROT(&N, X, &incx, Y, &incy, c, s); } template <> @@ -1022,36 +1032,37 @@ void HostBlas >::rotg(std::complex* a, F77_FUNC_ZROTG(a, b, c, s); } template <> -void HostBlas >::swap(int const N, std::complex* X, - int const incx, +void HostBlas >::swap(KK_INT const N, + std::complex* X, + KK_INT const incx, std::complex* Y, - int const incy) { + KK_INT const incy) { F77_FUNC_ZSWAP(&N, X, &incx, Y, &incy); } template <> void HostBlas >::gemv( - const char trans, int m, int n, const std::complex alpha, - const std::complex* a, int lda, const std::complex* b, - int ldb, const std::complex beta, - /* */ std::complex* c, int ldc) { + const char trans, KK_INT m, KK_INT n, const std::complex alpha, + const std::complex* a, KK_INT lda, const std::complex* b, + KK_INT ldb, const std::complex beta, + /* */ std::complex* c, KK_INT ldc) { F77_FUNC_ZGEMV(&trans, &m, &n, &alpha, (const std::complex*)a, &lda, (const std::complex*)b, &ldb, &beta, (std::complex*)c, &ldc); } template <> void HostBlas >::geru( - int m, int n, const std::complex alpha, - const std::complex* x, int incx, const std::complex* y, - int incy, std::complex* a, int lda) { + KK_INT m, KK_INT n, const std::complex alpha, + const std::complex* x, KK_INT incx, const std::complex* y, + KK_INT incy, std::complex* a, KK_INT lda) { F77_FUNC_ZGERU(&m, &n, &alpha, (const std::complex*)x, &incx, (const std::complex*)y, &incy, (std::complex*)a, &lda); } template <> void HostBlas >::gerc( - int m, int n, const std::complex alpha, - const std::complex* x, int incx, const std::complex* y, - int incy, std::complex* a, int lda) { + KK_INT m, KK_INT n, const std::complex alpha, + const std::complex* x, KK_INT incx, const std::complex* y, + KK_INT incy, std::complex* a, KK_INT lda) { F77_FUNC_ZGERC(&m, &n, &alpha, (const std::complex*)x, &incx, (const std::complex*)y, &incy, (std::complex*)a, &lda); @@ -1059,27 +1070,28 @@ void HostBlas >::gerc( template <> template <> void HostBlas >::zher( - const char uplo, int n, const double alpha, const std::complex* x, - int incx, std::complex* a, int lda) { + const char uplo, KK_INT n, const double alpha, + const std::complex* x, KK_INT incx, std::complex* a, + KK_INT lda) { F77_FUNC_ZHER(&uplo, &n, &alpha, (const std::complex*)x, &incx, (std::complex*)a, &lda); } template <> void HostBlas >::zher2( - const char uplo, int n, const std::complex alpha, - const std::complex* x, int incx, const std::complex* y, - int incy, std::complex* a, int lda) { + const char uplo, KK_INT n, const std::complex alpha, + const std::complex* x, KK_INT incx, const std::complex* y, + KK_INT incy, std::complex* a, KK_INT lda) { F77_FUNC_ZHER2(&uplo, &n, &alpha, (const std::complex*)x, &incx, (const std::complex*)y, &incy, (std::complex*)a, &lda); } template <> void HostBlas >::trsv(const char uplo, const char transa, - const char diag, int m, + const char diag, KK_INT m, const std::complex* a, - int lda, + KK_INT lda, /* */ std::complex* b, - int ldb) { + KK_INT ldb) { F77_FUNC_ZTRSV(&uplo, &transa, &diag, &m, (const std::complex*)a, &lda, (std::complex*)b, &ldb); } @@ -1097,30 +1109,30 @@ void HostBlas >::gemm( } template <> void HostBlas >::herk( - const char transa, const char transb, int n, int k, - const std::complex alpha, const std::complex* a, int lda, + const char transa, const char transb, KK_INT n, KK_INT k, + const std::complex alpha, const std::complex* a, KK_INT lda, const std::complex beta, - /* */ std::complex* c, int ldc) { + /* */ std::complex* c, KK_INT ldc) { F77_FUNC_ZHERK(&transa, &transb, &n, &k, &alpha, (const std::complex*)a, &lda, &beta, (std::complex*)c, &ldc); } template <> void HostBlas >::trmm( - const char side, const char uplo, const char transa, const char diag, int m, - int n, const std::complex alpha, const std::complex* a, - int lda, - /* */ std::complex* b, int ldb) { + const char side, const char uplo, const char transa, const char diag, + KK_INT m, KK_INT n, const std::complex alpha, + const std::complex* a, KK_INT lda, + /* */ std::complex* b, KK_INT ldb) { F77_FUNC_ZTRMM(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const std::complex*)a, &lda, (std::complex*)b, &ldb); } template <> void HostBlas >::trsm( - const char side, const char uplo, const char transa, const char diag, int m, - int n, const std::complex alpha, const std::complex* a, - int lda, - /* */ std::complex* b, int ldb) { + const char side, const char uplo, const char transa, const char diag, + KK_INT m, KK_INT n, const std::complex alpha, + const std::complex* a, KK_INT lda, + /* */ std::complex* b, KK_INT ldb) { F77_FUNC_ZTRSM(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const std::complex*)a, &lda, (std::complex*)b, &ldb); diff --git a/blas/tpls/KokkosBlas_Host_tpl.hpp b/blas/tpls/KokkosBlas_Host_tpl.hpp index 8e8781bfcf..5fb7c1f624 100644 --- a/blas/tpls/KokkosBlas_Host_tpl.hpp +++ b/blas/tpls/KokkosBlas_Host_tpl.hpp @@ -43,87 +43,88 @@ struct HostBlas { typedef Kokkos::ArithTraits ats; typedef typename ats::mag_type mag_type; - static void scal(int n, const T alpha, - /* */ T *x, int x_inc); + static void scal(KK_INT n, const T alpha, + /* */ T *x, KK_INT x_inc); - static int iamax(int n, const T *x, int x_inc); + static KK_INT iamax(KK_INT n, const T *x, KK_INT x_inc); - static mag_type nrm2(int n, const T *x, int x_inc); + static mag_type nrm2(KK_INT n, const T *x, KK_INT x_inc); - static mag_type asum(int n, const T *x, int x_inc); + static mag_type asum(KK_INT n, const T *x, KK_INT x_inc); - static T dot(int n, const T *x, int x_inc, const T *y, int y_inc); + static T dot(KK_INT n, const T *x, KK_INT x_inc, const T *y, KK_INT y_inc); - static void axpy(int n, const T alpha, const T *x, int x_inc, - /* */ T *y, int y_inc); + static void axpy(KK_INT n, const T alpha, const T *x, KK_INT x_inc, + /* */ T *y, KK_INT y_inc); - static void rot(int const N, T *X, int const incx, T *Y, int const incy, - mag_type *c, mag_type *s); + static void rot(KK_INT const N, T *X, KK_INT const incx, T *Y, + KK_INT const incy, mag_type *c, mag_type *s); static void rotg(T *a, T *b, mag_type *c, T *s); - static void rotm(const int n, T *X, const int incx, T *Y, const int incy, - T const *param); + static void rotm(const KK_INT n, T *X, const KK_INT incx, T *Y, + const KK_INT incy, T const *param); static void rotmg(T *d1, T *d2, T *x1, const T *y1, T *param); - static void swap(int const N, T *X, int const incx, T *Y, int const incy); + static void swap(KK_INT const N, T *X, KK_INT const incx, T *Y, + KK_INT const incy); - static void gemv(const char trans, int m, int n, const T alpha, const T *a, - int lda, const T *b, int ldb, const T beta, - /* */ T *c, int ldc); + static void gemv(const char trans, KK_INT m, KK_INT n, const T alpha, + const T *a, KK_INT lda, const T *b, KK_INT ldb, const T beta, + /* */ T *c, KK_INT ldc); - static void ger(int m, int n, const T alpha, const T *x, int incx, const T *y, - int incy, T *a, int lda); + static void ger(KK_INT m, KK_INT n, const T alpha, const T *x, KK_INT incx, + const T *y, KK_INT incy, T *a, KK_INT lda); - static void geru(int m, int n, const T alpha, const T *x, int incx, - const T *y, int incy, T *a, int lda); + static void geru(KK_INT m, KK_INT n, const T alpha, const T *x, KK_INT incx, + const T *y, KK_INT incy, T *a, KK_INT lda); - static void gerc(int m, int n, const T alpha, const T *x, int incx, - const T *y, int incy, T *a, int lda); + static void gerc(KK_INT m, KK_INT n, const T alpha, const T *x, KK_INT incx, + const T *y, KK_INT incy, T *a, KK_INT lda); - static void syr(const char uplo, int n, const T alpha, const T *x, int incx, - T *a, int lda); + static void syr(const char uplo, KK_INT n, const T alpha, const T *x, + KK_INT incx, T *a, KK_INT lda); - static void syr2(const char uplo, int n, const T alpha, const T *x, int incx, - const T *y, int incy, T *a, int lda); + static void syr2(const char uplo, KK_INT n, const T alpha, const T *x, + KK_INT incx, const T *y, KK_INT incy, T *a, KK_INT lda); template - static void cher(const char uplo, int n, const tAlpha alpha, const T *x, - int incx, T *a, int lda); + static void cher(const char uplo, KK_INT n, const tAlpha alpha, const T *x, + KK_INT incx, T *a, KK_INT lda); template - static void zher(const char uplo, int n, const tAlpha alpha, const T *x, - int incx, T *a, int lda); + static void zher(const char uplo, KK_INT n, const tAlpha alpha, const T *x, + KK_INT incx, T *a, KK_INT lda); - static void cher2(const char uplo, int n, const T alpha, const T *x, int incx, - const T *y, int incy, T *a, int lda); + static void cher2(const char uplo, KK_INT n, const T alpha, const T *x, + KK_INT incx, const T *y, KK_INT incy, T *a, KK_INT lda); - static void zher2(const char uplo, int n, const T alpha, const T *x, int incx, - const T *y, int incy, T *a, int lda); + static void zher2(const char uplo, KK_INT n, const T alpha, const T *x, + KK_INT incx, const T *y, KK_INT incy, T *a, KK_INT lda); - static void trsv(const char uplo, const char transa, const char diag, int m, - const T *a, int lda, - /* */ T *b, int ldb); + static void trsv(const char uplo, const char transa, const char diag, + KK_INT m, const T *a, KK_INT lda, + /* */ T *b, KK_INT ldb); static void gemm(const char transa, const char transb, KK_INT m, KK_INT n, KK_INT k, const T alpha, const T *a, KK_INT lda, const T *b, KK_INT ldb, const T beta, /* */ T *c, KK_INT ldc); - static void herk(const char transa, const char transb, int n, int k, - const T alpha, const T *a, int lda, const T beta, - /* */ T *c, int ldc); + static void herk(const char transa, const char transb, KK_INT n, KK_INT k, + const T alpha, const T *a, KK_INT lda, const T beta, + /* */ T *c, KK_INT ldc); static void trmm(const char side, const char uplo, const char transa, - const char diag, int m, int n, const T alpha, const T *a, - int lda, - /* */ T *b, int ldb); + const char diag, KK_INT m, KK_INT n, const T alpha, + const T *a, KK_INT lda, + /* */ T *b, KK_INT ldb); static void trsm(const char side, const char uplo, const char transa, - const char diag, int m, int n, const T alpha, const T *a, - int lda, - /* */ T *b, int ldb); + const char diag, KK_INT m, KK_INT n, const T alpha, + const T *a, KK_INT lda, + /* */ T *b, KK_INT ldb); }; } // namespace Impl } // namespace KokkosBlas