Skip to content

Commit

Permalink
Separating cublas/cusolver macros from wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Feb 22, 2022
1 parent 44e7c79 commit 3c4c1a9
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 185 deletions.
116 changes: 116 additions & 0 deletions cpp/include/raft/linalg/cublas_macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/error.hpp>
#include <cublas_v2.h>

///@todo: enable this once we have logger enabled
//#include <cuml/common/logger.hpp>

#include <cstdint>

#define _CUBLAS_ERR_TO_STR(err) \
case err: return #err

namespace raft {

/**
* @brief Exception thrown when a cuBLAS error is encountered.
*/
struct cublas_error : public raft::exception {
explicit cublas_error(char const* const message) : raft::exception(message) {}
explicit cublas_error(std::string const& message) : raft::exception(message) {}
};

namespace linalg {
namespace detail {

inline const char* cublas_error_to_string(cublasStatus_t err)
{
switch (err) {
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_SUCCESS);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_NOT_INITIALIZED);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_ALLOC_FAILED);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_INVALID_VALUE);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_ARCH_MISMATCH);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_MAPPING_ERROR);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_EXECUTION_FAILED);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_INTERNAL_ERROR);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_NOT_SUPPORTED);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_LICENSE_ERROR);
default: return "CUBLAS_STATUS_UNKNOWN";
};
}

} // namespace detail
} // namespace linalg
} // namespace raft

#undef _CUBLAS_ERR_TO_STR

/**
* @brief Error checking macro for cuBLAS runtime API functions.
*
* Invokes a cuBLAS runtime API function call, if the call does not return
* CUBLAS_STATUS_SUCCESS, throws an exception detailing the cuBLAS error that occurred
*/
#define RAFT_CUBLAS_TRY(call) \
do { \
cublasStatus_t const status = (call); \
if (CUBLAS_STATUS_SUCCESS != status) { \
std::string msg{}; \
SET_ERROR_MSG(msg, \
"cuBLAS error encountered at: ", \
"call='%s', Reason=%d:%s", \
#call, \
status, \
raft::linalg::detail::cublas_error_to_string(status)); \
throw raft::cublas_error(msg); \
} \
} while (0)

// FIXME: Remove after consumers rename
#ifndef CUBLAS_TRY
#define CUBLAS_TRY(call) RAFT_CUBLAS_TRY(call)
#endif

// /**
// * @brief check for cuda runtime API errors but log error instead of raising
// * exception.
// */
#define RAFT_CUBLAS_TRY_NO_THROW(call) \
do { \
cublasStatus_t const status = call; \
if (CUBLAS_STATUS_SUCCESS != status) { \
printf("CUBLAS call='%s' at file=%s line=%d failed with %s\n", \
#call, \
__FILE__, \
__LINE__, \
raft::linalg::detail::cublas_error_to_string(status)); \
} \
} while (0)

/** FIXME: remove after cuml rename */
#ifndef CUBLAS_CHECK
#define CUBLAS_CHECK(call) CUBLAS_TRY(call)
#endif

/** FIXME: remove after cuml rename */
#ifndef CUBLAS_CHECK_NO_THROW
#define CUBLAS_CHECK_NO_THROW(call) RAFT_CUBLAS_TRY_NO_THROW(call)
#endif
112 changes: 112 additions & 0 deletions cpp/include/raft/linalg/cusolver_macros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cusolverDn.h>
#include <cusolverSp.h>
///@todo: enable this once logging is enabled
//#include <cuml/common/logger.hpp>
#include <raft/cudart_utils.h>
#include <type_traits>

#define _CUSOLVER_ERR_TO_STR(err) \
case err: return #err;

namespace raft {

/**
* @brief Exception thrown when a cuSOLVER error is encountered.
*/
struct cusolver_error : public raft::exception {
explicit cusolver_error(char const* const message) : raft::exception(message) {}
explicit cusolver_error(std::string const& message) : raft::exception(message) {}
};

namespace linalg {

inline const char* cusolver_error_to_string(cusolverStatus_t err)
{
switch (err) {
_CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_SUCCESS);
_CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_NOT_INITIALIZED);
_CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_ALLOC_FAILED);
_CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_INVALID_VALUE);
_CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_ARCH_MISMATCH);
_CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_EXECUTION_FAILED);
_CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_INTERNAL_ERROR);
_CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED);
_CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_ZERO_PIVOT);
_CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_NOT_SUPPORTED);
default: return "CUSOLVER_STATUS_UNKNOWN";
};
}

} // namespace linalg
} // namespace raft

#undef _CUSOLVER_ERR_TO_STR

