diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD index a8ff0c4a3fc33d..a50ba568c4640e 100644 --- a/tensorflow/stream_executor/BUILD +++ b/tensorflow/stream_executor/BUILD @@ -119,7 +119,7 @@ cc_library( "//tensorflow/core:rocm", "@local_config_rocm//rocm:hiprand", "@local_config_rocm//rocm:rocfft", - "@local_config_rocm//rocm:hipblas", + "@local_config_rocm//rocm:rocblas", "@local_config_rocm//rocm:miopen", ]), alwayslink = 1, diff --git a/tensorflow/stream_executor/rocm/rocm_blas.cc b/tensorflow/stream_executor/rocm/rocm_blas.cc index bad5419c7b1dd4..c9330ce93bc304 100644 --- a/tensorflow/stream_executor/rocm/rocm_blas.cc +++ b/tensorflow/stream_executor/rocm/rocm_blas.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Include HIPBLAS headers early, and then set EIGEN_HAS_ROCM_FP16 +// Include rocBLAS headers early, and then set EIGEN_HAS_ROCM_FP16 // if we have new enough ROCM (which we will only know after including // rocm.h). This ensures that Eigen's Half.h does not attempt to make its own // __half typedef if ROCM has already defined one (and conversely, that we do // not include after Half.h has made its typedef). -#include "rocm/include/hipblas/hipblas.h" +#include "rocm/include/rocblas.h" #if ROCM_VERSION >= 7050 #define EIGEN_HAS_ROCM_FP16 @@ -57,284 +57,275 @@ limitations under the License. namespace stream_executor { namespace rocm { -PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kHipBlasPlugin); +PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocBlasPlugin); namespace wrap { -// XXX check if we need to port PERFTOOLS_GPU_TOOLS_HIPBLAS_WRAP from hipTensorFlow -#define PERFTOOLS_GPUTOOLS_HIPBLAS_WRAP(__name) \ +#define PERFTOOLS_GPUTOOLS_ROCBLAS_WRAP(__name) \ struct WrapperShim__##__name { \ static const char *kName; \ template \ - hipblasStatus_t operator()(ROCMExecutor *parent, Args... args) { \ + rocblas_status operator()(ROCMExecutor *parent, Args... args) { \ rocm::ScopedActivateExecutorContext sac{parent}; \ return ::__name(args...); \ } \ } __name; \ const char *WrapperShim__##__name::kName = #__name; -#define PERFTOOLS_GPUTOOLS_HIPBLAS_V2_WRAP(__name) \ - PERFTOOLS_GPUTOOLS_HIPBLAS_WRAP(__name) +#define PERFTOOLS_GPUTOOLS_ROCBLAS_V2_WRAP(__name) \ + PERFTOOLS_GPUTOOLS_ROCBLAS_WRAP(__name) #define HIPBLAS_BLAS_ROUTINE_EACH(__macro) \ -/* __macro(hipblasSnrm2) \ - __macro(hipblasDnrm2) \ - __macro(hipblasScnrm2) \ - __macro(hipblasDznrm2) */ \ - __macro(hipblasSdot) \ - __macro(hipblasDdot) \ -/* __macro(hipblasCdotu) \ - __macro(hipblasCdotc) \ - __macro(hipblasZdotu) \ - __macro(hipblasZdotc) */ \ - __macro(hipblasSscal) \ - __macro(hipblasDscal) \ -/* __macro(hipblasCscal) \ - __macro(hipblasCsscal) \ - __macro(hipblasZscal) \ - __macro(hipblasZdscal) */ \ - __macro(hipblasSaxpy) \ - __macro(hipblasDaxpy) \ -/* __macro(hipblasCaxpy) \ - __macro(hipblasZaxpy) */ \ - __macro(hipblasScopy) \ - __macro(hipblasDcopy) \ -/* __macro(hipblasCcopy) \ - __macro(hipblasZcopy) \ - __macro(hipblasSswap) \ - __macro(hipblasDswap) \ - __macro(hipblasCswap) \ - __macro(hipblasZswap) \ - __macro(hipblasIsamax) \ - __macro(hipblasIdamax) \ - __macro(hipblasIcamax) \ - __macro(hipblasIzamax) \ - __macro(hipblasIsamin) \ - __macro(hipblasIdamin) \ - __macro(hipblasIcamin) \ - __macro(hipblasIzamin) */ \ - __macro(hipblasSasum) \ - __macro(hipblasDasum) \ -/* __macro(hipblasScasum) \ - __macro(hipblasDzasum) \ - __macro(hipblasSrot) \ - __macro(hipblasDrot) \ - __macro(hipblasCrot) \ - __macro(hipblasCsrot) \ - __macro(hipblasZrot) \ - __macro(hipblasZdrot) \ - __macro(hipblasSrotg) \ - __macro(hipblasDrotg) \ - __macro(hipblasCrotg) \ - __macro(hipblasZrotg) \ - __macro(hipblasSrotm) \ - __macro(hipblasDrotm) \ - __macro(hipblasSrotmg) \ - __macro(hipblasDrotmg) */ \ - __macro(hipblasSgemv) \ - __macro(hipblasDgemv) \ -/* __macro(hipblasCgemv) \ - __macro(hipblasZgemv) \ - __macro(hipblasSgbmv) \ - __macro(hipblasDgbmv) \ - __macro(hipblasCgbmv) \ - __macro(hipblasZgbmv) \ - __macro(hipblasStrmv) \ - __macro(hipblasDtrmv) \ - __macro(hipblasCtrmv) \ - __macro(hipblasZtrmv) \ - __macro(hipblasStbmv) \ - __macro(hipblasDtbmv) \ - __macro(hipblasCtbmv) \ - __macro(hipblasZtbmv) \ - __macro(hipblasStpmv) \ - __macro(hipblasDtpmv) \ - __macro(hipblasCtpmv) \ - __macro(hipblasZtpmv) \ - __macro(hipblasStrsv) \ - __macro(hipblasDtrsv) \ - __macro(hipblasCtrsv) \ - __macro(hipblasZtrsv) \ - __macro(hipblasStpsv) \ - __macro(hipblasDtpsv) \ - __macro(hipblasCtpsv) \ - __macro(hipblasZtpsv) \ - __macro(hipblasStbsv) \ - __macro(hipblasDtbsv) \ - __macro(hipblasCtbsv) \ - __macro(hipblasZtbsv) \ - __macro(hipblasSsymv) \ - __macro(hipblasDsymv) \ - __macro(hipblasCsymv) \ - __macro(hipblasZsymv) \ - __macro(hipblasChemv) \ - __macro(hipblasZhemv) \ - __macro(hipblasSsbmv) \ - __macro(hipblasDsbmv) \ - __macro(hipblasChbmv) \ - __macro(hipblasZhbmv) \ - __macro(hipblasSspmv) \ - __macro(hipblasDspmv) \ - __macro(hipblasChpmv) \ - __macro(hipblasZhpmv) */ \ - __macro(hipblasSger) \ -/* __macro(hipblasDger) \ - __macro(hipblasCgeru) \ - __macro(hipblasCgerc) \ - __macro(hipblasZgeru) \ - __macro(hipblasZgerc) \ - __macro(hipblasSsyr) \ - __macro(hipblasDsyr) \ - __macro(hipblasCsyr) \ - __macro(hipblasZsyr) \ - __macro(hipblasCher) \ - __macro(hipblasZher) \ - __macro(hipblasSspr) \ - __macro(hipblasDspr) \ - __macro(hipblasChpr) \ - __macro(hipblasZhpr) \ - __macro(hipblasSsyr2) \ - __macro(hipblasDsyr2) \ - __macro(hipblasCsyr2) \ - __macro(hipblasZsyr2) \ - __macro(hipblasCher2) \ - __macro(hipblasZher2) \ - __macro(hipblasSspr2) \ - __macro(hipblasDspr2) \ - __macro(hipblasChpr2) \ - __macro(hipblasZhpr2) */ \ - __macro(hipblasSgemm) \ - __macro(hipblasDgemm) \ -/* __macro(hipblasCgemm) \ - __macro(hipblasZgemm) \ - __macro(hipblasSsyrk) \ - __macro(hipblasDsyrk) \ - __macro(hipblasCsyrk) \ - __macro(hipblasZsyrk) \ - __macro(hipblasCherk) \ - __macro(hipblasZherk) \ - __macro(hipblasSsyr2k) \ - __macro(hipblasDsyr2k) \ - __macro(hipblasCsyr2k) \ - __macro(hipblasZsyr2k) \ - __macro(hipblasCher2k) \ - __macro(hipblasZher2k) \ - __macro(hipblasSsyrkx) \ - __macro(hipblasDsyrkx) \ - __macro(hipblasCsyrkx) \ - __macro(hipblasZsyrkx) \ - __macro(hipblasCherkx) \ - __macro(hipblasZherkx) \ - __macro(hipblasSsymm) \ - __macro(hipblasDsymm) \ - __macro(hipblasCsymm) \ - __macro(hipblasZsymm) \ - __macro(hipblasChemm) \ - __macro(hipblasZhemm) \ - __macro(hipblasStrsm) \ - __macro(hipblasDtrsm) \ - __macro(hipblasCtrsm) \ - __macro(hipblasZtrsm) \ - __macro(hipblasStrmm) \ - __macro(hipblasDtrmm) \ - __macro(hipblasCtrmm) \ - __macro(hipblasZtrmm) \ - __macro(hipblasSgeam) \ - __macro(hipblasDgeam) \ - __macro(hipblasCgeam) \ - __macro(hipblasZgeam) \ - __macro(hipblasSdgmm) \ - __macro(hipblasDdgmm) \ - __macro(hipblasCdgmm) \ - __macro(hipblasZdgmm) */ - -PERFTOOLS_GPUTOOLS_HIPBLAS_V2_WRAP(hipblasCreate) -PERFTOOLS_GPUTOOLS_HIPBLAS_V2_WRAP(hipblasDestroy) -PERFTOOLS_GPUTOOLS_HIPBLAS_V2_WRAP(hipblasSetStream) -//PERFTOOLS_GPUTOOLS_HIPBLAS_V2_WRAP(hipblasSetPointerMode) -//PERFTOOLS_GPUTOOLS_HIPBLAS_V2_WRAP(hipblasGetPointerMode) -//PERFTOOLS_GPUTOOLS_HIPBLAS_WRAP(hipblasSgemmBatched) -PERFTOOLS_GPUTOOLS_HIPBLAS_WRAP(hipblasSgemmStridedBatched) -//PERFTOOLS_GPUTOOLS_HIPBLAS_WRAP(hipblasDgemmBatched) -PERFTOOLS_GPUTOOLS_HIPBLAS_WRAP(hipblasDgemmStridedBatched) -//PERFTOOLS_GPUTOOLS_HIPBLAS_WRAP(hipblasCgemmBatched) -//PERFTOOLS_GPUTOOLS_HIPBLAS_WRAP(hipblasZgemmBatched) -HIPBLAS_BLAS_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_HIPBLAS_V2_WRAP) + __macro(rocblas_snrm2) \ + __macro(rocblas_dnrm2) \ +/* __macro(rocblas_scnrm2) \ + __macro(rocblas_dznrm2) */ \ + __macro(rocblas_sdot) \ + __macro(rocblas_ddot) \ +/* __macro(rocblas_cdotu) \ + __macro(rocblas_cdotc) \ + __macro(rocblas_zdotu) \ + __macro(rocblas_zdotc) */ \ + __macro(rocblas_sscal) \ + __macro(rocblas_dscal) \ +/* __macro(rocblas_cscal) \ + __macro(rocblas_csscal) \ + __macro(rocblas_zscal) \ + __macro(rocblas_zdscal) */ \ + __macro(rocblas_saxpy) \ + __macro(rocblas_daxpy) \ +/* __macro(rocblas_caxpy) \ + __macro(rocblas_zaxpy) */ \ + __macro(rocblas_scopy) \ + __macro(rocblas_dcopy) \ +/* __macro(rocblas_ccopy) \ + __macro(rocblas_zcopy) */ \ + __macro(rocblas_sswap) \ + __macro(rocblas_dswap) \ +/* __macro(rocblas_cswap) \ + __macro(rocblas_zswap) */ \ + __macro(rocblas_isamax) \ + __macro(rocblas_idamax) \ +/* __macro(rocblas_icamax) \ + __macro(rocblas_izamax) */ \ + __macro(rocblas_isamin) \ + __macro(rocblas_idamin) \ +/* __macro(rocblas_icamin) \ + __macro(rocblas_izamin) */ \ + __macro(rocblas_sasum) \ + __macro(rocblas_dasum) \ +/* __macro(rocblas_scasum) \ + __macro(rocblas_dzasum) \ + __macro(rocblas_srot) \ + __macro(rocblas_drot) \ + __macro(rocblas_crot) \ + __macro(rocblas_csrot) \ + __macro(rocblas_zrot) \ + __macro(rocblas_zdrot) \ + __macro(rocblas_srotg) \ + __macro(rocblas_drotg) \ + __macro(rocblas_Crotg) \ + __macro(rocblas_crotg) \ + __macro(rocblas_zrotm) \ + __macro(rocblas_drotm) \ + __macro(rocblas_srotmg) \ + __macro(rocblas_drotmg) */ \ + __macro(rocblas_sgemv) \ + __macro(rocblas_dgemv) \ +/* __macro(rocblas_cgemv) \ + __macro(rocblas_zgemv) \ + __macro(rocblas_sgbmv) \ + __macro(rocblas_dgbmv) \ + __macro(rocblas_cgbmv) \ + __macro(rocblas_zgbmv) \ + __macro(rocblas_strmv) \ + __macro(rocblas_dtrmv) \ + __macro(rocblas_ctrmv) \ + __macro(rocblas_ztrmv) \ + __macro(rocblas_stbmv) \ + __macro(rocblas_dtbmv) \ + __macro(rocblas_ctbmv) \ + __macro(rocblas_ztbmv) \ + __macro(rocblas_stpmv) \ + __macro(rocblas_dtpmv) \ + __macro(rocblas_ctpmv) \ + __macro(rocblas_ztpmv) \ + __macro(rocblas_strsv) \ + __macro(rocblas_dtrsv) \ + __macro(rocblas_ctrsv) \ + __macro(rocblas_ztrsv) \ + __macro(rocblas_stpsv) \ + __macro(rocblas_dtpsv) \ + __macro(rocblas_ctpsv) \ + __macro(rocblas_ztpsv) \ + __macro(rocblas_stbsv) \ + __macro(rocblas_dtbsv) \ + __macro(rocblas_ctbsv) \ + __macro(rocblas_ztbsv) \ + __macro(rocblas_ssymv) \ + __macro(rocblas_dsymv) \ + __macro(rocblas_csymv) \ + __macro(rocblas_zsymv) \ + __macro(rocblas_chemv) \ + __macro(rocblas_zhemv) \ + __macro(rocblas_ssbmv) \ + __macro(rocblas_dsbmv) \ + __macro(rocblas_chbmv) \ + __macro(rocblas_zhbmv) \ + __macro(rocblas_sspmv) \ + __macro(rocblas_dspmv) \ + __macro(rocblas_chpmv) \ + __macro(rocblas_zhpmv) */ \ + __macro(rocblas_sger) \ + __macro(rocblas_dger) \ +/* __macro(rocblas_cgeru) \ + __macro(rocblas_cgerc) \ + __macro(rocblas_zgeru) \ + __macro(rocblas_zgerc) */ \ + __macro(rocblas_ssyr) \ + __macro(rocblas_dsyr) \ +/* __macro(rocblas_csyr) \ + __macro(rocblas_zsyr) \ + __macro(rocblas_cher) \ + __macro(rocblas_zher) \ + __macro(rocblas_sspr) \ + __macro(rocblas_dspr) \ + __macro(rocblas_chpr) \ + __macro(rocblas_zhpr) \ + __macro(rocblas_ssyr2) \ + __macro(rocblas_dsyr2) \ + __macro(rocblas_csyr2) \ + __macro(rocblas_zsyr2) \ + __macro(rocblas_cher2) \ + __macro(rocblas_zher2) \ + __macro(rocblas_sspr2) \ + __macro(rocblas_dspr2) \ + __macro(rocblas_chpr2) \ + __macro(rocblas_zhpr2) */ \ + __macro(rocblas_sgemm) \ + __macro(rocblas_dgemm) \ +/* __macro(rocblas_cgemm) \ + __macro(rocblas_zgemm) \ + __macro(rocblas_ssyrk) \ + __macro(rocblas_dsyrk) \ + __macro(rocblas_csyrk) \ + __macro(rocblas_zsyrk) \ + __macro(rocblas_cherk) \ + __macro(rocblas_zherk) \ + __macro(rocblas_ssyr2k) \ + __macro(rocblas_dsyr2k) \ + __macro(rocblas_csyr2k) \ + __macro(rocblas_zsyr2k) \ + __macro(rocblas_cher2k) \ + __macro(rocblas_zher2k) \ + __macro(rocblas_ssyrkx) \ + __macro(rocblas_dsyrkx) \ + __macro(rocblas_csyrkx) \ + __macro(rocblas_zsyrkx) \ + __macro(rocblas_cherkx) \ + __macro(rocblas_zherkx) \ + __macro(rocblas_ssymm) \ + __macro(rocblas_dsymm) \ + __macro(rocblas_csymm) \ + __macro(rocblas_zsymm) \ + __macro(rocblas_chemm) \ + __macro(rocblas_zhemm) */ \ + __macro(rocblas_strsm) \ + __macro(rocblas_dtrsm) \ +/* __macro(rocblas_ctrsm) \ + __macro(rocblas_ztrsm) \ + __macro(rocblas_strmm) \ + __macro(rocblas_dtrmm) \ + __macro(rocblas_ctrmm) \ + __macro(rocblas_ztrmm) \ + __macro(rocblas_sgeam) \ + __macro(rocblas_dgeam) \ + __macro(rocblas_cgeam) \ + __macro(rocblas_zgeam) \ + __macro(rocblas_sdgmm) \ + __macro(rocblas_ddgmm) \ + __macro(rocblas_cdgmm) \ + __macro(rocblas_zdgmm) */ + +PERFTOOLS_GPUTOOLS_ROCBLAS_V2_WRAP(rocblas_create_handle) +PERFTOOLS_GPUTOOLS_ROCBLAS_V2_WRAP(rocblas_destroy_handle) +PERFTOOLS_GPUTOOLS_ROCBLAS_V2_WRAP(rocblas_set_stream) +//PERFTOOLS_GPUTOOLS_ROCBLAS_V2_WRAP(rocblas_set_pointer_mode) +//PERFTOOLS_GPUTOOLS_ROCBLAS_V2_WRAP(rocblas_get_pointer_mode) +//PERFTOOLS_GPUTOOLS_ROCBLAS_WRAP(rocblas_sgemm_batched) +PERFTOOLS_GPUTOOLS_ROCBLAS_WRAP(rocblas_sgemm_strided_batched) +//PERFTOOLS_GPUTOOLS_ROCBLAS_WRAP(rocblas_dgemm_batched) +PERFTOOLS_GPUTOOLS_ROCBLAS_WRAP(rocblas_dgemm_strided_batched) +//PERFTOOLS_GPUTOOLS_ROCBLAS_WRAP(rocblas_cgemm_batched) +//PERFTOOLS_GPUTOOLS_ROCBLAS_WRAP(rocblas_zgemm_batched) +HIPBLAS_BLAS_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_ROCBLAS_V2_WRAP) #if ROCM_VERSION >= 7050 -//PERFTOOLS_GPUTOOLS_HIPBLAS_WRAP(hipblasSgemmEx) +//PERFTOOLS_GPUTOOLS_ROCBLAS_WRAP(rocblas_sgemmEx) #endif #if ROCM_VERSION >= 8000 -//PERFTOOLS_GPUTOOLS_HIPBLAS_WRAP(hipblasGemmEx) +//PERFTOOLS_GPUTOOLS_ROCBLAS_WRAP(rocblas_dgemmEx) #endif } // namespace wrap -static string ToString(hipblasStatus_t status) { +static string ToString(rocblas_status status) { switch (status) { - case HIPBLAS_STATUS_SUCCESS: - return "HIPBLAS_STATUS_SUCCESS"; - case HIPBLAS_STATUS_NOT_INITIALIZED: - return "HIPBLAS_STATUS_NOT_INITIALIZED"; - case HIPBLAS_STATUS_ALLOC_FAILED: - return "HIPBLAS_STATUS_ALLOC_FAILED"; - case HIPBLAS_STATUS_INVALID_VALUE: - return "HIPBLAS_STATUS_INVALID_VALUE"; - //case HIPBLAS_STATUS_ARCH_MISMATCH: - // return "HIPBLAS_STATUS_ARCH_MISMATCH"; - case HIPBLAS_STATUS_MAPPING_ERROR: - return "HIPBLAS_STATUS_MAPPING_ERROR"; - case HIPBLAS_STATUS_EXECUTION_FAILED: - return "HIPBLAS_STATUS_EXECUTION_FAILED"; - case HIPBLAS_STATUS_INTERNAL_ERROR: - return "HIPBLAS_STATUS_INTERNAL_ERROR"; -#if ROCM_VERSION >= 8000 - //case HIPBLAS_STATUS_NOT_SUPPORTED: - // return "HIPBLAS_STATUS_NOT_SUPPORTED"; - //case HIPBLAS_STATUS_LICENSE_ERROR: - // return "HIPBLAS_STATUS_LICENSE_ERROR"; -#endif + case rocblas_status_success: + return "rocblas_status_success"; + case rocblas_status_invalid_handle: + return "rocblas_status_invalid_handle"; + case rocblas_status_not_implemented: + return "rocblas_status_not_implemented"; + case rocblas_status_invalid_pointer: + return "rocblas_status_invalid_pointer"; + case rocblas_status_invalid_size: + return "rocblas_status_invalid_size"; + case rocblas_status_memory_error: + return "rocblas_status_memory_error"; + case rocblas_status_internal_error: + return "rocblas_status_internal_error"; default: - return port::StrCat(""); + return port::StrCat(""); } } -// HIPBLAS has interfaces that permit pointers to be passed from either the host +// rocBLAS has interfaces that permit pointers to be passed from either the host // memory space or the device memory space; however, you must instruct it as to -// which address space those pointers are in with hipblasSetPointerMode. +// which address space those pointers are in with rocblas_SetPointerMode. // -// This helper sets the HIPBLAS pointer mode to a desired value for a HIPBLAS call +// This helper sets the rocBLAS pointer mode to a desired value for a rocBLAS call // you are about to perform in a given scope. // -// The prior HIPBLAS pointer mode is retained and restored when this object goes +// The prior rocBLAS pointer mode is retained and restored when this object goes // out of scope. -/*class ScopedHipblasPointerMode { +/*class ScopedRocBLASPointerMode { public: - // Note that, because the setting of the hipblas pointer mode is fallible, + // Note that, because the setting of the rocBLAS pointer mode is fallible, // construction of this scoped datatype must be paired with a call to // Init(). // // Parameters: - // handle: The hipblas library handle to act upon in setting the pointer mode. - explicit ScopedHipblasPointerMode(ROCMExecutor *parent, hipblasHandle_t handle) + // handle: The rocBLAS library handle to act upon in setting the pointer mode. + explicit ScopedRocBLASPointerMode(ROCMExecutor *parent, rocblas_handle handle) : parent_(parent), handle_(handle), ok_(false) {} // Attempts the switch to the requested scoped pointer mode, new_mode. // // Note that when false is returned, an appropriate error has already been // logged. - bool Init(hipblasPointerMode_t new_mode) { - hipblasStatus_t ret = - wrap::hipblasGetPointerMode(parent_, handle_, &old_mode_); - if (ret != HIPBLAS_STATUS_SUCCESS) { - LOG(ERROR) << "failed to get old hipblas pointer mode: " << ToString(ret); + bool Init(rocblas_pointer_mode new_mode) { + rocblas_status ret = + wrap::rocblas_get_pointer_mode(parent_, handle_, &old_mode_); + if (ret != rocblas_status_success) { + LOG(ERROR) << "failed to get old rocBLAS pointer mode: " << ToString(ret); return ok_ = false; } - ret = wrap::hipblasSetPointerMode(parent_, handle_, new_mode); - if (ret != HIPBLAS_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set new hipblas pointer mode: " << ToString(ret); + ret = wrap::rocblas_set_pointer_mode(parent_, handle_, new_mode); + if (ret != rocblas_status_success) { + LOG(ERROR) << "failed to set new rocBLAS pointer mode: " << ToString(ret); return ok_ = false; } @@ -343,12 +334,12 @@ static string ToString(hipblasStatus_t status) { // Switches back to the prior pointer mode, if the switch operation was // successful in the first place. - ~ScopedHipblasPointerMode() { + ~ScopedRocBLASPointerMode() { if (ok_) { - hipblasStatus_t ret = - wrap::hipblasSetPointerMode(parent_, handle_, old_mode_); - if (ret != HIPBLAS_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set former hipblas pointer mode: " + rocblas_status ret = + wrap::rocblas_set_pointer_mode(parent_, handle_, old_mode_); + if (ret != rocblas_status_success) { + LOG(ERROR) << "failed to set former rocBLAS pointer mode: " << ToString(ret); } } @@ -356,15 +347,15 @@ static string ToString(hipblasStatus_t status) { private: ROCMExecutor *parent_; // Executor establishing this pointer mode for. - hipblasHandle_t handle_; // Handle to the HIPBLAS instance of interest. - hipblasPointerMode_t old_mode_; // Prior HIPBLAS pointer mode, to be restored. + rocblas_handle handle_; // Handle to the rocBLAS instance of interest. + rocblas_pointer_mode old_mode_; // Prior rocBLAS pointer mode, to be restored. bool ok_; // Whether the change was successful. };*/ bool ROCMBlas::Init() { - hipblasStatus_t ret = wrap::hipblasCreate(parent_, &blas_); - if (ret != HIPBLAS_STATUS_SUCCESS) { - LOG(ERROR) << "failed to create hipblas handle: " << ToString(ret); + rocblas_status ret = wrap::rocblas_create_handle(parent_, &blas_); + if (ret != rocblas_status_success) { + LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret); return false; } @@ -376,7 +367,7 @@ ROCMBlas::ROCMBlas(rocm::ROCMExecutor *parent) ROCMBlas::~ROCMBlas() { if (blas_ != nullptr) { - wrap::hipblasDestroy(parent_, blas_); + wrap::rocblas_destroy_handle(parent_, blas_); } } @@ -384,10 +375,10 @@ bool ROCMBlas::SetStream(Stream *stream) { CHECK(stream != nullptr); CHECK(AsROCMStreamValue(stream) != nullptr); CHECK(blas_ != nullptr); - hipblasStatus_t ret = - wrap::hipblasSetStream(parent_, blas_, AsROCMStreamValue(stream)); - if (ret != HIPBLAS_STATUS_SUCCESS) { - LOG(ERROR) << "failed to set stream for HIPBLAS calls: " << ToString(ret); + rocblas_status ret = + wrap::rocblas_set_stream(parent_, blas_, AsROCMStreamValue(stream)); + if (ret != rocblas_status_success) { + LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret); return false; } @@ -396,53 +387,53 @@ bool ROCMBlas::SetStream(Stream *stream) { namespace { -// Helper functions transforming blas arguments into HIPBLAS arguments. +// Helper functions transforming blas arguments into rocBLAS arguments. -hipblasOperation_t ROCMBlasTranspose(blas::Transpose trans) { +rocblas_operation ROCMBlasTranspose(blas::Transpose trans) { switch (trans) { case blas::Transpose::kNoTranspose: - return HIPBLAS_OP_N; + return rocblas_operation_none; case blas::Transpose::kTranspose: - return HIPBLAS_OP_T; + return rocblas_operation_transpose; case blas::Transpose::kConjugateTranspose: - return HIPBLAS_OP_C; + return rocblas_operation_conjugate_transpose; default: LOG(FATAL) << "Invalid value of blas::Transpose."; } } -/*hipblasFillMode_t ROCMBlasUpperLower(blas::UpperLower uplo) { +rocblas_fill ROCMBlasUpperLower(blas::UpperLower uplo) { switch (uplo) { case blas::UpperLower::kUpper: - return HIPBLAS_FILL_MODE_UPPER; + return rocblas_fill_upper; case blas::UpperLower::kLower: - return HIPBLAS_FILL_MODE_LOWER; + return rocblas_fill_lower; default: LOG(FATAL) << "Invalid value of blas::UpperLower."; } -}*/ +} -/*hipblasDiagType_t ROCMBlasDiagonal(blas::Diagonal diag) { +rocblas_diagonal ROCMBlasDiagonal(blas::Diagonal diag) { switch (diag) { case blas::Diagonal::kUnit: - return HIPBLAS_DIAG_UNIT; + return rocblas_diagonal_unit; case blas::Diagonal::kNonUnit: - return HIPBLAS_DIAG_NON_UNIT; + return rocblas_diagonal_non_unit; default: LOG(FATAL) << "Invalid value of blas::Diagonal."; } -}*/ +} -/*hipblasSideMode_t ROCMBlasSide(blas::Side side) { +rocblas_side ROCMBlasSide(blas::Side side) { switch (side) { case blas::Side::kLeft: - return HIPBLAS_SIDE_LEFT; + return rocblas_side_left; case blas::Side::kRight: - return HIPBLAS_SIDE_RIGHT; + return rocblas_side_right; default: LOG(FATAL) << "Invalid value of blas::Side."; } -}*/ +} /* // ROCMDataType::type translates from a C++ type (e.g. float) to a @@ -450,7 +441,7 @@ hipblasOperation_t ROCMBlasTranspose(blas::Transpose trans) { // blas::ComputationType to a rocmDataType_t. // // These are used to build the argument type and computation type args to -// hipblasGemmEx. hipblasGemmEx and rocmDataType_t are available only on +// rocblasGemmEx. rocblasGemmEx and rocmDataType_t are available only on // ROCM >= 8.0. #if ROCM_VERSION >= 8000 template @@ -533,10 +524,11 @@ rocmDataType_t ROCMComputationType(blas::ComputationType ty) { } // namespace template -bool ROCMBlas::DoBlasInternalImpl(FuncT hipblas_func, Stream *stream, +bool ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, bool pointer_mode_host, bool err_on_failure, Args... args) { mutex_lock lock{mu_}; + // XXX (jmd) why no unlock? CHECK(blas_ != nullptr); if (!SetStream(stream)) { @@ -549,18 +541,18 @@ bool ROCMBlas::DoBlasInternalImpl(FuncT hipblas_func, Stream *stream, return false; }*/ - hipblasStatus_t ret = hipblas_func(parent_, blas_, args...); - if (err_on_failure && ret != HIPBLAS_STATUS_SUCCESS) { - LOG(ERROR) << "failed to run HIPBLAS routine " << hipblas_func.kName << ": " + rocblas_status ret = rocblas_func(parent_, blas_, args...); + if (err_on_failure && ret != rocblas_status_success) { + LOG(ERROR) << "failed to run ROCBLAS routine " << rocblas_func.kName << ": " << ToString(ret); } - return ret == HIPBLAS_STATUS_SUCCESS; + return ret == rocblas_status_success; } bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) { - return DoBlasInternal(wrap::hipblasSasum, stream, + return DoBlasInternal(wrap::rocblas_sasum, stream, false /* = pointer_mode_host */, elem_count, ROCMMemory(x), incx, ROCMMemoryMutable(result)); } @@ -568,7 +560,7 @@ bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) { - return DoBlasInternal(wrap::hipblasDasum, stream, + return DoBlasInternal(wrap::rocblas_dasum, stream, false /* = pointer_mode_host */, elem_count, ROCMMemory(x), incx, ROCMMemoryMutable(result)); } @@ -578,7 +570,7 @@ bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, DeviceMemory *result) { return false; //return DoBlasInternal( - // wrap::hipblasScasum, stream, false /* = pointer_mode_host */, elem_count, + // wrap::rocblas_scasum, stream, false /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, ROCMMemoryMutable(result)); } @@ -587,14 +579,14 @@ bool ROCMBlas::DoBlasAsum(Stream *stream, uint64 elem_count, DeviceMemory *result) { return false; //return DoBlasInternal( - // wrap::hipblasDzasum, stream, false /* = pointer_mode_host */, elem_count, + // wrap::rocblas_dzasum, stream, false /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, ROCMMemoryMutable(result)); } bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) { - return DoBlasInternal(wrap::hipblasSaxpy, stream, + return DoBlasInternal(wrap::rocblas_saxpy, stream, true /* = pointer_mode_host */, elem_count, &alpha, ROCMMemory(x), incx, ROCMMemoryMutable(y), incy); } @@ -602,7 +594,7 @@ bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) { - return DoBlasInternal(wrap::hipblasDaxpy, stream, + return DoBlasInternal(wrap::rocblas_daxpy, stream, true /* = pointer_mode_host */, elem_count, &alpha, ROCMMemory(x), incx, ROCMMemoryMutable(y), incy); } @@ -612,7 +604,7 @@ bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory> *y, int incy) { return false; - //return DoBlasInternal(wrap::hipblasCaxpy, stream, + //return DoBlasInternal(wrap::rocblas_caxpy, stream, // true /* = pointer_mode_host */, elem_count, // ROCMComplex(&alpha), ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -623,7 +615,7 @@ bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory> *y, int incy) { return false; - //return DoBlasInternal(wrap::hipblasZaxpy, stream, + //return DoBlasInternal(wrap::rocblas_zaxpy, stream, // true /* = pointer_mode_host */, elem_count, // ROCMComplex(&alpha), ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -632,7 +624,7 @@ bool ROCMBlas::DoBlasAxpy(Stream *stream, uint64 elem_count, bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) { - return DoBlasInternal(wrap::hipblasScopy, stream, + return DoBlasInternal(wrap::rocblas_scopy, stream, true /* = pointer_mode_host */, elem_count, ROCMMemory(x), incx, ROCMMemoryMutable(y), incy); } @@ -640,7 +632,7 @@ bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *y, int incy) { - return DoBlasInternal(wrap::hipblasDcopy, stream, + return DoBlasInternal(wrap::rocblas_dcopy, stream, true /* = pointer_mode_host */, elem_count, ROCMMemory(x), incx, ROCMMemoryMutable(y), incy); } @@ -649,7 +641,7 @@ bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory> *y, int incy) { return false; - //return DoBlasInternal(wrap::hipblasCcopy, stream, + //return DoBlasInternal(wrap::rocblas_ccopy, stream, // true /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -659,7 +651,7 @@ bool ROCMBlas::DoBlasCopy(Stream *stream, uint64 elem_count, const DeviceMemory> &x, int incx, DeviceMemory> *y, int incy) { return false; - //return DoBlasInternal(wrap::hipblasZcopy, stream, + //return DoBlasInternal(wrap::rocblas_zcopy, stream, // true /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -670,7 +662,7 @@ bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count, const DeviceMemory &y, int incy, DeviceMemory *result) { return DoBlasInternal( - wrap::hipblasSdot, stream, false /* = pointer_mode_host */, elem_count, + wrap::rocblas_sdot, stream, false /* = pointer_mode_host */, elem_count, ROCMMemory(x), incx, ROCMMemory(y), incy, ROCMMemoryMutable(result)); } @@ -679,7 +671,7 @@ bool ROCMBlas::DoBlasDot(Stream *stream, uint64 elem_count, const DeviceMemory &y, int incy, DeviceMemory *result) { return DoBlasInternal( - wrap::hipblasDdot, stream, false /* = pointer_mode_host */, elem_count, + wrap::rocblas_ddot, stream, false /* = pointer_mode_host */, elem_count, ROCMMemory(x), incx, ROCMMemory(y), incy, ROCMMemoryMutable(result)); } @@ -689,7 +681,7 @@ bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count, DeviceMemory> *result) { return false; //return DoBlasInternal( - // wrap::hipblasCdotc, stream, false /* = pointer_mode_host */, elem_count, + // wrap::rocblas_cdotc, stream, false /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, ROCMComplex(ROCMMemory(y)), incy, // ROCMComplex(ROCMMemoryMutable(result))); } @@ -700,7 +692,7 @@ bool ROCMBlas::DoBlasDotc(Stream *stream, uint64 elem_count, DeviceMemory> *result) { return false; //return DoBlasInternal( - // wrap::hipblasZdotc, stream, false /* = pointer_mode_host */, elem_count, + // wrap::rocblas_zdotc, stream, false /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, ROCMComplex(ROCMMemory(y)), incy, // ROCMComplex(ROCMMemoryMutable(result))); } @@ -711,7 +703,7 @@ bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count, DeviceMemory> *result) { return false; //return DoBlasInternal( - // wrap::hipblasCdotu, stream, false /* = pointer_mode_host */, elem_count, + // wrap::rocblas_cdotu, stream, false /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, ROCMComplex(ROCMMemory(y)), incy, // ROCMComplex(ROCMMemoryMutable(result))); } @@ -722,7 +714,7 @@ bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count, DeviceMemory> *result) { return false; //return DoBlasInternal( - // wrap::hipblasZdotu, stream, false /* = pointer_mode_host */, elem_count, + // wrap::rocblas_zdotu, stream, false /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, ROCMComplex(ROCMMemory(y)), incy, // ROCMComplex(ROCMMemoryMutable(result))); } @@ -730,19 +722,17 @@ bool ROCMBlas::DoBlasDotu(Stream *stream, uint64 elem_count, bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) { - return false; - //return DoBlasInternal(wrap::hipblasSnrm2, stream, - // false /* = pointer_mode_host */, elem_count, - // ROCMMemory(x), incx, ROCMMemoryMutable(result)); + return DoBlasInternal(wrap::rocblas_snrm2, stream, + false /* = pointer_mode_host */, elem_count, + ROCMMemory(x), incx, ROCMMemoryMutable(result)); } bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) { - return false; - //return DoBlasInternal(wrap::hipblasDnrm2, stream, - // false /* = pointer_mode_host */, elem_count, - // ROCMMemory(x), incx, ROCMMemoryMutable(result)); + return DoBlasInternal(wrap::rocblas_dnrm2, stream, + false /* = pointer_mode_host */, elem_count, + ROCMMemory(x), incx, ROCMMemoryMutable(result)); } bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, @@ -750,7 +740,7 @@ bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, DeviceMemory *result) { return false; //return DoBlasInternal( - // wrap::hipblasScnrm2, stream, false /* = pointer_mode_host */, elem_count, + // wrap::rocblas_scnrm2, stream, false /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, ROCMMemoryMutable(result)); } @@ -759,7 +749,7 @@ bool ROCMBlas::DoBlasNrm2(Stream *stream, uint64 elem_count, DeviceMemory *result) { return false; //return DoBlasInternal( - // wrap::hipblasDznrm2, stream, false /* = pointer_mode_host */, elem_count, + // wrap::rocblas_dznrm2, stream, false /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, ROCMMemoryMutable(result)); } @@ -768,7 +758,7 @@ bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory *y, int incy, float c, float s) { return false; //return DoBlasInternal( - // wrap::hipblasSrot, stream, true /* = pointer_mode_host */, elem_count, + // wrap::rocblas_srot, stream, true /* = pointer_mode_host */, elem_count, // ROCMMemoryMutable(x), incx, ROCMMemoryMutable(y), incy, &c, &s); } @@ -778,7 +768,7 @@ bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, double s) { return false; //return DoBlasInternal( - // wrap::hipblasDrot, stream, true /* = pointer_mode_host */, elem_count, + // wrap::rocblas_drot, stream, true /* = pointer_mode_host */, elem_count, // ROCMMemoryMutable(x), incx, ROCMMemoryMutable(y), incy, &c, &s); } @@ -787,7 +777,7 @@ bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory> *y, int incy, float c, float s) { return false; - //return DoBlasInternal(wrap::hipblasCsrot, stream, + //return DoBlasInternal(wrap::rocblas_csrot, stream, // true /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemoryMutable(x)), incx, // ROCMComplex(ROCMMemoryMutable(y)), incy, &c, &s); @@ -798,7 +788,7 @@ bool ROCMBlas::DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory> *y, int incy, double c, double s) { return false; - //return DoBlasInternal(wrap::hipblasZdrot, stream, + //return DoBlasInternal(wrap::rocblas_zdrot, stream, // true /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemoryMutable(x)), incx, // ROCMComplex(ROCMMemoryMutable(y)), incy, &c, &s); @@ -808,7 +798,7 @@ bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory *a, DeviceMemory *b, DeviceMemory *c, DeviceMemory *s) { return false; - //return DoBlasInternal(wrap::hipblasSrotg, stream, + //return DoBlasInternal(wrap::rocblas_srotg, stream, // false /* = pointer_mode_host */, ROCMMemoryMutable(a), // ROCMMemoryMutable(b), ROCMMemoryMutable(c), // ROCMMemoryMutable(s)); @@ -818,7 +808,7 @@ bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory *a, DeviceMemory *b, DeviceMemory *c, DeviceMemory *s) { return false; - //return DoBlasInternal(wrap::hipblasDrotg, stream, + //return DoBlasInternal(wrap::rocblas_drotg, stream, // false /* = pointer_mode_host */, // ROCMComplex(ROCMMemoryMutable(a)), ROCMMemoryMutable(b), // ROCMMemoryMutable(c), ROCMMemoryMutable(s)); @@ -830,7 +820,7 @@ bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory> *a, DeviceMemory> *s) { return false; //return DoBlasInternal( - // wrap::hipblasCrotg, stream, false /* = pointer_mode_host */, + // wrap::rocblas_crotg, stream, false /* = pointer_mode_host */, // ROCMComplex(ROCMMemoryMutable(a)), ROCMComplex(ROCMMemoryMutable(b)), // ROCMComplex(ROCMMemoryMutable(c)), ROCMComplex(ROCMMemoryMutable(s))); } @@ -841,7 +831,7 @@ bool ROCMBlas::DoBlasRotg(Stream *stream, DeviceMemory> *a, DeviceMemory> *s) { return false; //return DoBlasInternal( - // wrap::hipblasZrotg, stream, false /* = pointer_mode_host */, + // wrap::rocblas_zrotg, stream, false /* = pointer_mode_host */, // ROCMComplex(ROCMMemoryMutable(a)), ROCMComplex(ROCMMemoryMutable(b)), // ROCMComplex(ROCMMemoryMutable(c)), ROCMComplex(ROCMMemoryMutable(s))); } @@ -851,7 +841,7 @@ bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory *y, int incy, const DeviceMemory ¶m) { return false; - //return DoBlasInternal(wrap::hipblasSrotm, stream, + //return DoBlasInternal(wrap::rocblas_srotm, stream, // false /* = pointer_mode_host */, elem_count, // ROCMMemoryMutable(x), incx, ROCMMemoryMutable(y), incy, // ROCMMemory(param)); @@ -862,7 +852,7 @@ bool ROCMBlas::DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory *y, int incy, const DeviceMemory ¶m) { return false; - //return DoBlasInternal(wrap::hipblasDrotm, stream, + //return DoBlasInternal(wrap::rocblas_drotm, stream, // false /* = pointer_mode_host */, elem_count, // ROCMMemoryMutable(x), incx, ROCMMemoryMutable(y), incy, // ROCMMemory(param)); @@ -873,7 +863,7 @@ bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory *d1, const DeviceMemory &y1, DeviceMemory *param) { return false; - //return DoBlasInternal(wrap::hipblasSrotmg, stream, + //return DoBlasInternal(wrap::rocblas_srotmg, stream, // false /* = pointer_mode_host */, ROCMMemoryMutable(d1), // ROCMMemoryMutable(d2), ROCMMemoryMutable(x1), // ROCMMemory(y1), ROCMMemoryMutable(param)); @@ -884,7 +874,7 @@ bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory *d1, const DeviceMemory &y1, DeviceMemory *param) { return false; - //return DoBlasInternal(wrap::hipblasDrotmg, stream, + //return DoBlasInternal(wrap::rocblas_drotmg, stream, // false /* = pointer_mode_host */, ROCMMemoryMutable(d1), // ROCMMemoryMutable(d2), ROCMMemoryMutable(x1), // ROCMMemory(y1), ROCMMemoryMutable(param)); @@ -892,14 +882,14 @@ bool ROCMBlas::DoBlasRotmg(Stream *stream, DeviceMemory *d1, bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, DeviceMemory *x, int incx) { - return DoBlasInternal(wrap::hipblasSscal, stream, + return DoBlasInternal(wrap::rocblas_sscal, stream, true /* = pointer_mode_host */, elem_count, &alpha, ROCMMemoryMutable(x), incx); } bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha, DeviceMemory *x, int incx) { - return DoBlasInternal(wrap::hipblasDscal, stream, + return DoBlasInternal(wrap::rocblas_dscal, stream, true /* = pointer_mode_host */, elem_count, &alpha, ROCMMemoryMutable(x), incx); } @@ -908,7 +898,7 @@ bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, float alpha, DeviceMemory> *x, int incx) { return false; //return DoBlasInternal( - // wrap::hipblasCsscal, stream, true /* = pointer_mode_host */, elem_count, + // wrap::rocblas_csscal, stream, true /* = pointer_mode_host */, elem_count, // ROCMComplex(&alpha), ROCMComplex(ROCMMemoryMutable(x)), incx); } @@ -916,7 +906,7 @@ bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, double alpha, DeviceMemory> *x, int incx) { return false; //return DoBlasInternal( - // wrap::hipblasZdscal, stream, true /* = pointer_mode_host */, elem_count, + // wrap::rocblas_zdscal, stream, true /* = pointer_mode_host */, elem_count, // ROCMComplex(&alpha), ROCMComplex(ROCMMemoryMutable(x)), incx); } @@ -925,7 +915,7 @@ bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, DeviceMemory> *x, int incx) { return false; //return DoBlasInternal( - // wrap::hipblasCscal, stream, true /* = pointer_mode_host */, elem_count, + // wrap::rocblas_cscal, stream, true /* = pointer_mode_host */, elem_count, // ROCMComplex(&alpha), ROCMComplex(ROCMMemoryMutable(x)), incx); } @@ -934,33 +924,31 @@ bool ROCMBlas::DoBlasScal(Stream *stream, uint64 elem_count, DeviceMemory> *x, int incx) { return false; //return DoBlasInternal( - // wrap::hipblasZscal, stream, true /* = pointer_mode_host */, elem_count, + // wrap::rocblas_zscal, stream, true /* = pointer_mode_host */, elem_count, // ROCMComplex(&alpha), ROCMComplex(ROCMMemoryMutable(x)), incx); } bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory *x, int incx, DeviceMemory *y, int incy) { - return false; - //return DoBlasInternal(wrap::hipblasSswap, stream, - // true /* = pointer_mode_host */, elem_count, - // ROCMMemoryMutable(x), incx, ROCMMemoryMutable(y), incy); + return DoBlasInternal(wrap::rocblas_sswap, stream, + true /* = pointer_mode_host */, elem_count, + ROCMMemoryMutable(x), incx, ROCMMemoryMutable(y), incy); } bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory *x, int incx, DeviceMemory *y, int incy) { - return false; - //return DoBlasInternal(wrap::hipblasDswap, stream, - // true /* = pointer_mode_host */, elem_count, - // ROCMMemoryMutable(x), incx, ROCMMemoryMutable(y), incy); + return DoBlasInternal(wrap::rocblas_dswap, stream, + true /* = pointer_mode_host */, elem_count, + ROCMMemoryMutable(x), incx, ROCMMemoryMutable(y), incy); } bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory> *x, int incx, DeviceMemory> *y, int incy) { return false; - //return DoBlasInternal(wrap::hipblasCswap, stream, + //return DoBlasInternal(wrap::rocblas_cswap, stream, // true /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemoryMutable(x)), incx, // ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -970,7 +958,7 @@ bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory> *x, int incx, DeviceMemory> *y, int incy) { return false; - //return DoBlasInternal(wrap::hipblasZswap, stream, + //return DoBlasInternal(wrap::rocblas_zswap, stream, // true /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemoryMutable(x)), incx, // ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -979,19 +967,17 @@ bool ROCMBlas::DoBlasSwap(Stream *stream, uint64 elem_count, bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) { - return false; - //return DoBlasInternal(wrap::hipblasIsamax, stream, - // false /* = pointer_mode_host */, elem_count, - // ROCMMemory(x), incx, ROCMMemoryMutable(result)); + return DoBlasInternal(wrap::rocblas_isamax, stream, + false /* = pointer_mode_host */, elem_count, + ROCMMemory(x), incx, ROCMMemoryMutable(result)); } bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) { - return false; - //return DoBlasInternal(wrap::hipblasIdamax, stream, - // false /* = pointer_mode_host */, elem_count, - // ROCMMemory(x), incx, ROCMMemoryMutable(result)); + return DoBlasInternal(wrap::rocblas_idamax, stream, + false /* = pointer_mode_host */, elem_count, + ROCMMemory(x), incx, ROCMMemoryMutable(result)); } bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, @@ -999,7 +985,7 @@ bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, DeviceMemory *result) { return false; //return DoBlasInternal( - // wrap::hipblasIcamax, stream, false /* = pointer_mode_host */, elem_count, + // wrap::rocblas_icamax, stream, false /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, ROCMMemoryMutable(result)); } @@ -1008,26 +994,24 @@ bool ROCMBlas::DoBlasIamax(Stream *stream, uint64 elem_count, int incx, DeviceMemory *result) { return false; //return DoBlasInternal( - // wrap::hipblasIzamax, stream, false /* = pointer_mode_host */, elem_count, + // wrap::rocblas_izamax, stream, false /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, ROCMMemoryMutable(result)); } bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) { - return false; - //return DoBlasInternal( - // wrap::hipblasIsamin, stream, false /* = pointer_mode_host */, elem_count, - // ROCMComplex(ROCMMemory(x)), incx, ROCMMemoryMutable(result)); + return DoBlasInternal( + wrap::rocblas_isamin, stream, false /* = pointer_mode_host */, elem_count, + ROCMComplex(ROCMMemory(x)), incx, ROCMMemoryMutable(result)); } bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, const DeviceMemory &x, int incx, DeviceMemory *result) { - return false; - //return DoBlasInternal( - // wrap::hipblasIdamin, stream, false /* = pointer_mode_host */, elem_count, - // ROCMComplex(ROCMMemory(x)), incx, ROCMMemoryMutable(result)); + return DoBlasInternal( + wrap::rocblas_idamin, stream, false /* = pointer_mode_host */, elem_count, + ROCMComplex(ROCMMemory(x)), incx, ROCMMemoryMutable(result)); } bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, @@ -1035,7 +1019,7 @@ bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, DeviceMemory *result) { return false; //return DoBlasInternal( - // wrap::hipblasIcamin, stream, false /* = pointer_mode_host */, elem_count, + // wrap::rocblas_icamin, stream, false /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, ROCMMemoryMutable(result)); } @@ -1044,7 +1028,7 @@ bool ROCMBlas::DoBlasIamin(Stream *stream, uint64 elem_count, int incx, DeviceMemory *result) { return false; //return DoBlasInternal( - // wrap::hipblasIzamin, stream, false /* = pointer_mode_host */, elem_count, + // wrap::rocblas_izamin, stream, false /* = pointer_mode_host */, elem_count, // ROCMComplex(ROCMMemory(x)), incx, ROCMMemoryMutable(result)); } @@ -1055,7 +1039,7 @@ bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, DeviceMemory *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasSgbmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_sgbmv, stream, true /* = pointer_mode_host */, // ROCMBlasTranspose(trans), m, n, kl, ku, &alpha, ROCMMemory(a), lda, // ROCMMemory(x), incx, &beta, ROCMMemoryMutable(y), incy); } @@ -1067,7 +1051,7 @@ bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, DeviceMemory *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasDgbmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_dgbmv, stream, true /* = pointer_mode_host */, // ROCMBlasTranspose(trans), m, n, kl, ku, &alpha, ROCMMemory(a), lda, // ROCMMemory(x), incx, &beta, ROCMMemoryMutable(y), incy); } @@ -1081,7 +1065,7 @@ bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, DeviceMemory> *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasCgbmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_cgbmv, stream, true /* = pointer_mode_host */, // ROCMBlasTranspose(trans), m, n, kl, ku, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -1096,7 +1080,7 @@ bool ROCMBlas::DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, DeviceMemory> *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasZgbmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_zgbmv, stream, true /* = pointer_mode_host */, // ROCMBlasTranspose(trans), m, n, kl, ku, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -1107,7 +1091,7 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, int lda, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) { return DoBlasInternal( - wrap::hipblasSgemv, stream, true /* = pointer_mode_host */, + wrap::rocblas_sgemv, stream, true /* = pointer_mode_host */, ROCMBlasTranspose(trans), m, n, &alpha, ROCMMemory(a), lda, ROCMMemory(x), incx, &beta, ROCMMemoryMutable(y), incy); } @@ -1117,7 +1101,7 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, int lda, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy) { return DoBlasInternal( - wrap::hipblasDgemv, stream, true /* = pointer_mode_host */, + wrap::rocblas_dgemv, stream, true /* = pointer_mode_host */, ROCMBlasTranspose(trans), m, n, &alpha, ROCMMemory(a), lda, ROCMMemory(x), incx, &beta, ROCMMemoryMutable(y), incy); } @@ -1130,7 +1114,7 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, DeviceMemory> *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasCgemv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_Cgemv, stream, true /* = pointer_mode_host */, // ROCMBlasTranspose(trans), m, n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -1144,7 +1128,7 @@ bool ROCMBlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, DeviceMemory> *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasZgemv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_zgemv, stream, true /* = pointer_mode_host */, // ROCMBlasTranspose(trans), m, n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -1155,7 +1139,7 @@ bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, const DeviceMemory &y, int incy, DeviceMemory *a, int lda) { return DoBlasInternal( - wrap::hipblasSger, stream, true /* = pointer_mode_host */, m, n, &alpha, + wrap::rocblas_sger, stream, true /* = pointer_mode_host */, m, n, &alpha, ROCMMemory(x), incx, ROCMMemory(y), incy, ROCMMemoryMutable(a), lda); } @@ -1163,10 +1147,9 @@ bool ROCMBlas::DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, const DeviceMemory &x, int incx, const DeviceMemory &y, int incy, DeviceMemory *a, int lda) { - return false; - //return DoBlasInternal( - // wrap::hipblasDger, stream, true /* = pointer_mode_host */, m, n, &alpha, - // ROCMMemory(x), incx, ROCMMemory(y), incy, ROCMMemoryMutable(a), lda); + return DoBlasInternal( + wrap::rocblas_dger, stream, true /* = pointer_mode_host */, m, n, &alpha, + ROCMMemory(x), incx, ROCMMemory(y), incy, ROCMMemoryMutable(a), lda); } bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n, @@ -1176,7 +1159,7 @@ bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n, DeviceMemory> *a, int lda) { return false; //return DoBlasInternal( - // wrap::hipblasCgerc, stream, true /* = pointer_mode_host */, m, n, + // wrap::rocblas_cgerc, stream, true /* = pointer_mode_host */, m, n, // ROCMComplex(&alpha), ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(ROCMMemory(y)), incy, ROCMComplex(ROCMMemoryMutable(a)), lda); } @@ -1188,7 +1171,7 @@ bool ROCMBlas::DoBlasGerc(Stream *stream, uint64 m, uint64 n, DeviceMemory> *a, int lda) { return false; //return DoBlasInternal( - // wrap::hipblasZgerc, stream, true /* = pointer_mode_host */, m, n, + // wrap::rocblas_zgerc, stream, true /* = pointer_mode_host */, m, n, // ROCMComplex(&alpha), ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(ROCMMemory(y)), incy, ROCMComplex(ROCMMemoryMutable(a)), lda); } @@ -1200,7 +1183,7 @@ bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n, DeviceMemory> *a, int lda) { return false; //return DoBlasInternal( - // wrap::hipblasCgeru, stream, true /* = pointer_mode_host */, m, n, + // wrap::rocblas_cgeru, stream, true /* = pointer_mode_host */, m, n, // ROCMComplex(&alpha), ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(ROCMMemory(y)), incy, ROCMComplex(ROCMMemoryMutable(a)), lda); } @@ -1212,7 +1195,7 @@ bool ROCMBlas::DoBlasGeru(Stream *stream, uint64 m, uint64 n, DeviceMemory> *a, int lda) { return false; //return DoBlasInternal( - // wrap::hipblasZgeru, stream, true /* = pointer_mode_host */, m, n, + // wrap::rocblas_zgeru, stream, true /* = pointer_mode_host */, m, n, // ROCMComplex(&alpha), ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(ROCMMemory(y)), incy, ROCMComplex(ROCMMemoryMutable(a)), lda); } @@ -1225,7 +1208,7 @@ bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasChbmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_chbmv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, k, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -1239,7 +1222,7 @@ bool ROCMBlas::DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasZhbmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_zhbmv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, k, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -1253,7 +1236,7 @@ bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasChemv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_chemv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -1267,7 +1250,7 @@ bool ROCMBlas::DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasZhemv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_zhemv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -1279,7 +1262,7 @@ bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *a, int lda) { return false; //return DoBlasInternal( - // wrap::hipblasCher, stream, true /* = pointer_mode_host */, + // wrap::rocblas_cher, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, &alpha, ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(ROCMMemoryMutable(a)), lda); } @@ -1290,7 +1273,7 @@ bool ROCMBlas::DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *a, int lda) { return false; //return DoBlasInternal( - // wrap::hipblasZher, stream, true /* = pointer_mode_host */, + // wrap::rocblas_zher, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, &alpha, ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(ROCMMemoryMutable(a)), lda); } @@ -1302,7 +1285,7 @@ bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *a, int lda) { return false; //return DoBlasInternal( - // wrap::hipblasCher2, stream, true /* = pointer_mode_host */, + // wrap::rocblas_cher2, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(x)), incx, ROCMComplex(ROCMMemory(y)), incy, // ROCMComplex(ROCMMemoryMutable(a)), lda); @@ -1315,7 +1298,7 @@ bool ROCMBlas::DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *a, int lda) { return false; //return DoBlasInternal( - // wrap::hipblasZher2, stream, true /* = pointer_mode_host */, + // wrap::rocblas_zher2, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(x)), incx, ROCMComplex(ROCMMemory(y)), incy, // ROCMComplex(ROCMMemoryMutable(a)), lda); @@ -1329,7 +1312,7 @@ bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasChpmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_chpmv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(ap)), ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -1343,7 +1326,7 @@ bool ROCMBlas::DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasZhpmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_zhpmv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(ap)), ROCMComplex(ROCMMemory(x)), incx, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(y)), incy); @@ -1355,7 +1338,7 @@ bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *ap) { return false; //return DoBlasInternal( - // wrap::hipblasChpr, stream, true /* = pointer_mode_host */, + // wrap::rocblas_chpr, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(x)), incx, ROCMComplex(ROCMMemoryMutable(ap))); } @@ -1366,7 +1349,7 @@ bool ROCMBlas::DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *ap) { return false; //return DoBlasInternal( - // wrap::hipblasZhpr, stream, true /* = pointer_mode_host */, + // wrap::rocblas_zhpr, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(x)), incx, ROCMComplex(ROCMMemoryMutable(ap))); } @@ -1378,7 +1361,7 @@ bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *ap) { return false; //return DoBlasInternal( - // wrap::hipblasChpr2, stream, true /* = pointer_mode_host */, + // wrap::rocblas_chpr2, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(x)), incx, ROCMComplex(ROCMMemory(y)), incy, // ROCMComplex(ROCMMemoryMutable(ap))); @@ -1391,7 +1374,7 @@ bool ROCMBlas::DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, DeviceMemory> *ap) { return false; //return DoBlasInternal( - // wrap::hipblasZhpr2, stream, true /* = pointer_mode_host */, + // wrap::rocblas_zhpr2, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(x)), incx, ROCMComplex(ROCMMemory(y)), incy, // ROCMComplex(ROCMMemoryMutable(ap))); @@ -1403,7 +1386,7 @@ bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, float beta, DeviceMemory *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasSsbmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_ssbmv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, k, &alpha, ROCMMemory(a), lda, ROCMMemory(x), // incx, &beta, ROCMMemoryMutable(y), incy); } @@ -1414,7 +1397,7 @@ bool ROCMBlas::DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, double beta, DeviceMemory *y, int incy) { return false; //return DoBlasInternal( - // wrap::hipblasDsbmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_dsbmv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, k, &alpha, ROCMMemory(a), lda, ROCMMemory(x), // incx, &beta, ROCMMemoryMutable(y), incy); } @@ -1424,7 +1407,7 @@ bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) { return false; - //return DoBlasInternal(wrap::hipblasSspmv, stream, + //return DoBlasInternal(wrap::rocblas_sspmv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(ap), // ROCMMemory(x), incx, &beta, ROCMMemoryMutable(y), incy); @@ -1435,7 +1418,7 @@ bool ROCMBlas::DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy) { return false; - //return DoBlasInternal(wrap::hipblasDspmv, stream, + //return DoBlasInternal(wrap::rocblas_dspmv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(ap), // ROCMMemory(x), incx, &beta, ROCMMemoryMutable(y), incy); @@ -1445,7 +1428,7 @@ bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, const DeviceMemory &x, int incx, DeviceMemory *ap) { return false; - //return DoBlasInternal(wrap::hipblasSspr, stream, + //return DoBlasInternal(wrap::rocblas_sspr, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(x), // incx, ROCMMemoryMutable(ap)); @@ -1455,7 +1438,7 @@ bool ROCMBlas::DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, double alpha, const DeviceMemory &x, int incx, DeviceMemory *ap) { return false; - //return DoBlasInternal(wrap::hipblasDspr, stream, + //return DoBlasInternal(wrap::rocblas_dspr, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(x), // incx, ROCMMemoryMutable(ap)); @@ -1466,7 +1449,7 @@ bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &y, int incy, DeviceMemory *ap) { return false; - //return DoBlasInternal(wrap::hipblasSspr2, stream, + //return DoBlasInternal(wrap::rocblas_sspr2, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(x), // incx, ROCMMemory(y), incy, ROCMMemoryMutable(ap)); @@ -1477,7 +1460,7 @@ bool ROCMBlas::DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &y, int incy, DeviceMemory *ap) { return false; - //return DoBlasInternal(wrap::hipblasDspr2, stream, + //return DoBlasInternal(wrap::rocblas_dspr2, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(x), // incx, ROCMMemory(y), incy, ROCMMemoryMutable(ap)); @@ -1488,7 +1471,7 @@ bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &x, int incx, float beta, DeviceMemory *y, int incy) { return false; - //return DoBlasInternal(wrap::hipblasSsymv, stream, + //return DoBlasInternal(wrap::rocblas_ssymv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(a), lda, // ROCMMemory(x), incx, &beta, ROCMMemoryMutable(y), incy); @@ -1499,7 +1482,7 @@ bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &x, int incx, double beta, DeviceMemory *y, int incy) { return false; - //return DoBlasInternal(wrap::hipblasDsymv, stream, + //return DoBlasInternal(wrap::rocblas_dsymv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(a), lda, // ROCMMemory(x), incx, &beta, ROCMMemoryMutable(y), incy); @@ -1508,21 +1491,19 @@ bool ROCMBlas::DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, const DeviceMemory &x, int incx, DeviceMemory *a, int lda) { - return false; - //return DoBlasInternal(wrap::hipblasSsyr, stream, - // true /* = pointer_mode_host */, - // ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(x), - // incx, ROCMMemoryMutable(a), lda); + return DoBlasInternal(wrap::rocblas_ssyr, stream, + true /* = pointer_mode_host */, + ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(x), + incx, ROCMMemoryMutable(a), lda); } bool ROCMBlas::DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, double alpha, const DeviceMemory &x, int incx, DeviceMemory *a, int lda) { - return false; - //return DoBlasInternal(wrap::hipblasDsyr, stream, - // true /* = pointer_mode_host */, - // ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(x), - // incx, ROCMMemoryMutable(a), lda); + return DoBlasInternal(wrap::rocblas_dsyr, stream, + true /* = pointer_mode_host */, + ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(x), + incx, ROCMMemoryMutable(a), lda); } bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, @@ -1530,7 +1511,7 @@ bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &y, int incy, DeviceMemory *a, int lda) { return false; - //return DoBlasInternal(wrap::hipblasSsyr2, stream, + //return DoBlasInternal(wrap::rocblas_ssyr2, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(x), // incx, ROCMMemory(y), incy, ROCMMemoryMutable(a), lda); @@ -1541,7 +1522,7 @@ bool ROCMBlas::DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, const DeviceMemory &y, int incy, DeviceMemory *a, int lda) { return false; - //return DoBlasInternal(wrap::hipblasDsyr2, stream, + //return DoBlasInternal(wrap::rocblas_dsyr2, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), n, &alpha, ROCMMemory(x), // incx, ROCMMemory(y), incy, ROCMMemoryMutable(a), lda); @@ -1552,7 +1533,7 @@ bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, uint64 k, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasStbmv, stream, + //return DoBlasInternal(wrap::rocblas_stbmv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, k, ROCMMemory(a), lda, @@ -1564,7 +1545,7 @@ bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, uint64 k, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasDtbmv, stream, + //return DoBlasInternal(wrap::rocblas_dtbmv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, k, ROCMMemory(a), lda, @@ -1578,7 +1559,7 @@ bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, int incx) { return false; //return DoBlasInternal( - // wrap::hipblasCtbmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_ctbmv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, k, ROCMComplex(ROCMMemory(a)), lda, // ROCMComplex(ROCMMemoryMutable(x)), incx); @@ -1591,7 +1572,7 @@ bool ROCMBlas::DoBlasTbmv(Stream *stream, blas::UpperLower uplo, int incx) { return false; //return DoBlasInternal( - // wrap::hipblasZtbmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_ztbmv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, k, ROCMComplex(ROCMMemory(a)), lda, // ROCMComplex(ROCMMemoryMutable(x)), incx); @@ -1602,7 +1583,7 @@ bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, uint64 k, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasStbsv, stream, + //return DoBlasInternal(wrap::rocblas_stbsv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, k, ROCMMemory(a), lda, @@ -1614,7 +1595,7 @@ bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, uint64 k, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasDtbsv, stream, + //return DoBlasInternal(wrap::rocblas_dtbsv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, k, ROCMMemory(a), lda, @@ -1628,7 +1609,7 @@ bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, int incx) { return false; //return DoBlasInternal( - // wrap::hipblasCtbsv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_ctbsv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, k, ROCMComplex(ROCMMemory(a)), lda, // ROCMComplex(ROCMMemoryMutable(x)), incx); @@ -1641,7 +1622,7 @@ bool ROCMBlas::DoBlasTbsv(Stream *stream, blas::UpperLower uplo, int incx) { return false; //return DoBlasInternal( - // wrap::hipblasZtbsv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_ztbsv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, k, ROCMComplex(ROCMMemory(a)), lda, // ROCMComplex(ROCMMemoryMutable(x)), incx); @@ -1653,7 +1634,7 @@ bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, int incx) { return false; //return DoBlasInternal( - // wrap::hipblasStpmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_stpmv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMMemory(ap), ROCMMemoryMutable(x), incx); } @@ -1664,7 +1645,7 @@ bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, DeviceMemory *x, int incx) { return false; //return DoBlasInternal( - // wrap::hipblasDtpmv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_dtpmv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMMemory(ap), ROCMMemoryMutable(x), incx); } @@ -1674,7 +1655,7 @@ bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &ap, DeviceMemory> *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasCtpmv, stream, + //return DoBlasInternal(wrap::rocblas_ctpmv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMComplex(ROCMMemory(ap)), @@ -1686,7 +1667,7 @@ bool ROCMBlas::DoBlasTpmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &ap, DeviceMemory> *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasZtpmv, stream, + //return DoBlasInternal(wrap::rocblas_ztpmv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMComplex(ROCMMemory(ap)), @@ -1699,7 +1680,7 @@ bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, int incx) { return false; //return DoBlasInternal( - // wrap::hipblasStpsv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_stpsv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMMemory(ap), ROCMMemoryMutable(x), incx); } @@ -1710,7 +1691,7 @@ bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, DeviceMemory *x, int incx) { return false; //return DoBlasInternal( - // wrap::hipblasDtpsv, stream, true /* = pointer_mode_host */, + // wrap::rocblas_dtpsv, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMMemory(ap), ROCMMemoryMutable(x), incx); } @@ -1720,7 +1701,7 @@ bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &ap, DeviceMemory> *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasCtpsv, stream, + //return DoBlasInternal(wrap::rocblas_ctpsv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMComplex(ROCMMemory(ap)), @@ -1732,7 +1713,7 @@ bool ROCMBlas::DoBlasTpsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &ap, DeviceMemory> *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasZtpsv, stream, + //return DoBlasInternal(wrap::rocblas_ztpsv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMComplex(ROCMMemory(ap)), @@ -1744,7 +1725,7 @@ bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasStrmv, stream, + //return DoBlasInternal(wrap::rocblas_strmv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMMemory(a), lda, @@ -1756,7 +1737,7 @@ bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasDtrmv, stream, + //return DoBlasInternal(wrap::rocblas_dtrmv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMMemory(a), lda, @@ -1768,7 +1749,7 @@ bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasCtrmv, stream, + //return DoBlasInternal(wrap::rocblas_ctrmv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMComplex(ROCMMemory(a)), @@ -1780,7 +1761,7 @@ bool ROCMBlas::DoBlasTrmv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasZtrmv, stream, + //return DoBlasInternal(wrap::rocblas_ztrmv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMComplex(ROCMMemory(a)), @@ -1792,7 +1773,7 @@ bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasStrsv, stream, + //return DoBlasInternal(wrap::rocblas_strsv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMMemory(a), lda, @@ -1804,7 +1785,7 @@ bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory &a, int lda, DeviceMemory *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasDtrsv, stream, + //return DoBlasInternal(wrap::rocblas_dtrsv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMMemory(a), lda, @@ -1816,7 +1797,7 @@ bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasCtrsv, stream, + //return DoBlasInternal(wrap::rocblas_ctrsv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMComplex(ROCMMemory(a)), @@ -1828,7 +1809,7 @@ bool ROCMBlas::DoBlasTrsv(Stream *stream, blas::UpperLower uplo, const DeviceMemory> &a, int lda, DeviceMemory> *x, int incx) { return false; - //return DoBlasInternal(wrap::hipblasZtrsv, stream, + //return DoBlasInternal(wrap::rocblas_ztrsv, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), // ROCMBlasDiagonal(diag), n, ROCMComplex(ROCMMemory(a)), @@ -1843,7 +1824,7 @@ bool ROCMBlas::DoBlasGemm( DeviceMemory *c, int ldc) { #if ROCM_VERSION >= 7050 VLOG(1) << port::Printf( - "doing HIPBLAS SGEMM: at=%d bt=%d m=%llu n=%llu " + "doing rocBLAS SGEMM: at=%d bt=%d m=%llu n=%llu " "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f " "c=%p ldc=%d", static_cast(transa), static_cast(transb), m, n, k, alpha, @@ -1873,14 +1854,15 @@ bool ROCMBlas::DoBlasGemm( // TODO(sesse): Consider supporting the Hgemm interface, which uses half // calculations internally (faster on newer devices, such as Pascal and TX1, // but less precise). + // TODO (jmd): rocBLAS has a hgemm return false; //return DoBlasInternal( - // wrap::hipblasSgemmEx, stream, true /* = pointer_mode_host */, + // wrap::rocblas_SgemmEx, stream, true /* = pointer_mode_host */, // ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha, // ROCMMemory(a), SE_ROCM_DATA_HALF, lda, ROCMMemory(b), SE_ROCM_DATA_HALF, // ldb, &beta, ROCMMemoryMutable(c), SE_ROCM_DATA_HALF, ldc); #else - LOG(ERROR) << "fp16 sgemm is not implemented in this HIPBLAS version " + LOG(ERROR) << "fp16 sgemm is not implemented in this rocBLAS version " << "(need at least ROCM 7.5)"; return false; #endif @@ -1892,7 +1874,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, const DeviceMemory &b, int ldb, float beta, DeviceMemory *c, int ldc) { VLOG(1) << port::Printf( - "doing HIPBLAS SGEMM: at=%d bt=%d m=%llu n=%llu " + "doing rocBLAS SGEMM: at=%d bt=%d m=%llu n=%llu " "k=%llu alpha=%f a=%p lda=%d b=%p ldb=%d beta=%f " "c=%p ldc=%d", static_cast(transa), static_cast(transb), m, n, k, alpha, @@ -1920,7 +1902,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, } } return DoBlasInternal( - wrap::hipblasSgemm, stream, true /* = pointer_mode_host */, + wrap::rocblas_sgemm, stream, true /* = pointer_mode_host */, ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha, ROCMMemory(a), lda, ROCMMemory(b), ldb, &beta, ROCMMemoryMutable(c), ldc); } @@ -1931,7 +1913,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, const DeviceMemory &b, int ldb, double beta, DeviceMemory *c, int ldc) { return DoBlasInternal( - wrap::hipblasDgemm, stream, true /* = pointer_mode_host */, + wrap::rocblas_dgemm, stream, true /* = pointer_mode_host */, ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha, ROCMMemory(a), lda, ROCMMemory(b), ldb, &beta, ROCMMemoryMutable(c), ldc); } @@ -1945,7 +1927,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, DeviceMemory> *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasCgemm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_cgemm, stream, true /* = pointer_mode_host */, // ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, // ROCMComplex(&alpha), ROCMComplex(ROCMMemory(a)), lda, // ROCMComplex(ROCMMemory(b)), ldb, ROCMComplex(&beta), @@ -1961,7 +1943,7 @@ bool ROCMBlas::DoBlasGemm(Stream *stream, blas::Transpose transa, DeviceMemory> *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasZgemm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_zgemm, stream, true /* = pointer_mode_host */, // ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, // ROCMComplex(&alpha), ROCMComplex(ROCMMemory(a)), lda, // ROCMComplex(ROCMMemory(b)), ldb, ROCMComplex(&beta), @@ -2094,7 +2076,7 @@ bool ROCMBlas::DoBlasGemmWithAlgorithmImpl( DeviceMemory *c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { #if 0 -// ROCM < version 8 and GPUs < sm_50 don't support hipblasGemmEx. +// ROCM < version 8 and GPUs < sm_50 don't support rocblas_GemmEx. #if ROCM_VERSION < 8000 return false; #else @@ -2120,16 +2102,16 @@ bool ROCMBlas::DoBlasGemmWithAlgorithmImpl( } rocmDataType_t rocm_in_type = ROCMDataType::type; - // Since we are converting 'algorithm' to hipblasGemmAlgo_t by static_cast, + // Since we are converting 'algorithm' to rocblas_GemmAlgo_t by static_cast, // we do the following compile-time check on the default value: static_assert(blas::kDefaultGemmAlgo == HIPBLAS_GEMM_DFALT, ""); bool result = DoBlasInternalFailureOK( - wrap::hipblasGemmEx, stream, /* pointer_mode_host = */ true, + wrap::rocblas_GemmEx, stream, /* pointer_mode_host = */ true, ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, &alpha, ROCMMemory(a), rocm_in_type, lda, ROCMMemory(b), rocm_in_type, ldb, &beta, ROCMMemoryMutable(c), ROCMDataType::type, ldc, ROCMComputationType(computation_type), - static_cast(algorithm)); + static_cast(algorithm)); if (timer != nullptr && result) { // ROCMTimer will CHECK-fail if we Stop() it while the stream is in an error @@ -2150,10 +2132,10 @@ bool ROCMBlas::DoBlasGemmWithAlgorithmImpl( bool ROCMBlas::GetBlasGemmAlgorithms( std::vector *out_algorithms) { -// hipblasGemmAlgo_t (and the function that accepts this type, hipblasGemmEx) +// rocblas_GemmAlgo_t (and the function that accepts this type, rocblas_GemmEx) // were first introduced in ROCM 8. #if ROCM_VERSION >= 8000 - for (hipblasGemmAlgo_t algo : + for (rocblas_GemmAlgo_t algo : {HIPBLAS_GEMM_DFALT, HIPBLAS_GEMM_ALGO0, HIPBLAS_GEMM_ALGO1, HIPBLAS_GEMM_ALGO2, HIPBLAS_GEMM_ALGO3, HIPBLAS_GEMM_ALGO4, HIPBLAS_GEMM_ALGO5, HIPBLAS_GEMM_ALGO6, HIPBLAS_GEMM_ALGO7}) { @@ -2248,7 +2230,7 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm( template port::Status ROCMBlas::DoBlasGemmBatchedInternal( - FuncT hipblas_func, Stream *stream, blas::Transpose transa, + FuncT rocblas_func, Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, const port::ArraySlice *> &a_ptrs_to_wrappers, int lda, const port::ArraySlice *> &b_ptrs_to_wrappers, int ldb, @@ -2307,12 +2289,12 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal( assert(!(ldc < m || bsc < ldc * n)); - if (ROCMBlasTranspose(transa) == HIPBLAS_OP_N) + if (ROCMBlasTranspose(transa) == rocblas_operation_none) assert(!(lda < m || bsa < lda * k)); else assert(!(lda < k || bsa < lda * m)); - if (ROCMBlasTranspose(transb) == HIPBLAS_OP_N) + if (ROCMBlasTranspose(transb) == rocblas_operation_none) assert(!(ldb < k || bsb < ldb * n)); else assert(!(ldb < n || bsc < ldc * k)); @@ -2320,7 +2302,7 @@ port::Status ROCMBlas::DoBlasGemmBatchedInternal( if(bsa_is_constant && bsb_is_constant && bsc_is_constant) { bool ok = DoBlasInternal( - hipblas_func, stream, true /* = pointer_mode_host */, + rocblas_func, stream, true /* = pointer_mode_host */, ROCMBlasTranspose(transa), ROCMBlasTranspose(transb), m, n, k, ROCMComplex(&alpha), a_raw_ptrs[ 0 ], lda, bsa, b_raw_ptrs[ 0 ], ldb, bsb, ROCMComplex(&beta), @@ -2343,7 +2325,7 @@ bool ROCMBlas::DoBlasGemmBatched( const port::ArraySlice *> &c_array, int ldc, int batch_count, ScratchAllocator *scratch_allocator) { port::Status status = DoBlasGemmBatchedInternal( - wrap::hipblasSgemmStridedBatched, stream, transa, transb, m, n, k, alpha, + wrap::rocblas_sgemm_strided_batched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); if (!status.ok()) { @@ -2360,7 +2342,7 @@ bool ROCMBlas::DoBlasGemmBatched( double beta, const port::ArraySlice *> &c_array, int ldc, int batch_count, ScratchAllocator *scratch_allocator) { port::Status status = DoBlasGemmBatchedInternal( - wrap::hipblasDgemmStridedBatched, stream, transa, transb, m, n, k, alpha, + wrap::rocblas_dgemm_strided_batched, stream, transa, transb, m, n, k, alpha, a_array, lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); if (!status.ok()) { @@ -2380,7 +2362,7 @@ bool ROCMBlas::DoBlasGemmBatched( int ldc, int batch_count, ScratchAllocator *scratch_allocator) { return false; //port::Status status = DoBlasGemmBatchedInternal( - // wrap::hipblasCgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, + // wrap::rocblas_cgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, // lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); //if (!status.ok()) { // LOG(ERROR) << status; @@ -2399,7 +2381,7 @@ bool ROCMBlas::DoBlasGemmBatched( int ldc, int batch_count, ScratchAllocator *scratch_allocator) { return false; //port::Status status = DoBlasGemmBatchedInternal( - // wrap::hipblasZgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, + // wrap::rocblas_zgemmBatched, stream, transa, transb, m, n, k, alpha, a_array, // lda, b_array, ldb, beta, c_array, ldc, batch_count, scratch_allocator); //if (!status.ok()) { // LOG(ERROR) << status; @@ -2416,7 +2398,7 @@ bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side, DeviceMemory> *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasChemm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_chemm, stream, true /* = pointer_mode_host */, // ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemory(b)), ldb, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(c)), ldc); @@ -2431,7 +2413,7 @@ bool ROCMBlas::DoBlasHemm(Stream *stream, blas::Side side, DeviceMemory> *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasZhemm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_zhemm, stream, true /* = pointer_mode_host */, // ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemory(b)), ldb, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(c)), ldc); @@ -2444,7 +2426,7 @@ bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo, float beta, DeviceMemory> *c, int ldc) { return false; - //return DoBlasInternal(wrap::hipblasCherk, stream, + //return DoBlasInternal(wrap::rocblas_cherk, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, // k, ROCMComplex(&alpha), ROCMComplex(ROCMMemory(a)), lda, @@ -2458,7 +2440,7 @@ bool ROCMBlas::DoBlasHerk(Stream *stream, blas::UpperLower uplo, double beta, DeviceMemory> *c, int ldc) { return false; - //return DoBlasInternal(wrap::hipblasZherk, stream, + //return DoBlasInternal(wrap::rocblas_zherk, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, // k, ROCMComplex(&alpha), ROCMComplex(ROCMMemory(a)), lda, @@ -2473,7 +2455,7 @@ bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo, float beta, DeviceMemory> *c, int ldc) { return false; - //return DoBlasInternal(wrap::hipblasCher2k, stream, + //return DoBlasInternal(wrap::rocblas_cher2k, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, // k, ROCMComplex(&alpha), ROCMComplex(ROCMMemory(a)), lda, @@ -2489,7 +2471,7 @@ bool ROCMBlas::DoBlasHer2k(Stream *stream, blas::UpperLower uplo, double beta, DeviceMemory> *c, int ldc) { return false; - //return DoBlasInternal(wrap::hipblasZher2k, stream, + //return DoBlasInternal(wrap::rocblas_zher2k, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, // k, ROCMComplex(&alpha), ROCMComplex(ROCMMemory(a)), lda, @@ -2504,7 +2486,7 @@ bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, DeviceMemory *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasSsymm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_ssymm, stream, true /* = pointer_mode_host */, // ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, &alpha, ROCMMemory(a), // lda, ROCMMemory(b), ldb, &beta, ROCMMemoryMutable(c), ldc); } @@ -2516,7 +2498,7 @@ bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, DeviceMemory *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasDsymm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_dsymm, stream, true /* = pointer_mode_host */, // ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, &alpha, ROCMMemory(a), // lda, ROCMMemory(b), ldb, &beta, ROCMMemoryMutable(c), ldc); } @@ -2530,7 +2512,7 @@ bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, DeviceMemory> *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasCsymm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_csymm, stream, true /* = pointer_mode_host */, // ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemory(b)), ldb, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(c)), ldc); @@ -2545,7 +2527,7 @@ bool ROCMBlas::DoBlasSymm(Stream *stream, blas::Side side, DeviceMemory> *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasZsymm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_zsymm, stream, true /* = pointer_mode_host */, // ROCMBlasSide(side), ROCMBlasUpperLower(uplo), m, n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemory(b)), ldb, // ROCMComplex(&beta), ROCMComplex(ROCMMemoryMutable(c)), ldc); @@ -2557,7 +2539,7 @@ bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, float beta, DeviceMemory *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasSsyrk, stream, true /* = pointer_mode_host */, + // wrap::rocblas_ssyrk, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha, // ROCMMemory(a), lda, &beta, ROCMMemoryMutable(c), ldc); } @@ -2568,7 +2550,7 @@ bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, double beta, DeviceMemory *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasDsyrk, stream, true /* = pointer_mode_host */, + // wrap::rocblas_dsyrk, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha, // ROCMMemory(a), lda, &beta, ROCMMemoryMutable(c), ldc); } @@ -2581,7 +2563,7 @@ bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, DeviceMemory> *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasCsyrk, stream, true /* = pointer_mode_host */, + // wrap::rocblas_csyrk, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, // ROCMComplex(&alpha), ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(&beta), // ROCMComplex(ROCMMemoryMutable(c)), ldc); @@ -2595,7 +2577,7 @@ bool ROCMBlas::DoBlasSyrk(Stream *stream, blas::UpperLower uplo, DeviceMemory> *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasZsyrk, stream, true /* = pointer_mode_host */, + // wrap::rocblas_zsyrk, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, // ROCMComplex(&alpha), ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(&beta), // ROCMComplex(ROCMMemoryMutable(c)), ldc); @@ -2608,7 +2590,7 @@ bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, DeviceMemory *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasSsyr2k, stream, true /* = pointer_mode_host */, + // wrap::rocblas_ssyr2k, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha, // ROCMMemory(a), lda, ROCMMemory(b), ldb, &beta, ROCMMemoryMutable(c), ldc); } @@ -2620,7 +2602,7 @@ bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, DeviceMemory *c, int ldc) { return false; //return DoBlasInternal( - // wrap::hipblasDsyr2k, stream, true /* = pointer_mode_host */, + // wrap::rocblas_dsyr2k, stream, true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, k, &alpha, // ROCMMemory(a), lda, ROCMMemory(b), ldb, &beta, ROCMMemoryMutable(c), ldc); } @@ -2633,7 +2615,7 @@ bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, std::complex beta, DeviceMemory> *c, int ldc) { return false; - //return DoBlasInternal(wrap::hipblasCsyr2k, stream, + //return DoBlasInternal(wrap::rocblas_csyr2k, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, // k, ROCMComplex(&alpha), ROCMComplex(ROCMMemory(a)), lda, @@ -2649,7 +2631,7 @@ bool ROCMBlas::DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, std::complex beta, DeviceMemory> *c, int ldc) { return false; - //return DoBlasInternal(wrap::hipblasZsyr2k, stream, + //return DoBlasInternal(wrap::rocblas_zsyr2k, stream, // true /* = pointer_mode_host */, // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(trans), n, // k, ROCMComplex(&alpha), ROCMComplex(ROCMMemory(a)), lda, @@ -2664,7 +2646,7 @@ bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, DeviceMemory *b, int ldb) { return false; //return DoBlasInternal( - // wrap::hipblasStrmm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_strmm, stream, true /* = pointer_mode_host */, // ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), // ROCMBlasDiagonal(diag), m, n, &alpha, ROCMMemory(a), lda, // ROCMMemoryMutable(b), ldb, ROCMMemoryMutable(b), ldb); @@ -2677,7 +2659,7 @@ bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, DeviceMemory *b, int ldb) { return false; //return DoBlasInternal( - // wrap::hipblasDtrmm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_dtrmm, stream, true /* = pointer_mode_host */, // ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), // ROCMBlasDiagonal(diag), m, n, &alpha, ROCMMemory(a), lda, // ROCMMemoryMutable(b), ldb, ROCMMemoryMutable(b), ldb); @@ -2691,7 +2673,7 @@ bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, DeviceMemory> *b, int ldb) { return false; //return DoBlasInternal( - // wrap::hipblasCtrmm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_ctrmm, stream, true /* = pointer_mode_host */, // ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), // ROCMBlasDiagonal(diag), m, n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemoryMutable(b)), ldb, @@ -2706,7 +2688,7 @@ bool ROCMBlas::DoBlasTrmm(Stream *stream, blas::Side side, DeviceMemory> *b, int ldb) { return false; //return DoBlasInternal( - // wrap::hipblasZtrmm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_ztrmm, stream, true /* = pointer_mode_host */, // ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), // ROCMBlasDiagonal(diag), m, n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemoryMutable(b)), ldb, @@ -2719,7 +2701,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) { return false; - //return DoBlasInternal(wrap::hipblasStrsm, stream, + //return DoBlasInternal(wrap::rocblas_strsm, stream, // true /* = pointer_mode_host */, ROCMBlasSide(side), // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), // ROCMBlasDiagonal(diag), m, n, &alpha, ROCMMemory(a), @@ -2732,7 +2714,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) { return false; - //return DoBlasInternal(wrap::hipblasDtrsm, stream, + //return DoBlasInternal(wrap::rocblas_dtrsm, stream, // true /* = pointer_mode_host */, ROCMBlasSide(side), // ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), // ROCMBlasDiagonal(diag), m, n, &alpha, ROCMMemory(a), @@ -2747,7 +2729,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, DeviceMemory> *b, int ldb) { return false; //return DoBlasInternal( - // wrap::hipblasCtrsm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_ctrsm, stream, true /* = pointer_mode_host */, // ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), // ROCMBlasDiagonal(diag), m, n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemoryMutable(b)), ldb); @@ -2761,7 +2743,7 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, DeviceMemory> *b, int ldb) { return false; //return DoBlasInternal( - // wrap::hipblasZtrsm, stream, true /* = pointer_mode_host */, + // wrap::rocblas_ztrsm, stream, true /* = pointer_mode_host */, // ROCMBlasSide(side), ROCMBlasUpperLower(uplo), ROCMBlasTranspose(transa), // ROCMBlasDiagonal(diag), m, n, ROCMComplex(&alpha), // ROCMComplex(ROCMMemory(a)), lda, ROCMComplex(ROCMMemoryMutable(b)), ldb); @@ -2771,18 +2753,18 @@ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, namespace gpu = ::stream_executor; -void initialize_hipblas() { +void initialize_rocblas() { gpu::port::Status status = gpu::PluginRegistry::Instance() ->RegisterFactory( - gpu::rocm::kROCmPlatformId, gpu::rocm::kHipBlasPlugin, "HIPBLAS", + gpu::rocm::kROCmPlatformId, gpu::rocm::kRocBlasPlugin, "rocBLAS", [](gpu::internal::StreamExecutorInterface *parent) -> gpu::blas::BlasSupport * { gpu::rocm::ROCMExecutor *rocm_executor = dynamic_cast(parent); if (rocm_executor == nullptr) { LOG(ERROR) - << "Attempting to initialize an instance of the HIPBLAS " + << "Attempting to initialize an instance of the rocBLAS " << "support library with a non-ROCM StreamExecutor"; return nullptr; } @@ -2798,16 +2780,16 @@ void initialize_hipblas() { }); if (!status.ok()) { - LOG(ERROR) << "Unable to register HIPBLAS factory: " + LOG(ERROR) << "Unable to register rocBLAS factory: " << status.error_message(); } gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::rocm::kROCmPlatformId, gpu::PluginKind::kBlas, - gpu::rocm::kHipBlasPlugin); + gpu::rocm::kRocBlasPlugin); } } // namespace stream_executor -REGISTER_MODULE_INITIALIZER(register_hipblas, - { stream_executor::initialize_hipblas(); }); +REGISTER_MODULE_INITIALIZER(register_rocblas, + { stream_executor::initialize_rocblas(); }); diff --git a/tensorflow/stream_executor/rocm/rocm_blas.h b/tensorflow/stream_executor/rocm/rocm_blas.h index 9bce5a2b65ba6f..ad88c135e0950e 100644 --- a/tensorflow/stream_executor/rocm/rocm_blas.h +++ b/tensorflow/stream_executor/rocm/rocm_blas.h @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// ROCM-specific support for BLAS functionality -- this wraps the HIPBLAS library +// ROCM-specific support for BLAS functionality -- this wraps the rocBLAS library // capabilities, and is only included into ROCM implementation code -- it will // not introduce rocm headers into other code. @@ -33,18 +33,18 @@ class Stream; namespace rocm { -// Opaque and unique identifier for the HIPBLAS plugin. -extern const PluginId kHipBlasPlugin; +// Opaque and unique identifier for the rocBLAS plugin. +extern const PluginId kRocBlasPlugin; class ROCMExecutor; -// BLAS plugin for ROCM platform via HIPBLAS library. +// BLAS plugin for ROCM platform via rocBLAS library. // // This satisfies the platform-agnostic BlasSupport interface. // -// Note that the HIPBLAS handle that this encapsulates is implicitly tied to the +// Note that the rocBLAS handle that this encapsulates is implicitly tied to the // context (and, as a result, the device) that the parent ROCMExecutor is tied -// to. This simply happens as an artifact of creating the HIPBLAS handle when a +// to. This simply happens as an artifact of creating the rocBLAS handle when a // ROCM context is active. // // Thread-safe post-initialization. @@ -52,49 +52,49 @@ class ROCMBlas : public blas::BlasSupport { public: explicit ROCMBlas(ROCMExecutor *parent); - // Allocates a HIPBLAS handle. + // Allocates a rocBLAS handle. bool Init(); - // Releases the HIPBLAS handle, if present. + // Releases the rocBLAS handle, if present. ~ROCMBlas() override; TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES private: - // Tells HIPBLAS to enqueue the BLAS operation onto a particular Stream. + // Tells rocBLAS to enqueue the BLAS operation onto a particular Stream. // - // HIPBLAS is stateful, and only be associated with one stream (in order to + // rocBLAS is stateful, and only be associated with one stream (in order to // enqueue dispatch) at a given time. As a result, this generally must be - // invoked before calling into HIPBLAS. + // invoked before calling into rocBLAS. bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_); - // A helper function that calls the real HIPBLAS function together with error + // A helper function that calls the real rocBLAS function together with error // handling. // - // hipblas_func: HIPBLAS function pointer. - // hipblas_name: HIPBLAS function name. + // rocblas_func: rocBLAS function pointer. + // rocblas_name: rocBLAS function name. // stream: Stream to enqueue the BLAS operation onto. // pointer_mode_host: Indicate if the pointer to a scalar value is from host // (true) or device (false). - // err_on_failure: Whether to print an error if the hipblas function fails. - // args: Arguments of HIPBLAS function. + // err_on_failure: Whether to print an error if the rocBLAS function fails. + // args: Arguments of rocBLAS function. template - bool DoBlasInternalImpl(FuncT hipblas_func, Stream *stream, + bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, bool pointer_mode_host, bool err_on_failure, Args... args); // Convenience functions that call DoBlasInternalImpl with different values // for err_on_failure. template - bool DoBlasInternal(FuncT hipblas_func, Stream *stream, bool pointer_mode_host, + bool DoBlasInternal(FuncT rocblas_func, Stream *stream, bool pointer_mode_host, Args... args) { - return DoBlasInternalImpl(hipblas_func, stream, pointer_mode_host, + return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, /*err_on_failure=*/true, args...); } template - bool DoBlasInternalFailureOK(FuncT hipblas_func, Stream *stream, + bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream, bool pointer_mode_host, Args... args) { - return DoBlasInternalImpl(hipblas_func, stream, pointer_mode_host, + return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host, /*err_on_failure=*/false, args...); } @@ -102,7 +102,7 @@ class ROCMBlas : public blas::BlasSupport { // types. template port::Status DoBlasGemmBatchedInternal( - FuncT hipblas_func, Stream *stream, blas::Transpose transa, + FuncT rocblas_func, Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha, const port::ArraySlice *> &a_array, int lda, const port::ArraySlice *> &b_array, int ldb, T beta, @@ -113,7 +113,7 @@ class ROCMBlas : public blas::BlasSupport { // // We take alpha and beta by const reference because T might be Eigen::half, // and we want to avoid pulling in a dependency on Eigen. When we pass the - // references to hipblas, we essentially reinterpret_cast to __half, which is + // references to rocBLAS, we essentially reinterpret_cast to __half, which is // safe because Eigen::half inherits from __half. template bool DoBlasGemmWithAlgorithmImpl( @@ -141,15 +141,15 @@ class ROCMBlas : public blas::BlasSupport { const T &beta, DeviceMemory *y, int incy, blas::ProfileResult *output_profile_result); - // mutex that guards the HIPBLAS handle for this device. + // mutex that guards the rocBLAS handle for this device. mutex mu_; // ROCMExecutor which instantiated this ROCMBlas. // Immutable post-initialization. ROCMExecutor *parent_; - // HIPBLAS library handle on the device. - hipblasHandle_t blas_ GUARDED_BY(mu_); + // rocBLAS library handle on the device. + rocblas_handle blas_ GUARDED_BY(mu_); SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas); }; diff --git a/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl b/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl index cb3844a4f9869a..2a3a7a56fe2d3d 100644 --- a/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl +++ b/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl @@ -132,7 +132,7 @@ toolchain { cxx_builtin_include_directory: "/opt/rocm/hsa/include" cxx_builtin_include_directory: "/opt/rocm/include/hip" cxx_builtin_include_directory: "/opt/rocm/include/hip/hcc_detail" - cxx_builtin_include_directory: "/opt/rocm/hipblas/include" + cxx_builtin_include_directory: "/opt/rocm/rocblas/include" cxx_builtin_include_directory: "/opt/rocm/rocfft/include" cxx_builtin_include_directory: "/opt/rocm/hiprand/include" cxx_builtin_include_directory: "/opt/rocm/hcc/include" @@ -241,7 +241,7 @@ toolchain { cxx_builtin_include_directory: "/opt/rocm/hsa/include" cxx_builtin_include_directory: "/opt/rocm/include/hip" cxx_builtin_include_directory: "/opt/rocm/include/hip/hcc_detail" - cxx_builtin_include_directory: "/opt/rocm/hipblas/include" + cxx_builtin_include_directory: "/opt/rocm/rocblas/include" cxx_builtin_include_directory: "/opt/rocm/rocfft/include" cxx_builtin_include_directory: "/opt/rocm/hiprand/include" cxx_builtin_include_directory: "/opt/rocm/hcc/include" diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl index 0bb26e32f49dfc..c2be72a1822544 100644 --- a/third_party/gpus/rocm/BUILD.tpl +++ b/third_party/gpus/rocm/BUILD.tpl @@ -64,9 +64,9 @@ cc_library( ) cc_library( - name = "hipblas", - srcs = ["rocm/lib/%{hipblas_lib}"], - data = ["rocm/lib/%{hipblas_lib}"], + name = "rocblas", + srcs = ["rocm/lib/%{rocblas_lib}"], + data = ["rocm/lib/%{rocblas_lib}"], includes = [ ".", "rocm/include", @@ -118,7 +118,7 @@ cc_library( deps = [ ":rocm_headers", ":rocmrt", - ":hipblas", + ":rocblas", ":rocfft", ":hiprand", ":miopen", diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 3f479e6265a06f..7c23418e9b621a 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -128,8 +128,8 @@ def _host_compiler_includes(repository_ctx, cc): # Add rocfft headers inc_dirs.append("/opt/rocm/rocfft/include") - # Add hipblas headers - inc_dirs.append("/opt/rocm/hipblas/include") + # Add rocBLAS headers + inc_dirs.append("/opt/rocm/rocblas/include") # Add MIOpen headers inc_dirs.append("/opt/rocm/miopen/include") @@ -303,8 +303,8 @@ def _find_libs(repository_ctx, rocm_config): return { "hip": _find_rocm_lib( "hip_hcc", repository_ctx, cpu_value, rocm_config.rocm_toolkit_path), - "hipblas": _find_rocm_lib( - "hipblas", repository_ctx, cpu_value, rocm_config.rocm_toolkit_path), + "rocblas": _find_rocm_lib( + "rocblas", repository_ctx, cpu_value, rocm_config.rocm_toolkit_path), "rocfft": _find_rocm_lib( "rocfft", repository_ctx, cpu_value, rocm_config.rocm_toolkit_path + "/rocfft"), "hiprand": _find_rocm_lib( @@ -530,7 +530,7 @@ def _create_local_rocm_repository(repository_ctx): genrules.append(_symlink_genrule_for_dir(repository_ctx, rocm_toolkit_path + "/rocfft/include", "rocm/include/rocfft", "rocfft-include")) genrules.append(_symlink_genrule_for_dir(repository_ctx, - rocm_toolkit_path + "/hipblas/include", "rocm/include/hipblas", "hipblas-include")) + rocm_toolkit_path + "/rocblas/include", "rocm/include/rocblas", "rocblas-include")) genrules.append(_symlink_genrule_for_dir(repository_ctx, rocm_toolkit_path + "/miopen/include", "rocm/include/miopen", "miopen-include")) @@ -559,14 +559,14 @@ def _create_local_rocm_repository(repository_ctx): "%{rocmrt_static_lib}": rocm_libs["hip"].file_name, "%{rocmrt_static_linkopt}": '', "%{rocmrt_lib}": rocm_libs["hip"].file_name, - "%{hipblas_lib}": rocm_libs["hipblas"].file_name, + "%{rocblas_lib}": rocm_libs["rocblas"].file_name, "%{rocfft_lib}": rocm_libs["rocfft"].file_name, "%{hiprand_lib}": rocm_libs["hiprand"].file_name, "%{miopen_lib}": rocm_libs["miopen"].file_name, "%{rocm_include_genrules}": "\n".join(genrules), "%{rocm_headers}": ('":rocm-include",\n' + '":rocfft-include",\n' + - '":hipblas-include",\n' + + '":rocblas-include",\n' + '":miopen-include",'), }) # Set up crosstool/