Skip to content

Commit

Permalink
Review update: update documentation and formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Sep 18, 2019
1 parent 01eadd0 commit 3517023
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 20 deletions.
14 changes: 13 additions & 1 deletion cuda/base/device_guard.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define GKO_CUDA_BASE_DEVICE_GUARD_HPP_


#include <exception>


#include <cuda_runtime.h>


Expand All @@ -43,6 +46,14 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
namespace gko {


/**
* This class defines a device guard for the cuda functions and the cuda module.
* The guard is used to make sure that the device code is run on the correct
* cuda device, when run with multiple devices. The class records the current
* device id and uses `cudaSetDevice` to set the device id to the one being
* passed in. After the scope has been exited, the destructor sets the device_id
* back to the one before entering the scope.
*/
class device_guard {
public:
device_guard(int device_id)
Expand Down Expand Up @@ -76,4 +87,5 @@ class device_guard {

} // namespace gko

#endif

#endif // GKO_CUDA_BASE_DEVICE_GUARD_HPP_
22 changes: 21 additions & 1 deletion cuda/base/pointer_mode_guard.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#define GKO_CUDA_BASE_POINTER_MODE_GUARD_HPP_


#include <exception>


#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cusparse.h>
Expand All @@ -48,6 +51,14 @@ namespace cuda {
namespace cublas {


/**
* This class defines a pointer mode guard for the cuda functions and the cuda
* module. The guard is used to make sure that the correct pointer mode has been
* set when using scalars for the cublas functions. The class records the
* current handle and sets the pointer mode to host for the current scope. After
* the scope has been exited, the destructor sets the pointer mode back to
* device.
*/
class pointer_mode_guard {
public:
pointer_mode_guard(cublasHandle_t &handle)
Expand Down Expand Up @@ -87,6 +98,14 @@ class pointer_mode_guard {
namespace cusparse {


/**
* This class defines a pointer mode guard for the cuda functions and the cuda
* module. The guard is used to make sure that the correct pointer mode has been
* set when using scalars for the cusparse functions. The class records the
* current handle and sets the pointer mode to host for the current scope. After
* the scope has been exited, the destructor sets the pointer mode back to
* device.
*/
class pointer_mode_guard {
public:
pointer_mode_guard(cusparseHandle_t &handle)
Expand Down Expand Up @@ -125,4 +144,5 @@ class pointer_mode_guard {
} // namespace kernels
} // namespace gko

#endif

#endif // GKO_CUDA_BASE_POINTER_MODE_GUARD_HPP_
4 changes: 2 additions & 2 deletions cuda/solver/common_trs_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ void should_perform_transpose_kernel(std::shared_ptr<const CudaExecutor> exec,
void init_struct_kernel(std::shared_ptr<const CudaExecutor> exec,
std::shared_ptr<solver::SolveStruct> &solve_struct)
{
solve_struct =
std::shared_ptr<solver::SolveStruct>(new solver::SolveStruct());
solve_struct = std::make_shared<solver::SolveStruct>();
}


Expand Down Expand Up @@ -192,6 +191,7 @@ void solve_kernel(std::shared_ptr<const CudaExecutor> exec,
solve_struct->policy, solve_struct->factor_work_vec);
}


#elif (defined(CUDA_VERSION) && (CUDA_VERSION < 9020))


Expand Down
8 changes: 0 additions & 8 deletions cuda/test/solver/lower_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,9 @@ TEST_F(LowerTrs, CudaLowerTrsFlagCheckIsCorrect)
{
bool trans_flag = true;
bool expected_flag = false;


#if (defined(CUDA_VERSION) && (CUDA_VERSION < 9020))


expected_flag = true;


#endif


gko::kernels::cuda::lower_trs::should_perform_transpose(cuda, trans_flag);

ASSERT_EQ(expected_flag, trans_flag);
Expand Down
8 changes: 0 additions & 8 deletions cuda/test/solver/upper_trs_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,9 @@ TEST_F(UpperTrs, CudaUpperTrsFlagCheckIsCorrect)
bool trans_flag = true;
bool expected_flag = false;


#if (defined(CUDA_VERSION) && (CUDA_VERSION < 9020))


expected_flag = true;


#endif


gko::kernels::cuda::upper_trs::should_perform_transpose(cuda, trans_flag);

ASSERT_EQ(expected_flag, trans_flag);
Expand All @@ -162,7 +155,6 @@ TEST_F(UpperTrs, CudaSingleRhsApplyIsEquivalentToRef)
TEST_F(UpperTrs, CudaMultipleRhsApplyIsEquivalentToRef)
{
initialize_data(50, 3);

auto upper_trs_factory =
gko::solver::UpperTrs<>::build().with_num_rhs(3u).on(ref);
auto d_upper_trs_factory =
Expand Down

0 comments on commit 3517023

Please sign in to comment.