/**
* @brief Error checking macro for cuSOLVER runtime API functions.
*
* Invokes a cuSOLVER runtime API function call, if the call does not return
* CUSolver_STATUS_SUCCESS, throws an exception detailing the cuSOLVER error that occurred
*/
#define RAFT_CUSOLVER_TRY(call) \
do { \
cusolverStatus_t const status = (call); \
if (CUSOLVER_STATUS_SUCCESS != status) { \
std::string msg{}; \
SET_ERROR_MSG(msg, \
"cuSOLVER error encountered at: ", \
"call='%s', Reason=%d:%s", \
#call, \
status, \
raft::linalg::detail::cusolver_error_to_string(status)); \
throw raft::cusolver_error(msg); \
} \
} while (0)

// FIXME: remove after consumer rename
#ifndef CUSOLVER_TRY
#define CUSOLVER_TRY(call) RAFT_CUSOLVER_TRY(call)
#endif

// /**
// * @brief check for cuda runtime API errors but log error instead of raising
// * exception.
// */
#define RAFT_CUSOLVER_TRY_NO_THROW(call) \
do { \
cusolverStatus_t const status = call; \
if (CUSOLVER_STATUS_SUCCESS != status) { \
printf("CUSOLVER call='%s' at file=%s line=%d failed with %s\n", \
#call, \
__FILE__, \
__LINE__, \
raft::linalg::detail::cusolver_error_to_string(status)); \
} \
} while (0)

// FIXME: remove after cuml rename
#ifndef CUSOLVER_CHECK
#define CUSOLVER_CHECK(call) CUSOLVER_TRY(call)
#endif

#ifndef CUSOLVER_CHECK_NO_THROW
#define CUSOLVER_CHECK_NO_THROW(call) CUSOLVER_TRY_NO_THROW(call)
#endif
96 changes: 1 addition & 95 deletions cpp/include/raft/linalg/detail/cublas_wrappers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,103 +17,9 @@
#pragma once

#include <raft/error.hpp>

#include <raft/linalg/cublas_macros.h>
#include <cublas_v2.h>
///@todo: enable this once we have logger enabled
//#include <cuml/common/logger.hpp>

#include <cstdint>

#define _CUBLAS_ERR_TO_STR(err) \
case err: return #err

namespace raft {

/**
* @brief Exception thrown when a cuBLAS error is encountered.
*/
struct cublas_error : public raft::exception {
explicit cublas_error(char const* const message) : raft::exception(message) {}
explicit cublas_error(std::string const& message) : raft::exception(message) {}
};

namespace linalg {
namespace detail {

inline const char* cublas_error_to_string(cublasStatus_t err)
{
switch (err) {
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_SUCCESS);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_NOT_INITIALIZED);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_ALLOC_FAILED);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_INVALID_VALUE);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_ARCH_MISMATCH);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_MAPPING_ERROR);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_EXECUTION_FAILED);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_INTERNAL_ERROR);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_NOT_SUPPORTED);
_CUBLAS_ERR_TO_STR(CUBLAS_STATUS_LICENSE_ERROR);
default: return "CUBLAS_STATUS_UNKNOWN";
};
}

} // namespace detail
} // namespace linalg
} // namespace raft

#undef _CUBLAS_ERR_TO_STR

/**
* @brief Error checking macro for cuBLAS runtime API functions.
*
* Invokes a cuBLAS runtime API function call, if the call does not return
* CUBLAS_STATUS_SUCCESS, throws an exception detailing the cuBLAS error that occurred
*/
#define RAFT_CUBLAS_TRY(call) \
do { \
cublasStatus_t const status = (call); \
if (CUBLAS_STATUS_SUCCESS != status) { \
std::string msg{}; \
SET_ERROR_MSG(msg, \
"cuBLAS error encountered at: ", \
"call='%s', Reason=%d:%s", \
#call, \
status, \
raft::linalg::detail::cublas_error_to_string(status)); \
throw raft::cublas_error(msg); \
} \
} while (0)

// FIXME: Remove after consumers rename
#ifndef CUBLAS_TRY
#define CUBLAS_TRY(call) RAFT_CUBLAS_TRY(call)
#endif

// /**
// * @brief check for cuda runtime API errors but log error instead of raising
// * exception.
// */
#define RAFT_CUBLAS_TRY_NO_THROW(call) \
do { \
cublasStatus_t const status = call; \
if (CUBLAS_STATUS_SUCCESS != status) { \
printf("CUBLAS call='%s' at file=%s line=%d failed with %s\n", \
#call, \
__FILE__, \
__LINE__, \
raft::linalg::detail::cublas_error_to_string(status)); \
} \
} while (0)

/** FIXME: remove after cuml rename */
#ifndef CUBLAS_CHECK
#define CUBLAS_CHECK(call) CUBLAS_TRY(call)
#endif

/** FIXME: remove after cuml rename */
#ifndef CUBLAS_CHECK_NO_THROW
#define CUBLAS_CHECK_NO_THROW(call) RAFT_CUBLAS_TRY_NO_THROW(call)
#endif

namespace raft {
namespace linalg {
Expand Down
Loading

0 comments on commit 3c4c1a9

Please sign in to comment.