From 3b29fe2774ee093701c41964b8fe1cc2622ad25e Mon Sep 17 00:00:00 2001 From: Mark Gates Date: Fri, 24 May 2024 20:53:01 -0400 Subject: [PATCH 1/5] update sub-repos --- blaspp | 2 +- lapackpp | 2 +- testsweeper | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/blaspp b/blaspp index 92ad3b4f1..b852daac3 160000 --- a/blaspp +++ b/blaspp @@ -1 +1 @@ -Subproject commit 92ad3b4f147b7e48a062d12b040207a8e5505c2b +Subproject commit b852daac3675d03fbe1c4415e0e17ed175e40194 diff --git a/lapackpp b/lapackpp index 7bbdb5877..b19a2bb63 160000 --- a/lapackpp +++ b/lapackpp @@ -1 +1 @@ -Subproject commit 7bbdb5877c505ba261f2ffd58fd9c38b41a328cf +Subproject commit b19a2bb63ea0c93849f5840b6d361c2fc00ac35c diff --git a/testsweeper b/testsweeper index 023592af2..edf78e95a 160000 --- a/testsweeper +++ b/testsweeper @@ -1 +1 @@ -Subproject commit 023592af26481673f64ef315ace7e288f95844ba +Subproject commit edf78e95a342b5d541693f5343cd56c5bf623a97 From 66eabbdf2fe30b4e0ab7c043ba1c40cd8ed52be6 Mon Sep 17 00:00:00 2001 From: Mark Gates Date: Tue, 21 May 2024 01:19:07 -0400 Subject: [PATCH 2/5] change methods to enums, for consistency with other options --- GNUmakefile | 1 + include/slate/c_api/types.h | 62 +++- include/slate/enums.hh | 425 ++++++++++++++++++++++- include/slate/method.hh | 319 ----------------- include/slate/slate.hh | 2 - include/slate/types.hh | 40 ++- src/cholqr.cc | 18 +- src/core/enums.cc | 43 +++ src/gels.cc | 16 +- src/gemm.cc | 25 +- src/getrf.cc | 2 +- src/getrs.cc | 2 +- src/hemm.cc | 25 +- src/trsm.cc | 24 +- test/test.cc | 39 +-- test/test.hh | 34 +- test/test_gels.cc | 20 +- test/test_gemm.cc | 6 +- test/test_geqrf.cc | 4 +- test/test_gesv.cc | 8 +- test/test_hemm.cc | 8 +- test/test_her2k.cc | 2 +- test/test_herk.cc | 2 +- test/test_posv.cc | 8 +- test/test_trsm.cc | 6 +- tools/fortran/generate_fortran_module.py | 10 +- 26 files changed, 697 insertions(+), 454 deletions(-) delete mode 100644 include/slate/method.hh create mode 100644 src/core/enums.cc diff --git a/GNUmakefile b/GNUmakefile index edeb2461d..1f92015bc 100644 --- a/GNUmakefile +++ b/GNUmakefile @@ -445,6 +445,7 @@ libslate_src += \ src/auxiliary/Debug.cc \ src/auxiliary/Trace.cc \ src/core/Memory.cc \ + src/core/enums.cc \ src/core/types.cc \ src/version.cc \ # End. Add alphabetically. diff --git a/include/slate/c_api/types.h b/include/slate/c_api/types.h index 7be6d3833..28e9e9f03 100644 --- a/include/slate/c_api/types.h +++ b/include/slate/c_api/types.h @@ -32,11 +32,62 @@ const slate_Target slate_Target_HostBatch = 'B'; ///< slate::Target::HostBatch const slate_Target slate_Target_Devices = 'D'; ///< slate::Target::Devices // end slate_Target -typedef char slate_MethodEig; /* enum */ ///< slate::MethodEig -const slate_MethodEig slate_MethodEig_QR = 'Q'; ///< slate::MethodEig::QR -const slate_MethodEig slate_MethodEig_DC = 'D'; ///< slate::MethodEig::DC +typedef char slate_MethodTrsm; /* enum */ ///< slate::MethodTrsm +const slate_MethodTrsm slate_MethodTrsm_Auto = '*'; ///< slate::MethodTrsm::Auto +const slate_MethodTrsm slate_MethodTrsm_A = 'A'; ///< slate::MethodTrsm::A +const slate_MethodTrsm slate_MethodTrsm_B = 'B'; ///< slate::MethodTrsm::B +// end slate_MethodTrsm + +typedef char slate_MethodGemm; /* enum */ ///< slate::MethodGemm +const slate_MethodGemm slate_MethodGemm_Auto = '*'; ///< slate::MethodGemm::Auto +const slate_MethodGemm slate_MethodGemm_A = 'A'; ///< slate::MethodGemm::A +const slate_MethodGemm slate_MethodGemm_C = 'C'; ///< slate::MethodGemm::C +// end slate_MethodGemm + +typedef char slate_MethodHemm; /* enum */ ///< slate::MethodHemm +const slate_MethodHemm slate_MethodHemm_Auto = '*'; ///< slate::MethodHemm::Auto +const slate_MethodHemm slate_MethodHemm_A = 'A'; ///< slate::MethodHemm::A +const slate_MethodHemm slate_MethodHemm_C = 'C'; ///< slate::MethodHemm::C +// end slate_MethodHemm + +typedef char slate_MethodCholQR; /* enum */ ///< slate::MethodCholQR +const slate_MethodCholQR slate_MethodCholQR_Auto = '*'; ///< slate::MethodCholQR::Auto +const slate_MethodCholQR slate_MethodCholQR_GemmA = 'A'; ///< slate::MethodCholQR::GemmA +const slate_MethodCholQR slate_MethodCholQR_GemmC = 'C'; ///< slate::MethodCholQR::GemmC +const slate_MethodCholQR slate_MethodCholQR_HerkA = 'R'; ///< slate::MethodCholQR::HerkA +const slate_MethodCholQR slate_MethodCholQR_HerkC = 'K'; ///< slate::MethodCholQR::HerkC +// end slate_MethodCholQR + +typedef char slate_MethodGels; /* enum */ ///< slate::MethodGels +const slate_MethodGels slate_MethodGels_Auto = '*'; ///< slate::MethodGels::Auto +const slate_MethodGels slate_MethodGels_QR = 'Q'; ///< slate::MethodGels::QR +const slate_MethodGels slate_MethodGels_CholQR = 'C'; ///< slate::MethodGels::CholQR +// end slate_MethodGels + +typedef char slate_MethodLU; /* enum */ ///< slate::MethodLU +const slate_MethodLU slate_MethodLU_Auto = '*'; ///< slate::MethodLU::Auto +const slate_MethodLU slate_MethodLU_PartialPiv = 'P'; ///< slate::MethodLU::PartialPiv +const slate_MethodLU slate_MethodLU_CALU = 'C'; ///< slate::MethodLU::CALU +const slate_MethodLU slate_MethodLU_NoPiv = 'N'; ///< slate::MethodLU::NoPiv +const slate_MethodLU slate_MethodLU_RBT = 'R'; ///< slate::MethodLU::RBT +const slate_MethodLU slate_MethodLU_BEAM = 'B'; ///< slate::MethodLU::BEAM +// end slate_MethodLU + +typedef char slate_MethodEig; /* enum */ ///< slate::MethodEig +const slate_MethodEig slate_MethodEig_Auto = '*'; ///< slate::MethodEig::Auto +const slate_MethodEig slate_MethodEig_QR = 'Q'; ///< slate::MethodEig::QR +const slate_MethodEig slate_MethodEig_DC = 'D'; ///< slate::MethodEig::DC +const slate_MethodEig slate_MethodEig_Bisection = 'B'; ///< slate::MethodEig::Bisection +const slate_MethodEig slate_MethodEig_MRRR = 'M'; ///< slate::MethodEig::MRRR // end slate_MethodEig +typedef char slate_MethodSVD; /* enum */ ///< slate::MethodSVD +const slate_MethodSVD slate_MethodSVD_Auto = '*'; ///< slate::MethodSVD::Auto +const slate_MethodSVD slate_MethodSVD_QR = 'Q'; ///< slate::MethodSVD::QR +const slate_MethodSVD slate_MethodSVD_DC = 'D'; ///< slate::MethodSVD::DC +const slate_MethodSVD slate_MethodSVD_Bisection = 'B'; ///< slate::MethodSVD::Bisection +// end slate_MethodSVD + // todo: auto sync with include/slate/enums.hh typedef char slate_Option; /* enum */ ///< slate::Option const slate_Option slate_Option_ChunkSize = 0; ///< slate::Option::ChunkSize @@ -66,11 +117,6 @@ const slate_Option slate_Option_MethodTrsm = 66; ///< slate::Option::M typedef short slate_MOSI_State; -//------------------------------------------------------------------------------ -// slate/include/slate/types.hh - -typedef int slate_Method; - //------------------------------------------------------------------------------ // blaspp/include/blas_util.hh diff --git a/include/slate/enums.hh b/include/slate/enums.hh index 09f37ced2..ad8ea816f 100644 --- a/include/slate/enums.hh +++ b/include/slate/enums.hh @@ -9,23 +9,27 @@ #ifndef SLATE_ENUMS_HH #define SLATE_ENUMS_HH +#include "slate/Exception.hh" + #include #include +#include + namespace slate { -typedef blas::Op Op; -typedef blas::Uplo Uplo; -typedef blas::Diag Diag; -typedef blas::Side Side; -typedef blas::Layout Layout; +using blas::Op; +using blas::Uplo; +using blas::Diag; +using blas::Side; +using blas::Layout; using lapack::Equed; using lapack::RowCol; -typedef lapack::Norm Norm; -typedef lapack::Direction Direction; +using lapack::Norm; +using lapack::Direction; -typedef lapack::Job Job; +using lapack::Job; //------------------------------------------------------------------------------ /// Location and method of computation. @@ -48,14 +52,408 @@ template class TargetType {}; } // namespace internal //------------------------------------------------------------------------------ -/// Eigenvalue algorithm to used in heev routine. -/// @ingroup enum +// Methods + +//------------------------------------------------------------------------------ +/// Algorithm to use for triangular solve (trsm). +/// @ingroup method +/// +enum class MethodTrsm : char { + Auto = '*', ///< Let SLATE decide + A = 'A', ///< Matrix A is stationary, B is sent; use when B is small + B = 'B', ///< Matrix B is stationary, A is sent; use when B is large + TrsmA [[deprecated("Use A. To be removed 2025-05.")]] = 'A', + TrsmB [[deprecated("Use B. To be removed 2025-05.")]] = 'B', +}; + +extern const char* MethodTrsm_help; + +//----------------------------------- +inline const char* to_c_string( MethodTrsm value ) +{ + switch (value) { + case MethodTrsm::Auto: return "auto"; + case MethodTrsm::A: return "A"; + case MethodTrsm::B: return "B"; + } + return "?"; +} + +//----------------------------------- +inline std::string to_string( MethodTrsm value ) +{ + return to_c_string( value ); +} + +//----------------------------------- +inline void from_string( std::string const& str, MethodTrsm* val ) +{ + std::string str_ = str; + std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower ); + + if (str_ == "auto") + *val = MethodTrsm::Auto; + else if (str_ == "a" || str_ == "trsma") + *val = MethodTrsm::A; + else if (str_ == "b" || str_ == "trsmb") + *val = MethodTrsm::B; + else + throw Exception( "unknown trsm method: " + str ); +} + +//------------------------------------------------------------------------------ +/// Algorithm to use for general matrix multiply (gemm). +/// @ingroup method +/// +enum class MethodGemm : char { + Auto = '*', ///< Let SLATE decide + A = 'A', ///< Matrix A is stationary, C is sent; use when C is small + C = 'C', ///< Matrix C is stationary, A is sent; use when C is large + GemmA [[deprecated("Use A. To be removed 2025-05.")]] = 'A', + GemmC [[deprecated("Use C. To be removed 2025-05.")]] = 'C', +}; + +extern const char* MethodGemm_help; + +//----------------------------------- +inline const char* to_c_string( MethodGemm value ) +{ + switch (value) { + case MethodGemm::Auto: return "auto"; + case MethodGemm::A: return "A"; + case MethodGemm::C: return "C"; + } + return "?"; +} + +//----------------------------------- +inline std::string to_string( MethodGemm value ) +{ + return to_c_string( value ); +} + +//----------------------------------- +inline void from_string( std::string const& str, MethodGemm* val ) +{ + std::string str_ = str; + std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower ); + + if (str_ == "auto") + *val = MethodGemm::Auto; + else if (str_ == "a" || str_ == "gemma") + *val = MethodGemm::A; + else if (str_ == "c" || str_ == "gemmc") + *val = MethodGemm::C; + else + throw Exception( "unknown gemm method: " + str ); +} + +//------------------------------------------------------------------------------ +/// Algorithm to use for Hermitian matrix multiply (hemm). +/// @ingroup method +/// +enum class MethodHemm : char { + Auto = '*', ///< Let SLATE decide + A = 'A', ///< Matrix A is stationary, C is sent; use when C is small + C = 'C', ///< Matrix C is stationary, A is sent; use when C is large + HemmA [[deprecated("Use A. To be removed 2025-05.")]] = 'A', + HemmC [[deprecated("Use C. To be removed 2025-05.")]] = 'C', +}; + +extern const char* MethodHemm_help; + +//----------------------------------- +inline const char* to_c_string( MethodHemm value ) +{ + switch (value) { + case MethodHemm::Auto: return "auto"; + case MethodHemm::A: return "A"; + case MethodHemm::C: return "C"; + } + return "?"; +} + +//----------------------------------- +inline std::string to_string( MethodHemm value ) +{ + return to_c_string( value ); +} + +//----------------------------------- +inline void from_string( std::string const& str, MethodHemm* val ) +{ + std::string str_ = str; + std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower ); + + if (str_ == "auto") + *val = MethodHemm::Auto; + else if (str_ == "a" || str_ == "hemma") + *val = MethodHemm::A; + else if (str_ == "c" || str_ == "hemmc") + *val = MethodHemm::C; + else + throw Exception( "unknown hemm method: " + str ); +} + +//------------------------------------------------------------------------------ +/// Algorithm to use for Cholesky QR. +/// @ingroup method +/// +enum class MethodCholQR : char { + Auto = '*', ///< Let SLATE decide + GemmA = 'A', ///< Use gemm-A algorithm to compute A^H A + GemmC = 'C', ///< Use gemm-C algorithm to compute A^H A + HerkA = 'R', ///< Use herk-A algorithm to compute A^H A; not yet implemented + HerkC = 'K', ///< Use herk-C algorithm to compute A^H A +}; + +extern const char* MethodCholQR_help; + +//----------------------------------- +inline const char* to_c_string( MethodCholQR value ) +{ + switch (value) { + case MethodCholQR::Auto: return "auto"; + case MethodCholQR::GemmA: return "gemmA"; + case MethodCholQR::GemmC: return "gemmC"; + case MethodCholQR::HerkA: return "herkA"; + case MethodCholQR::HerkC: return "herkC"; + } + return "?"; +} + +//----------------------------------- +inline std::string to_string( MethodCholQR value ) +{ + return to_c_string( value ); +} + +//----------------------------------- +inline void from_string( std::string const& str, MethodCholQR* val ) +{ + std::string str_ = str; + std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower ); + + if (str_ == "auto") + *val = MethodCholQR::Auto; + else if (str_ == "gemma") + *val = MethodCholQR::GemmA; + else if (str_ == "gemmc") + *val = MethodCholQR::GemmC; + else if (str_ == "herka") + *val = MethodCholQR::HerkA; + else if (str_ == "herkc") + *val = MethodCholQR::HerkC; + else + throw Exception( "unknown Cholesky QR method: " + str ); +} + +//------------------------------------------------------------------------------ +/// Algorithm to use for least squares (gels). +/// @ingroup method +/// +enum class MethodGels : char { + Auto = '*', ///< Let SLATE decide + QR = 'Q', ///< Use Householder QR factorization + CholQR = 'C', ///< Use Cholesky QR factorization; use when A is well-conditioned + Geqrf [[deprecated("Use QR. To be removed 2025-05.")]] = 'Q', + Cholqr [[deprecated("Use CholQR. To be removed 2025-05.")]] = 'C', +}; + +extern const char* MethodGels_help; + +//----------------------------------- +inline const char* to_c_string( MethodGels value ) +{ + switch (value) { + case MethodGels::Auto: return "auto"; + case MethodGels::QR: return "QR"; + case MethodGels::CholQR: return "CholQR"; + } + return "?"; +} + +//----------------------------------- +inline std::string to_string( MethodGels value ) +{ + return to_c_string( value ); +} + +//----------------------------------- +inline void from_string( std::string const& str, MethodGels* val ) +{ + std::string str_ = str; + std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower ); + + if (str_ == "auto") + *val = MethodGels::Auto; + else if (str_ == "qr" || str_ == "geqrf") + *val = MethodGels::QR; + else if (str_ == "cholqr") + *val = MethodGels::CholQR; + else + throw Exception( "unknown least squares (gels) method: " + str ); +} + +//------------------------------------------------------------------------------ +/// Algorithm to use for LU factorization and solve. +/// @ingroup method +/// +enum class MethodLU : char { + Auto = '*', ///< Let SLATE decide + PartialPiv = 'P', ///< Use classical partial pivoting + CALU = 'C', ///< Use Communication Avoiding LU (CALU) + NoPiv = 'N', ///< Use no-pivoting LU + RBT = 'R', ///< Use Random Butterfly Transform (RBT) + BEAM = 'B', ///< Use BEAM LU factorization +}; + +extern const char* MethodLU_help; + +//----------------------------------- +inline const char* to_c_string( MethodLU value ) +{ + switch (value) { + case MethodLU::Auto: return "auto"; + case MethodLU::PartialPiv: return "PPLU"; + case MethodLU::CALU: return "CALU"; + case MethodLU::NoPiv: return "NoPiv"; + case MethodLU::RBT: return "RBT"; + case MethodLU::BEAM: return "BEAM"; + } + return "?"; +} + +//----------------------------------- +inline std::string to_string( MethodLU value ) +{ + return to_c_string( value ); +} + +//----------------------------------- +inline void from_string( std::string const& str, MethodLU* val ) +{ + std::string str_ = str; + std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower ); + + if (str_ == "auto") + *val = MethodLU::Auto; + else if (str_ == "pplu" || str_ == "partialpiv") + *val = MethodLU::PartialPiv; + else if (str_ == "calu") + *val = MethodLU::CALU; + else if (str_ == "nopiv") + *val = MethodLU::NoPiv; + else if (str_ == "rbt") + *val = MethodLU::RBT; + else if (str_ == "beam") + *val = MethodLU::BEAM; + else + throw Exception( "unknown LU method: " + str ); +} + +//------------------------------------------------------------------------------ +/// Algorithm to use for eigenvalues (eig). +/// @ingroup method /// enum class MethodEig : char { - QR = 'Q', ///< QR iteration for finding eigenvalues - DC = 'D', ///< Divide and conquer algorithm for finding eigenvalues + Auto = '*', ///< Let SLATE decide + QR = 'Q', ///< QR iteration + DC = 'D', ///< Divide and conquer + Bisection = 'B', ///< Bisection; not yet implemented + MRRR = 'M', ///< Multiple Relatively Robust Representations (MRRR); not yet implemented +}; + +extern const char* MethodEig_help; + +//----------------------------------- +inline const char* to_c_string( MethodEig value ) +{ + switch (value) { + case MethodEig::Auto: return "auto"; + case MethodEig::QR: return "QR"; + case MethodEig::DC: return "DC"; + case MethodEig::Bisection: return "Bisection"; + case MethodEig::MRRR: return "MRRR"; + } + return "?"; +} + +//----------------------------------- +inline std::string to_string( MethodEig value ) +{ + return to_c_string( value ); +} + +//----------------------------------- +inline void from_string( std::string const& str, MethodEig* val ) +{ + std::string str_ = str; + std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower ); + + if (str_ == "auto") + *val = MethodEig::Auto; + else if (str_ == "qr") + *val = MethodEig::QR; + else if (str_ == "dc") + *val = MethodEig::DC; + else if (str_ == "bisection") + *val = MethodEig::Bisection; + else if (str_ == "mrrr") + *val = MethodEig::MRRR; + else + throw Exception( "unknown eig method: " + str ); +} + +//------------------------------------------------------------------------------ +/// Algorithm to use for singular value decomposition (SVD). +/// @ingroup method +/// +enum class MethodSVD : char { + Auto = '*', ///< Let SLATE decide + QR = 'Q', ///< QR iteration + DC = 'D', ///< Divide and conquer; not yet implemented + Bisection = 'B', ///< Bisection; not yet implemented }; +extern const char* MethodSVD_help; + +//----------------------------------- +inline const char* to_c_string( MethodSVD value ) +{ + switch (value) { + case MethodSVD::Auto: return "auto"; + case MethodSVD::QR: return "QR"; + case MethodSVD::DC: return "DC"; + case MethodSVD::Bisection: return "Bisection"; + } + return "?"; +} + +//----------------------------------- +inline std::string to_string( MethodSVD value ) +{ + return to_c_string( value ); +} + +//----------------------------------- +inline void from_string( std::string const& str, MethodSVD* val ) +{ + std::string str_ = str; + std::transform( str_.begin(), str_.end(), str_.begin(), ::tolower ); + + if (str_ == "auto") + *val = MethodSVD::Auto; + else if (str_ == "qr") + *val = MethodSVD::QR; + else if (str_ == "dc") + *val = MethodSVD::DC; + else if (str_ == "bisection") + *val = MethodSVD::Bisection; + else + throw Exception( "unknown SVD method: " + str ); +} + //------------------------------------------------------------------------------ /// Keys for options to pass to SLATE routines. /// @ingroup enum @@ -89,13 +487,14 @@ enum class Option : char { ///< For correct printing, PrintWidth = PrintPrecision + 6. // Methods, listed alphabetically. - MethodCholQR = 60, ///< Select the algorithm to compute A^H * A + MethodCholQR = 60, ///< Select the algorithm to compute A^H A MethodEig, ///< Select the algorithm to compute eigenpairs of tridiagonal matrix MethodGels, ///< Select the gels algorithm MethodGemm, ///< Select the gemm algorithm MethodHemm, ///< Select the hemm algorithm MethodLU, ///< Select the LU (getrf) algorithm MethodTrsm, ///< Select the trsm algorithm + MethodSVD, ///< Select the algorithm to compute singular values of bidiagonal matrix }; //------------------------------------------------------------------------------ diff --git a/include/slate/method.hh b/include/slate/method.hh deleted file mode 100644 index 9917be7ab..000000000 --- a/include/slate/method.hh +++ /dev/null @@ -1,319 +0,0 @@ -// Copyright (c) 2017-2023, University of Tennessee. All rights reserved. -// SPDX-License-Identifier: BSD-3-Clause -// This program is free software: you can redistribute it and/or modify it under -// the terms of the BSD 3-Clause license. See the accompanying LICENSE file. - -//------------------------------------------------------------------------------ -/// @file -/// -#ifndef SLATE_METHOD_HH -#define SLATE_METHOD_HH - -#include "slate/types.hh" - -namespace slate { - -// This defines default values that MUST be considered in the inner namespaces -constexpr char baseMethodError_str[] = "error"; -constexpr char baseMethodAuto_str[] = "auto"; - -const Method baseMethodError = -1; -const Method baseMethodAuto = 0; - -//------------------------------------------------------------------------------ -/// Select the right algorithm to perform the trsm -namespace MethodTrsm { - - constexpr char TrsmA_str[] = "A"; - constexpr char TrsmB_str[] = "B"; - const Method Error = baseMethodError; - const Method Auto = baseMethodAuto; - const Method TrsmA = 1; ///< Select trsmA algorithm - const Method TrsmB = 2; ///< Select trsmB algorithm - - template - inline Method select_algo(TA& A, TB& B, Options const& opts) { - Target target = get_option( opts, Option::Target, Target::HostTask ); - int n_devices = A.num_devices(); - - Method method = (B.nt() < 2 ? TrsmA : TrsmB); - - if (method == TrsmA && target == Target::Devices && n_devices > 1) - method = TrsmB; - - return method; - } - - inline Method str2methodTrsm(const char* method) - { - std::string method_ = method; - std::transform( - method_.begin(), method_.end(), method_.begin(), ::tolower ); - - if (method_ == "auto") - return Auto; - else if (method_ == "a" || method_ == "trsma") - return TrsmA; - else if (method_ == "b" || method_ == "trsmb") - return TrsmB; - else - throw slate::Exception("unknown trsm method"); - } - - inline const char* methodTrsm2str(Method method) - { - switch (method) { - case Auto: return baseMethodAuto_str; - case TrsmA: return TrsmA_str; - case TrsmB: return TrsmB_str; - default: return baseMethodError_str; - } - } - -} // namespace MethodTrsm - -//------------------------------------------------------------------------------ -/// Select the right algorithm to perform the gemm -namespace MethodGemm { - - constexpr char GemmA_str[] = "A"; - constexpr char GemmC_str[] = "C"; - const Method Error = baseMethodError; - const Method Auto = baseMethodAuto; - const Method GemmA = 1; ///< Select gemmA algorithm - const Method GemmC = 2; ///< Select gemmC algorithm - - template - inline Method select_algo(TA& A, TB& B, Options& opts) { - // TODO replace the default value by a unique value located elsewhere - Target target = get_option( opts, Option::Target, Target::HostTask ); - int n_devices = A.num_devices(); - - Method method = (B.nt() < 2 ? GemmA : GemmC); - - if (method == GemmA && target == Target::Devices && n_devices > 1) - method = GemmC; - - return method; - } - - inline Method str2methodGemm(const char* method) - { - std::string method_ = method; - std::transform( - method_.begin(), method_.end(), method_.begin(), ::tolower ); - - if (method_ == "auto") - return Auto; - else if (method_ == "a" || method_ == "gemma") - return GemmA; - else if (method_ == "c" || method_ == "gemmc") - return GemmC; - else - throw slate::Exception("unknown gemm method"); - } - - inline const char* methodGemm2str(Method method) - { - switch (method) { - case Auto: return baseMethodAuto_str; - case GemmA: return GemmA_str; - case GemmC: return GemmC_str; - default: return baseMethodError_str; - } - } - -} // namespace MethodGemm - -//------------------------------------------------------------------------------ -/// Select the right algorithm to perform the hemm -namespace MethodHemm { - - constexpr char HemmA_str[] = "A"; - constexpr char HemmC_str[] = "C"; - const Method Error = baseMethodError; - const Method Auto = baseMethodAuto; - const Method HemmA = 1; ///< Select hemmA algorithm - const Method HemmC = 2; ///< Select hemmC algorithm - - template - inline Method select_algo(TA& A, TB& B, Options const& opts) { - // TODO replace the default value by a unique value located elsewhere - Target target = get_option( opts, Option::Target, Target::HostTask ); - - Method method = (B.nt() < 2 ? HemmA : HemmC); - - // XXX For now, when target == device, we fallback to HemmC on device - if (target == Target::Devices && method == HemmA) - method = HemmC; - - return method; - } - - inline Method str2methodHemm(const char* method) - { - std::string method_ = method; - std::transform( - method_.begin(), method_.end(), method_.begin(), ::tolower ); - - if (method_ == "auto") - return Auto; - else if (method_ == "a" || method_ == "hemma") - return HemmA; - else if (method_ == "c" || method_ == "hemmc") - return HemmC; - else - throw slate::Exception("unknown hemm method"); - } - - inline const char* methodHemm2str(Method method) - { - switch (method) { - case Auto: return baseMethodAuto_str; - case HemmA: return HemmA_str; - case HemmC: return HemmC_str; - default: return baseMethodError_str; - } - } - -} // namespace MethodHemm - -//------------------------------------------------------------------------------ -/// Select the right algorithm to perform the AH * A in CholQR -namespace MethodCholQR { - static constexpr char HerkC_str[] = "herkC"; - static constexpr char GemmA_str[] = "gemmA"; - static constexpr char GemmC_str[] = "gemmC"; - static const Method Error = baseMethodError; ///< Error flag - static const Method Auto = baseMethodAuto; ///< Let the algorithm decide - static const Method HerkC = 1; ///< Select herkC algorithm - static const Method GemmA = 2; ///< Select gemmA algorithm - static const Method GemmC = 3; ///< Select gemmC algorithm - - template - inline Method select_algo(TA& A, TB& B, Options const& opts) { - - Target target = get_option( opts, Option::Target, Target::HostTask ); - - Method method = (target == Target::Devices ? HerkC : GemmA); - - return method; - } - - inline Method str2methodCholQR(const char* method) - { - std::string method_ = method; - std::transform( - method_.begin(), method_.end(), method_.begin(), ::tolower ); - - if (method_ == "auto") - return Auto; - else if (method_ == "herkc") - return HerkC; - else if (method_ == "gemmc") - return GemmC; - else if (method_ == "gemma") - return GemmA; - else - throw slate::Exception("unknown cholQR method"); - } - - inline const char* methodCholQR2str(Method method) - { - switch (method) { - case Auto: return baseMethodAuto_str; - case HerkC: return HerkC_str; - case GemmA: return GemmA_str; - case GemmC: return GemmC_str; - default: return baseMethodError_str; - } - } - -} // namespace MethodCholQR - -//------------------------------------------------------------------------------ -/// Select the right algorithm to solve least squares problems -namespace MethodGels { - static constexpr char Cholqr_str[] = "cholqr"; - static constexpr char Geqrf_str[] = "qr"; - static const Method Error = baseMethodError; ///< Error flag - static const Method Auto = baseMethodAuto; ///< Let the algorithm decide - static const Method Cholqr = 1; ///< Select cholqr algorithm - static const Method Geqrf = 2; ///< Select geqrf algorithm - - template - inline Method select_algo(TA& A, TB& B, Options const& opts) { - return Geqrf; - } - - inline Method str2methodGels(const char* method) - { - std::string method_ = method; - std::transform( - method_.begin(), method_.end(), method_.begin(), ::tolower ); - - if (method_ == "auto") - return Auto; - else if (method_ == "qr") - return Geqrf; - else if (method_ == "cholqr") - return Cholqr; - else - throw slate::Exception("unknown gels method"); - } - - inline const char* methodGels2str(Method method) - { - switch (method) { - case Auto: return baseMethodAuto_str; - case Geqrf: return Geqrf_str; - case Cholqr: return Cholqr_str; - default: return baseMethodError_str; - } - } - -} // namespace MethodGels - -//------------------------------------------------------------------------------ -/// Select the LU factorization algorithm. -namespace MethodLU { - - static constexpr char PartialPiv_str[] = "PPLU"; - static constexpr char CALU_str[] = "CALU"; - static constexpr char NoPiv_str[] = "NoPiv"; - static const Method Error = baseMethodError; ///< Error flag - static const Method PartialPiv = 1; ///< Select partial pivoting LU - static const Method CALU = 2; ///< Select communication avoiding LU - static const Method NoPiv = 3; ///< Select no pivoting LU - - inline Method str2methodLU( const char* method ) - { - std::string method_ = method; - std::transform( - method_.begin(), method_.end(), method_.begin(), ::tolower ); - - if (method_ == "pplu" || method_ == "partialpiv") - return PartialPiv; - else if (method_ == "calu") - return CALU; - else if (method_ == "nopiv") - return NoPiv; - else - throw slate::Exception("unknown LU method"); - } - - inline const char* methodLU2str( Method method ) - { - switch (method) { - case PartialPiv: return PartialPiv_str; - case CALU: return CALU_str; - case NoPiv: return NoPiv_str; - default: return baseMethodError_str; - } - } - -} // namespace MethodLU - -} // namespace slate - -#endif // SLATE_METHOD_HH diff --git a/include/slate/slate.hh b/include/slate/slate.hh index 0dfc5b966..c4e307daf 100644 --- a/include/slate/slate.hh +++ b/include/slate/slate.hh @@ -15,8 +15,6 @@ #include "slate/TriangularBandMatrix.hh" #include "slate/HermitianBandMatrix.hh" -#include "slate/method.hh" - #include "slate/func.hh" #include "slate/types.hh" #include "slate/print.hh" diff --git a/include/slate/types.hh b/include/slate/types.hh index 27dfbc33d..2ce756b1d 100644 --- a/include/slate/types.hh +++ b/include/slate/types.hh @@ -46,7 +46,29 @@ public: OptionValue(Target t) : i_(int(t)) {} - OptionValue(MethodEig m) : i_(int(m)) + //----- Methods, alphabetical + OptionValue( MethodCholQR m ) : i_( int( m ) ) + {} + + OptionValue( MethodEig m ) : i_( int( m ) ) + {} + + OptionValue( MethodGels m ) : i_( int( m ) ) + {} + + OptionValue( MethodGemm m ) : i_( int( m ) ) + {} + + OptionValue( MethodHemm m ) : i_( int( m ) ) + {} + + OptionValue( MethodLU m ) : i_( int( m ) ) + {} + + OptionValue( MethodTrsm m ) : i_( int( m ) ) + {} + + OptionValue( MethodSVD m ) : i_( int( m ) ) {} union { @@ -58,9 +80,6 @@ public: using Options = std::map; using Value = OptionValue; ///< @deprecated -//------------------------------------------------------------------------------ -typedef int Method; - //------------------------------------------------------------------------------ class Pivot { public: @@ -236,13 +255,14 @@ template<> struct OptValueType { using T = int; }; template<> struct OptValueType { using T = int; }; template<> struct OptValueType { using T = int; }; template<> struct OptValueType { using T = int; }; -template<> struct OptValueType { using T = Method; }; +template<> struct OptValueType { using T = MethodCholQR; }; template<> struct OptValueType { using T = MethodEig; }; -template<> struct OptValueType { using T = Method; }; -template<> struct OptValueType { using T = Method; }; -template<> struct OptValueType { using T = Method; }; -template<> struct OptValueType { using T = Method; }; -template<> struct OptValueType { using T = Method; }; +template<> struct OptValueType { using T = MethodGels; }; +template<> struct OptValueType { using T = MethodGemm; }; +template<> struct OptValueType { using T = MethodHemm; }; +template<> struct OptValueType { using T = MethodLU; }; +template<> struct OptValueType { using T = MethodTrsm; }; +template<> struct OptValueType { using T = MethodSVD; }; template auto get_option( Options opts, typename OptValueType