Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IMP] move core CUDA RT macros to cuda_rt_essentials.hpp #1584

Merged
merged 2 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions cpp/include/raft/util/cuda_rt_essentials.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,38 @@ struct cuda_error : public raft::exception {
throw raft::cuda_error(msg); \
} \
} while (0)

/**
* @brief Debug macro to check for CUDA errors
*
* In a non-release build, this macro will synchronize the specified stream
* before error checking. In both release and non-release builds, this macro
* checks for any pending CUDA errors from previous calls. If an error is
* reported, an exception is thrown detailing the CUDA error that occurred.
*
* The intent of this macro is to provide a mechanism for synchronous and
* deterministic execution for debugging asynchronous CUDA execution. It should
* be used after any asynchronous CUDA call, e.g., cudaMemcpyAsync, or an
* asynchronous kernel launch.
*/
#ifndef NDEBUG
#define RAFT_CHECK_CUDA(stream) RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
#else
#define RAFT_CHECK_CUDA(stream) RAFT_CUDA_TRY(cudaPeekAtLastError());
#endif

// /**
// * @brief check for cuda runtime API errors but log error instead of raising
// * exception.
// */
#define RAFT_CUDA_TRY_NO_THROW(call) \
do { \
cudaError_t const status = call; \
if (cudaSuccess != status) { \
printf("CUDA call='%s' at file=%s line=%d failed with %s\n", \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One change request related to the header being self-standing. Can you add an include of cstdio to the top of cuda_rt_essentials.hpp?

This should have been done before, but probably escaped review.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, nothing was directly using cstdio in this file and cstdio was only included transitively through core/error.hpp for things like SET_ERROR_MSG.

So all was actually good before (IMO), and I fixed the include here now since we now use cstdio directly in this file.

Thanks for pointing it out !

#call, \
__FILE__, \
__LINE__, \
cudaGetErrorString(status)); \
} \
} while (0)
35 changes: 0 additions & 35 deletions cpp/include/raft/util/cudart_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,41 +34,6 @@
#include <mutex>
#include <string>

/**
* @brief Debug macro to check for CUDA errors
*
* In a non-release build, this macro will synchronize the specified stream
* before error checking. In both release and non-release builds, this macro
* checks for any pending CUDA errors from previous calls. If an error is
* reported, an exception is thrown detailing the CUDA error that occurred.
*
* The intent of this macro is to provide a mechanism for synchronous and
* deterministic execution for debugging asynchronous CUDA execution. It should
* be used after any asynchronous CUDA call, e.g., cudaMemcpyAsync, or an
* asynchronous kernel launch.
*/
#ifndef NDEBUG
#define RAFT_CHECK_CUDA(stream) RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
#else
#define RAFT_CHECK_CUDA(stream) RAFT_CUDA_TRY(cudaPeekAtLastError());
#endif

// /**
// * @brief check for cuda runtime API errors but log error instead of raising
// * exception.
// */
#define RAFT_CUDA_TRY_NO_THROW(call) \
do { \
cudaError_t const status = call; \
if (cudaSuccess != status) { \
printf("CUDA call='%s' at file=%s line=%d failed with %s\n", \
#call, \
__FILE__, \
__LINE__, \
cudaGetErrorString(status)); \
} \
} while (0)

namespace raft {

/** Helper method to get to know warp size in device code */
Expand Down