Skip to content

Commit

Permalink
Debug commit, circ dependency.
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Aug 29, 2019
1 parent 949118a commit 7992508
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 21 deletions.
7 changes: 3 additions & 4 deletions cuda/base/cusparse_bindings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,7 @@ GKO_BIND_CUSPARSE64_CSRSM_ANALYSIS(ValueType, detail::not_implemented);
size_type n, const ValueType *one, const cusparseMatDescr_t descr, \
const ValueType *csrVal, const int32 *csrRowPtr, \
const int32 *csrColInd, cusparseSolveAnalysisInfo_t factor_info, \
const ValueType *rhs, int32 rhs_stride, ValueType *sol, \
int32 sol_stride) \
ValueType *rhs, int32 rhs_stride, ValueType *sol, int32 sol_stride) \
{ \
GKO_ASSERT_NO_CUSPARSE_ERRORS( \
CusparseName(handle, trans, m, n, as_culibs_type(one), descr, \
Expand All @@ -711,8 +710,8 @@ GKO_BIND_CUSPARSE64_CSRSM_ANALYSIS(ValueType, detail::not_implemented);
size_type n, const ValueType *one, const cusparseMatDescr_t descr, \
const ValueType *csrVal, const int64 *csrRowPtr, \
const int64 *csrColInd, cusparseSolveAnalysisInfo_t factor_info, \
const ValueType *rhs, int64 rhs_stride, ValueType *sol, \
int64 sol_stride) GKO_NOT_IMPLEMENTED; \
ValueType *rhs, int64 rhs_stride, ValueType *sol, int64 sol_stride) \
GKO_NOT_IMPLEMENTED; \
static_assert(true, \
"This assert is used to counter the false positive extra " \
"semi-colon warnings")
Expand Down
34 changes: 17 additions & 17 deletions cuda/solver/lower_trs_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -204,29 +204,29 @@ void solve(std::shared_ptr<const CudaExecutor> exec,
GKO_ASSERT_NO_CUSPARSE_ERRORS(
cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_HOST));
if (b->get_stride() == 1) {
auto temp_b = const_cast<ValueType *>(b->get_const_values());
cusparse::csrsm_solve(
handle, CUSPARSE_OPERATION_NON_TRANSPOSE, matrix->get_size()[0],
b->get_stride(), &one, cusp_csrsm_data.factor_descr,
matrix->get_const_values(), matrix->get_const_row_ptrs(),
matrix->get_const_col_idxs(), cusp_csrsm_data.solve_info,
b->get_const_values(), b->get_size()[0], x->get_values(),
x->get_size()[0]);
temp_b, b->get_size()[0], x->get_values(), x->get_size()[0]);
} else {
auto t_b = vec::create(exec);
t_b->copy_from(static_cast<const matrix::Dense<ValueType> *>(
b->transpose().get()));
auto t_x = vec::create(exec);
t_x->copy_from(
static_cast<matrix::Dense<ValueType> *>(x->transpose().get()));
cusparse::csrsm_solve(
handle, CUSPARSE_OPERATION_NON_TRANSPOSE, matrix->get_size()[0],
t_b->get_size()[0], &one, cusp_csrsm_data.factor_descr,
matrix->get_const_values(), matrix->get_const_row_ptrs(),
matrix->get_const_col_idxs(), cusp_csrsm_data.solve_info,
t_b->get_const_values(), t_b->get_size()[1], t_x->get_values(),
t_x->get_size()[1]);
x->copy_from(static_cast<matrix::Dense<ValueType> *>(
t_x->transpose().get()));
gko::size_type shift = 0;
auto temp_b = const_cast<ValueType *>(b->get_const_values());
auto temp_x = x->get_values();
for (gko::size_type nrhs = 0; nrhs < b->get_stride(); ++nrhs) {
auto temp_b2 = temp_b + shift;
auto temp_x2 = temp_x + shift;
cusparse::csrsm_solve(
handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
matrix->get_size()[0], b->get_stride(), &one,
cusp_csrsm_data.factor_descr, matrix->get_const_values(),
matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(),
cusp_csrsm_data.solve_info, temp_b2, b->get_size()[0],
temp_x2, x->get_size()[0]);
shift++;
}
}
GKO_ASSERT_NO_CUSPARSE_ERRORS(
cusparseSetPointerMode(handle, CUSPARSE_POINTER_MODE_DEVICE));
Expand Down

0 comments on commit 7992508

Please sign in to comment.