Skip to content

Commit

Permalink
move TeamScale and TeamVectorScale to KokkosBlas
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikołaj Zuzek committed Jun 23, 2022
1 parent ad4da89 commit 3046f2a
Show file tree
Hide file tree
Showing 14 changed files with 139 additions and 78 deletions.
16 changes: 9 additions & 7 deletions src/batched/dense/KokkosBatched_Scale_Decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

/// \author Kyungjoo Kim (kyukim@sandia.gov)

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Vector.hpp"

namespace KokkosBatched {

///
Expand All @@ -30,7 +27,10 @@ struct TeamScale {
template <typename ScalarType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member,
const ScalarType alpha,
const AViewType &A);
const AViewType &A) {
assert(false && "Deprecated: use KokkosBlas::TeamScale");
return 0;
}
};

///
Expand All @@ -42,11 +42,13 @@ struct TeamVectorScale {
template <typename ScalarType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member,
const ScalarType alpha,
const AViewType &A);
const AViewType &A) {
// static_assert(false);
assert(false && "Deprecated: use KokkosBlas::TeamVectorScale");
return 0;
}
};

} // namespace KokkosBatched

#include "KokkosBatched_Scale_Impl.hpp"

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ TeamVectorGemmInternal<Algo::Gemm::Unblocked, false>::invoke(
if (beta == zero)
TeamVectorSetInternal ::invoke(member, m, n, zero, C, cs0, cs1);
else if (beta != one)
TeamVectorScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);
KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, n, beta, C,
cs0, cs1);

if (alpha != ScalarType(0.0)) {
if (m <= 0 || n <= 0 || k <= 0) return 0;
Expand Down Expand Up @@ -80,7 +81,8 @@ TeamVectorGemmInternal<Algo::Gemm::Unblocked, true>::invoke(
if (beta == zero)
TeamVectorSetInternal ::invoke(member, m, n, zero, C, cs0, cs1);
else if (beta != one)
TeamVectorScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);
KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, n, beta, C,
cs0, cs1);

