Skip to content

Commit

Permalink
fix issue kokkos#1212
Browse files Browse the repository at this point in the history
  • Loading branch information
kliegeois committed Dec 7, 2021
1 parent 133b7fc commit 5b2b64f
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/batched/sparse/KokkosBatched_CG.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ struct CG {
const MemberType &member, const OperatorType &A, const VectorViewType &B,
const VectorViewType &X,
const KrylovHandle<typename VectorViewType::non_const_value_type>
handle) {
&handle) {
int status = 0;
if (std::is_same<ArgMode, Mode::Team>::value) {
status =
Expand Down
2 changes: 1 addition & 1 deletion src/batched/sparse/KokkosBatched_GMRES.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ struct GMRES {
const MemberType &member, const OperatorType &A, const VectorViewType &B,
const VectorViewType &X,
const KrylovHandle<typename VectorViewType::non_const_value_type>
handle) {
&handle) {
int status = 0;
if (std::is_same<ArgMode, Mode::Team>::value) {
status =
Expand Down
7 changes: 4 additions & 3 deletions src/batched/sparse/impl/KokkosBatched_CG_TeamVector_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,14 @@ struct TeamVectorCG {
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType& member, const OperatorType& A, const VectorViewType& _B,
const VectorViewType& _X,
KrylovHandle<typename VectorViewType::non_const_value_type>* handle) {
const KrylovHandle<typename VectorViewType::non_const_value_type>&
handle) {
typedef int OrdinalType;
typedef typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type MagnitudeType;

const size_t maximum_iteration = handle->get_max_iteration();
const MagnitudeType tolerance = handle->get_tolerance();
const size_t maximum_iteration = handle.get_max_iteration();
const MagnitudeType tolerance = handle.get_tolerance();

using ScratchPadNormViewType = Kokkos::View<
MagnitudeType*,
Expand Down
7 changes: 4 additions & 3 deletions src/batched/sparse/impl/KokkosBatched_CG_Team_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,14 @@ struct TeamCG {
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType& member, const OperatorType& A, const VectorViewType& _B,
const VectorViewType& _X,
KrylovHandle<typename VectorViewType::non_const_value_type>* handle) {
const KrylovHandle<typename VectorViewType::non_const_value_type>&
handle) {
typedef int OrdinalType;
typedef typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type MagnitudeType;

size_t maximum_iteration = handle->get_max_iteration();
const MagnitudeType tolerance = handle->get_tolerance();
size_t maximum_iteration = handle.get_max_iteration();
const MagnitudeType tolerance = handle.get_tolerance();

using ScratchPadNormViewType = Kokkos::View<
MagnitudeType*,
Expand Down
13 changes: 7 additions & 6 deletions src/batched/sparse/impl/KokkosBatched_GMRES_TeamVector_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ struct TeamVectorGMRES {
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType& member, const OperatorType& A, const VectorViewType& _B,
const VectorViewType& _X,
KrylovHandle<typename VectorViewType::non_const_value_type>* handle) {
const KrylovHandle<typename VectorViewType::non_const_value_type>&
handle) {
typedef int OrdinalType;
typedef typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type MagnitudeType;
Expand All @@ -90,10 +91,10 @@ struct TeamVectorGMRES {
const OrdinalType numMatrices = _X.extent(0);
const OrdinalType numRows = _X.extent(1);

size_t maximum_iteration = handle->get_max_iteration() < numRows
? handle->get_max_iteration()
size_t maximum_iteration = handle.get_max_iteration() < numRows
? handle.get_max_iteration()
: numRows;
const MagnitudeType tolerance = handle->get_tolerance();
const MagnitudeType tolerance = handle.get_tolerance();
const MagnitudeType max_tolerance = 0.;

ScratchPadMultiVectorViewType V(member.team_scratch(1), numMatrices,
Expand Down Expand Up @@ -210,8 +211,8 @@ struct TeamVectorGMRES {
// Compute the new Givens rotation:
Kokkos::pair<typename VectorViewType::non_const_value_type,
typename VectorViewType::non_const_value_type>
G_new;
typename VectorViewType::non_const_value_type alpha;
G_new(1, 0);
typename VectorViewType::non_const_value_type alpha = 0;
SerialGivensInternal::invoke(H_j(j), H_j(j + 1), &G_new, &alpha);

Givens(l, j, 0) = G_new.first;
Expand Down
13 changes: 7 additions & 6 deletions src/batched/sparse/impl/KokkosBatched_GMRES_Team_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ struct TeamGMRES {
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType& member, const OperatorType& A, const VectorViewType& _B,
const VectorViewType& _X,
KrylovHandle<typename VectorViewType::non_const_value_type>* handle) {
const KrylovHandle<typename VectorViewType::non_const_value_type>&
handle) {
typedef int OrdinalType;
typedef typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type MagnitudeType;
Expand All @@ -89,10 +90,10 @@ struct TeamGMRES {
const OrdinalType numMatrices = _X.extent(0);
const OrdinalType numRows = _X.extent(1);

size_t maximum_iteration = handle->get_max_iteration() < numRows
? handle->get_max_iteration()
size_t maximum_iteration = handle.get_max_iteration() < numRows
? handle.get_max_iteration()
: numRows;
const MagnitudeType tolerance = handle->get_tolerance();
const MagnitudeType tolerance = handle.get_tolerance();
const MagnitudeType max_tolerance = 0.;

ScratchPadMultiVectorViewType V(member.team_scratch(1), numMatrices,
Expand Down Expand Up @@ -208,8 +209,8 @@ struct TeamGMRES {
// Compute the new Givens rotation:
Kokkos::pair<typename VectorViewType::non_const_value_type,
typename VectorViewType::non_const_value_type>
G_new;
typename VectorViewType::non_const_value_type alpha;
G_new(1, 0);
typename VectorViewType::non_const_value_type alpha = 0;
SerialGivensInternal::invoke(H_j(j), H_j(j + 1), &G_new, &alpha);

Givens(l, j, 0) = G_new.first;
Expand Down
6 changes: 2 additions & 4 deletions unit_test/batched/sparse/Test_Batched_TeamCG.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ struct Functor_TestBatchedTeamCG {
const VectorViewType _X;
const VectorViewType _B;
const int _N_team;
KrylovHandle<typename ValuesViewType::value_type> *handle;
KrylovHandle<typename ValuesViewType::value_type> handle;

KOKKOS_INLINE_FUNCTION
Functor_TestBatchedTeamCG(const ValuesViewType &D, const IntView &r,
const IntView &c, const VectorViewType &X,
const VectorViewType &B, const int N_team)
: _D(D), _r(r), _c(c), _X(X), _B(B), _N_team(N_team) {
handle = new KrylovHandle<typename ValuesViewType::value_type>;
}
: _D(D), _r(r), _c(c), _X(X), _B(B), _N_team(N_team) {}

template <typename MemberType>
KOKKOS_INLINE_FUNCTION void operator()(const MemberType &member) const {
Expand Down
10 changes: 4 additions & 6 deletions unit_test/batched/sparse/Test_Batched_TeamGMRES.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ struct Functor_TestBatchedTeamGMRES {
const VectorViewType _X;
const VectorViewType _B;
const int _N_team;
KrylovHandle<typename ValuesViewType::value_type> *handle;
KrylovHandle<typename ValuesViewType::value_type> handle;

KOKKOS_INLINE_FUNCTION
Functor_TestBatchedTeamGMRES(const ValuesViewType &D, const IntView &r,
const IntView &c, const VectorViewType &X,
const VectorViewType &B, const int N_team)
: _D(D), _r(r), _c(c), _X(X), _B(B), _N_team(N_team) {
handle = new KrylovHandle<typename ValuesViewType::value_type>;
}
: _D(D), _r(r), _c(c), _X(X), _B(B), _N_team(N_team) {}

template <typename MemberType>
KOKKOS_INLINE_FUNCTION void operator()(const MemberType &member) const {
Expand Down Expand Up @@ -69,9 +67,9 @@ struct Functor_TestBatchedTeamGMRES {
size_t bytes_0 = ValuesViewType::shmem_size(_N_team, _D.extent(1));
size_t bytes_1 = ValuesViewType::shmem_size(_N_team, 1);

handle->set_max_iteration(10);
handle.set_max_iteration(10);

int maximum_iteration = handle->get_max_iteration();
int maximum_iteration = handle.get_max_iteration();

policy.set_scratch_size(0, Kokkos::PerTeam(5 * bytes_0 + 5 * bytes_1));
policy.set_scratch_size(
Expand Down
6 changes: 2 additions & 4 deletions unit_test/batched/sparse/Test_Batched_TeamVectorCG.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ struct Functor_TestBatchedTeamVectorCG {
const VectorViewType _X;
const VectorViewType _B;
const int _N_team;
KrylovHandle<typename ValuesViewType::value_type> *handle;
KrylovHandle<typename ValuesViewType::value_type> handle;

KOKKOS_INLINE_FUNCTION
Functor_TestBatchedTeamVectorCG(const ValuesViewType &D, const IntView &r,
const IntView &c, const VectorViewType &X,
const VectorViewType &B, const int N_team)
: _D(D), _r(r), _c(c), _X(X), _B(B), _N_team(N_team) {
handle = new KrylovHandle<typename ValuesViewType::value_type>();
}
: _D(D), _r(r), _c(c), _X(X), _B(B), _N_team(N_team) {}

template <typename MemberType>
KOKKOS_INLINE_FUNCTION void operator()(const MemberType &member) const {
Expand Down
10 changes: 4 additions & 6 deletions unit_test/batched/sparse/Test_Batched_TeamVectorGMRES.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ struct Functor_TestBatchedTeamVectorGMRES {
const VectorViewType _X;
const VectorViewType _B;
const int _N_team;
KrylovHandle<typename ValuesViewType::value_type> *handle;
KrylovHandle<typename ValuesViewType::value_type> handle;

KOKKOS_INLINE_FUNCTION
Functor_TestBatchedTeamVectorGMRES(const ValuesViewType &D, const IntView &r,
const IntView &c, const VectorViewType &X,
const VectorViewType &B, const int N_team)
: _D(D), _r(r), _c(c), _X(X), _B(B), _N_team(N_team) {
handle = new KrylovHandle<typename ValuesViewType::value_type>;
}
: _D(D), _r(r), _c(c), _X(X), _B(B), _N_team(N_team) {}

template <typename MemberType>
KOKKOS_INLINE_FUNCTION void operator()(const MemberType &member) const {
Expand Down Expand Up @@ -69,9 +67,9 @@ struct Functor_TestBatchedTeamVectorGMRES {
size_t bytes_0 = ValuesViewType::shmem_size(_N_team, _D.extent(1));
size_t bytes_1 = ValuesViewType::shmem_size(_N_team, 1);

handle->set_max_iteration(10);
handle.set_max_iteration(10);

int maximum_iteration = handle->get_max_iteration();
int maximum_iteration = handle.get_max_iteration();

policy.set_scratch_size(0, Kokkos::PerTeam(5 * bytes_0 + 5 * bytes_1));
policy.set_scratch_size(
Expand Down

0 comments on commit 5b2b64f

Please sign in to comment.