if (alpha != ScalarType(0.0)) {
if (m <= 0 || n <= 0 || k <= 0) return 0;
Expand Down
8 changes: 5 additions & 3 deletions src/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "KokkosKernels_ExecSpaceUtils.hpp"

#include "KokkosBatched_Set_Internal.hpp"
#include "KokkosBatched_Scale_Internal.hpp"
#include "KokkosBlas1_team_scal.hpp"

#include "KokkosBatched_InnerGemmFixC_Serial_Impl.hpp"

Expand Down Expand Up @@ -43,7 +43,8 @@ KOKKOS_INLINE_FUNCTION int TeamGemmInternal<Algo::Gemm::Unblocked>::invoke(
if (beta == zero)
TeamSetInternal ::invoke(member, m, n, zero, C, cs0, cs1);
else if (beta != one)
TeamScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, n, beta, C, cs0,
cs1);

if (alpha != ScalarType(0.0)) {
if (m <= 0 || n <= 0 || k <= 0) return 0;
Expand Down Expand Up @@ -84,7 +85,8 @@ KOKKOS_INLINE_FUNCTION int TeamGemmInternal<Algo::Gemm::Blocked>::invoke(
if (beta == zero)
TeamSetInternal ::invoke(member, m, n, zero, C, cs0, cs1);
else if (beta != one)
TeamScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, n, beta, C, cs0,
cs1);

if (alpha != ScalarType(0.0)) {
if (m <= 0 || n <= 0 || k <= 0) return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "KokkosBatched_Util.hpp"

#include "KokkosBatched_Set_Internal.hpp"
#include "KokkosBatched_Scale_Internal.hpp"
#include "KokkosBlas1_team_scal_impl.hpp"

#include "KokkosBatched_InnerMultipleDotProduct_Serial_Impl.hpp"

Expand Down Expand Up @@ -60,7 +60,7 @@ TeamVectorGemvInternal<Algo::Gemv::Unblocked>::invoke(
if (beta == zero)
TeamVectorSetInternal ::invoke(member, m, zero, y, ys0);
else if (beta != one)
TeamVectorScaleInternal::invoke(member, m, beta, y, ys0);
KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, beta, y, ys0);

if (alpha != zero) {
if (m <= 0 || n <= 0) return 0;
Expand Down
4 changes: 2 additions & 2 deletions src/batched/dense/impl/KokkosBatched_Gemv_Team_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ KOKKOS_INLINE_FUNCTION int TeamGemvInternal<Algo::Gemv::Unblocked>::invoke(
if (beta == zero)
TeamSetInternal ::invoke(member, m, zero, y, ys0);
else if (beta != one)
TeamScaleInternal::invoke(member, m, beta, y, ys0);
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, beta, y, ys0);

if (alpha != zero) {
if (m <= 0 || n <= 0) return 0;
Expand Down Expand Up @@ -88,7 +88,7 @@ KOKKOS_INLINE_FUNCTION int TeamGemvInternal<Algo::Gemv::Blocked>::invoke(
if (beta == zero)
TeamSetInternal ::invoke(member, m, zero, y, ys0);
else if (beta != one)
TeamScaleInternal::invoke(member, m, beta, y, ys0);
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, beta, y, ys0);

if (alpha != zero) {
if (m <= 0 || n <= 0) return 0;
Expand Down
38 changes: 0 additions & 38 deletions src/batched/dense/impl/KokkosBatched_Scale_Impl.hpp

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ TeamVectorTrsmInternalLeftLower<Algo::Trsm::Unblocked>::invoke(
TeamVectorSetInternal ::invoke(member, m, n, zero, B, bs0, bs1);
else {
if (alpha != one)
TeamVectorScaleInternal::invoke(member, m, n, alpha, B, bs0, bs1);
KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, n, alpha, B,
bs0, bs1);
if (m <= 0 || n <= 0) return 0;

for (int p = 0; p < m; ++p) {
Expand Down Expand Up @@ -98,7 +99,8 @@ TeamVectorTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::invoke(
TeamVectorSetInternal ::invoke(member, m, n, zero, B, bs0, bs1);
else {
if (alpha != one)
TeamVectorScaleInternal::invoke(member, m, n, alpha, B, bs0, bs1);
KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, n, alpha, B,
bs0, bs1);
if (m <= 0 || n <= 0) return 0;

ValueType *KOKKOS_RESTRICT B0 = B;
Expand Down
12 changes: 8 additions & 4 deletions src/batched/dense/impl/KokkosBatched_Trsm_Team_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ TeamTrsmInternalLeftLower<Algo::Trsm::Unblocked>::invoke(
TeamSetInternal ::invoke(member, m, n, zero, B, bs0, bs1);
else {
if (alpha != one)
TeamScaleInternal::invoke(member, m, n, alpha, B, bs0, bs1);
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, n, alpha, B, bs0,
bs1);
if (m <= 0 || n <= 0) return 0;

for (int p = 0; p < m; ++p) {
Expand Down Expand Up @@ -92,7 +93,8 @@ TeamTrsmInternalLeftLower<Algo::Trsm::Blocked>::invoke(
TeamSetInternal ::invoke(member, m, n, zero, B, bs0, bs1);
else {
if (alpha != one)
TeamScaleInternal::invoke(member, m, n, alpha, B, bs0, bs1);
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, n, alpha, B, bs0,
bs1);
if (m <= 0 || n <= 0) return 0;

///
Expand Down Expand Up @@ -175,7 +177,8 @@ TeamTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::invoke(
TeamSetInternal ::invoke(member, m, n, zero, B, bs0, bs1);
else {
if (alpha != one)
TeamScaleInternal::invoke(member, m, n, alpha, B, bs0, bs1);
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, n, alpha, B, bs0,
bs1);
if (m <= 0 || n <= 0) return 0;

ValueType *KOKKOS_RESTRICT B0 = B;
Expand Down Expand Up @@ -231,7 +234,8 @@ TeamTrsmInternalLeftUpper<Algo::Trsm::Blocked>::invoke(
TeamSetInternal ::invoke(member, m, n, zero, B, bs0, bs1);
else {
if (alpha != one)
TeamScaleInternal::invoke(member, m, n, alpha, B, bs0, bs1);
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, n, alpha, B, bs0,
bs1);
if (m <= 0 || n <= 0) return 0;

InnerTrsmLeftUpperUnitDiag<mbAlgo> trsm_u(as0, as1, bs0, bs1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ TeamVectorTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
if (alpha == zero)
TeamVectorSetInternal::invoke(member, m, zero, b, bs0);
else {
if (alpha != one) TeamVectorScaleInternal::invoke(member, m, alpha, b, bs0);
if (alpha != one)
KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, alpha, b,
bs0);
if (m <= 0) return 0;

for (int p = 0; p < m; ++p) {
Expand Down Expand Up @@ -106,7 +108,9 @@ TeamVectorTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
if (alpha == zero)
TeamVectorSetInternal::invoke(member, m, zero, b, bs0);
else {
if (alpha != one) TeamVectorScaleInternal::invoke(member, m, alpha, b, bs0);
if (alpha != one)
KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, alpha, b,
bs0);
if (m <= 0) return 0;

ValueType *KOKKOS_RESTRICT b0 = b;
Expand Down
12 changes: 8 additions & 4 deletions src/batched/dense/impl/KokkosBatched_Trsv_Team_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ KOKKOS_INLINE_FUNCTION int TeamTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
if (alpha == zero)
TeamSetInternal::invoke(member, m, zero, b, bs0);
else {
if (alpha != one) TeamScaleInternal::invoke(member, m, alpha, b, bs0);
if (alpha != one)
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, alpha, b, bs0);
if (m <= 0) return 0;

for (int p = 0; p < m; ++p) {
Expand Down Expand Up @@ -91,7 +92,8 @@ KOKKOS_INLINE_FUNCTION int TeamTrsvInternalLower<Algo::Trsv::Blocked>::invoke(
if (alpha == zero)
TeamSetInternal::invoke(member, m, zero, b, bs0);
else {
if (alpha != one) TeamScaleInternal::invoke(member, m, alpha, b, bs0);
if (alpha != one)
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, alpha, b, bs0);
if (m <= 0) return 0;

/// case GPU: team size is large and blocksize (mb,nb) is small
Expand Down Expand Up @@ -155,7 +157,8 @@ KOKKOS_INLINE_FUNCTION int TeamTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
if (alpha == zero)
TeamSetInternal::invoke(member, m, zero, b, bs0);
else {
if (alpha != one) TeamScaleInternal::invoke(member, m, alpha, b, bs0);
if (alpha != one)
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, alpha, b, bs0);
if (m <= 0) return 0;

ValueType *KOKKOS_RESTRICT b0 = b;
Expand Down Expand Up @@ -198,7 +201,8 @@ KOKKOS_INLINE_FUNCTION int TeamTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
if (alpha == zero)
TeamSetInternal::invoke(member, m, zero, b, bs0);
else {
if (alpha != one) TeamScaleInternal::invoke(member, m, alpha, b, bs0);
if (alpha != one)
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, alpha, b, bs0);
if (m <= 0) return 0;

InnerTrsmLeftUpperUnitDiag<mbAlgo> trsm_u(as0, as1, bs0, 0);
Expand Down
37 changes: 37 additions & 0 deletions src/blas/KokkosBlas1_team_scal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,46 @@
#ifndef KOKKOSBLAS1_TEAM_SCAL_HPP_
#define KOKKOSBLAS1_TEAM_SCAL_HPP_

#include <KokkosBlas1_team_scal_impl.hpp>

// TODO: deprecate/remove ?
#include <KokkosBlas1_team_scal_spec.hpp>

namespace KokkosBlas {

///
/// Team Scale
///

template <typename MemberType>
struct TeamScale {
template <typename ScalarType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType& member,
const ScalarType alpha,
const AViewType& A) {
return Impl::TeamScaleInternal::invoke(member, A.extent(0), A.extent(1),
alpha, A.data(), A.stride_0(),
A.stride_1());
}
};

///
/// TeamVector Scale
///

template <typename MemberType>
struct TeamVectorScale {
template <typename ScalarType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType& member,
const ScalarType alpha,
const AViewType& A) {
return Impl::TeamVectorScaleInternal::invoke(member, A.extent(0),
A.extent(1), alpha, A.data(),
A.stride_0(), A.stride_1());
}
};

// TODO: deprecate/remove ?
namespace Experimental {

template <class TeamType, class RVector, class XVector>
Expand Down
Loading

0 comments on commit 3046f2a

Please sign in to comment.