Skip to content

Commit

Permalink
Merge pull request #338 from kokkos/getrs
Browse files Browse the repository at this point in the history
getrs implementation
  • Loading branch information
ndellingwood authored Dec 4, 2018
2 parents b26f446 + 806bb5d commit 486dda8
Show file tree
Hide file tree
Showing 34 changed files with 1,349 additions and 23 deletions.
12 changes: 5 additions & 7 deletions src/batched/KokkosBatched_InverseLU_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
/// \author Vinh Dang (vqdang@sandia.gov)

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Trsm_Decl.hpp"
#include "KokkosBatched_Trsm_Serial_Impl.hpp"
#include "KokkosBatched_SolveLU_Decl.hpp"
#include "KokkosBatched_SolveLU_Serial_Impl.hpp"

namespace KokkosBatched {
namespace Experimental {
Expand Down Expand Up @@ -88,7 +88,7 @@ namespace KokkosBatched {
auto B = Kokkos::View<ScalarType**, Kokkos::LayoutLeft, typename WViewType::memory_space, Kokkos::MemoryTraits<Kokkos::Unmanaged> >(W.data(), A.extent(0), A.extent(1));

const ScalarType one(1.0);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
Expand All @@ -97,9 +97,8 @@ namespace KokkosBatched {
}

//First, compute L inverse by solving the system L*Linv = I for Linv
SerialTrsm<Side::Left,Uplo::Lower,Trans::NoTranspose,Diag::Unit,Algo::Trsm::Unblocked>::invoke(one, A, B);
//Second, compute A inverse by solving the system U*Ainv = Linv for Ainv
SerialTrsm<Side::Left,Uplo::Upper,Trans::NoTranspose,Diag::NonUnit,Algo::Trsm::Unblocked>::invoke(one, A, B);
SerialSolveLU<Algo::SolveLU::Unblocked,Trans::NoTranspose>::invoke(A,B);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
Expand Down Expand Up @@ -140,9 +139,8 @@ namespace KokkosBatched {
}

//First, compute L inverse by solving the system L*Linv = I for Linv
SerialTrsm<Side::Left,Uplo::Lower,Trans::NoTranspose,Diag::Unit,Algo::Trsm::Blocked>::invoke(one, A, B);
//Second, compute A inverse by solving the system U*Ainv = Linv for Ainv
SerialTrsm<Side::Left,Uplo::Upper,Trans::NoTranspose,Diag::NonUnit,Algo::Trsm::Blocked>::invoke(one, A, B);
SerialSolveLU<Algo::SolveLU::Blocked,Trans::NoTranspose>::invoke(A,B);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
Expand Down
10 changes: 4 additions & 6 deletions src/batched/KokkosBatched_InverseLU_Team_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
/// \author Vinh Dang (vqdang@sandia.gov)

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Trsm_Decl.hpp"
#include "KokkosBatched_Trsm_Team_Impl.hpp"
#include "KokkosBatched_SolveLU_Decl.hpp"
#include "KokkosBatched_SolveLU_Team_Impl.hpp"

namespace KokkosBatched {
namespace Experimental {
Expand Down Expand Up @@ -43,9 +43,8 @@ namespace KokkosBatched {
});

//First, compute L inverse by solving the system L*Linv = I for Linv
TeamTrsm<MemberType,Side::Left,Uplo::Lower,Trans::NoTranspose,Diag::Unit,Algo::Trsm::Unblocked>::invoke(member, one, A, B);
//Second, compute A inverse by solving the system U*Ainv = Linv for Ainv
TeamTrsm<MemberType,Side::Left,Uplo::Upper,Trans::NoTranspose,Diag::NonUnit,Algo::Trsm::Unblocked>::invoke(member, one, A, B);
TeamSolveLU<MemberType,Algo::SolveLU::Unblocked,Trans::NoTranspose>::invoke(member, A, B);

Kokkos::parallel_for(Kokkos::TeamThreadRange(member,A.extent(0)*A.extent(1)),[&](const int &tid) {
int i = tid/A.extent(1);
Expand Down Expand Up @@ -82,9 +81,8 @@ namespace KokkosBatched {
});

//First, compute L inverse by solving the system L*Linv = I for Linv
TeamTrsm<MemberType,Side::Left,Uplo::Lower,Trans::NoTranspose,Diag::Unit,Algo::Trsm::Blocked>::invoke(member, one, A, B);
//Second, compute A inverse by solving the system U*Ainv = Linv for Ainv
TeamTrsm<MemberType,Side::Left,Uplo::Upper,Trans::NoTranspose,Diag::NonUnit,Algo::Trsm::Blocked>::invoke(member, one, A, B);
TeamSolveLU<MemberType,Algo::SolveLU::Blocked,Trans::NoTranspose>::invoke(member, A, B);

Kokkos::parallel_for(Kokkos::TeamThreadRange(member,A.extent(0)*A.extent(1)),[&](const int &tid) {
int i = tid/A.extent(1);
Expand Down
41 changes: 41 additions & 0 deletions src/batched/KokkosBatched_SolveLU_Decl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#ifndef __KOKKOSBATCHED_SOLVELU_DECL_HPP__
#define __KOKKOSBATCHED_SOLVELU_DECL_HPP__


/// \author Vinh Dang (vqdang@sandia.gov)

#include "KokkosBatched_Vector.hpp"

namespace KokkosBatched {
namespace Experimental {

template<typename ArgAlgo,
typename TransType>
struct SerialSolveLU {
// no piv version
template<typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const AViewType &A,
const BViewType &B);
};

template<typename MemberType,
typename ArgAlgo,
typename TransType>
struct TeamSolveLU {
// no piv version
template<typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const MemberType &member,
const AViewType &A,
const BViewType &B);
};

}
}

#endif
128 changes: 128 additions & 0 deletions src/batched/KokkosBatched_SolveLU_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#ifndef __KOKKOSBATCHED_SOLVELU_SERIAL_IMPL_HPP__
#define __KOKKOSBATCHED_SOLVELU_SERIAL_IMPL_HPP__


/// \author Vinh Dang (vqdang@sandia.gov)

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Trsm_Decl.hpp"
#include "KokkosBatched_Trsm_Serial_Impl.hpp"

namespace KokkosBatched {
namespace Experimental {
///
/// Serial Impl
/// =========

///
/// SolveLU no piv
///

template<>
template<typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
int
SerialSolveLU<Algo::SolveLU::Unblocked,Trans::NoTranspose>::
invoke(const AViewType &A,
const BViewType &B) {
static_assert(AViewType::rank == 2, "A should have two dimensions");
static_assert((BViewType::rank == 1)||(BViewType::rank == 2), "B should have either one dimension or two dimensions");
static_assert(std::is_same<typename AViewType::memory_space, typename BViewType::memory_space>::value, "A and B should be on the same memory space");
assert(A.extent(0)==A.extent(1));
assert(A.extent(1)==B.extent(0));

typedef typename AViewType::value_type ScalarType;

const ScalarType one(1.0);

//First, compute Y (= U*X) by solving the system L*Y = B for Y
SerialTrsm<Side::Left,Uplo::Lower,Trans::NoTranspose,Diag::Unit,Algo::Trsm::Unblocked>::invoke(one, A, B);
//Second, compute X by solving the system U*X = Y for X
SerialTrsm<Side::Left,Uplo::Upper,Trans::NoTranspose,Diag::NonUnit,Algo::Trsm::Unblocked>::invoke(one, A, B);

return 0;
}

template<>
template<typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
int
SerialSolveLU<Algo::SolveLU::Blocked,Trans::NoTranspose>::
invoke(const AViewType &A,
const BViewType &B) {
static_assert(AViewType::rank == 2, "A should have two dimensions");
static_assert((BViewType::rank == 1)||(BViewType::rank == 2), "B should have either one dimension or two dimensions");
static_assert(std::is_same<typename AViewType::memory_space, typename BViewType::memory_space>::value, "A and B should be on the same memory space");
assert(A.extent(0)==A.extent(1));
assert(A.extent(1)==B.extent(0));

typedef typename AViewType::value_type ScalarType;

const ScalarType one(1.0);

//First, compute Y (= U*X) by solving the system L*Y = B for Y
SerialTrsm<Side::Left,Uplo::Lower,Trans::NoTranspose,Diag::Unit,Algo::Trsm::Blocked>::invoke(one, A, B);
//Second, compute X by solving the system U*X = Y for X
SerialTrsm<Side::Left,Uplo::Upper,Trans::NoTranspose,Diag::NonUnit,Algo::Trsm::Blocked>::invoke(one, A, B);

return 0;
}

template<>
template<typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
int
SerialSolveLU<Algo::SolveLU::Unblocked,Trans::Transpose>::
invoke(const AViewType &A,
const BViewType &B) {
static_assert(AViewType::rank == 2, "A should have two dimensions");
static_assert((BViewType::rank == 1)||(BViewType::rank == 2), "B should have either one dimension or two dimensions");
static_assert(std::is_same<typename AViewType::memory_space, typename BViewType::memory_space>::value, "A and B should be on the same memory space");
assert(A.extent(0)==A.extent(1));
assert(A.extent(1)==B.extent(0));

typedef typename AViewType::value_type ScalarType;

const ScalarType one(1.0);

//First, compute Y (= L'*X) by solving the system U'*Y = B for Y
SerialTrsm<Side::Left,Uplo::Lower,Trans::Transpose,Diag::NonUnit,Algo::Trsm::Unblocked>::invoke(one, A, B);
//Second, compute X by solving the system L'*X = Y for X
SerialTrsm<Side::Left,Uplo::Upper,Trans::Transpose,Diag::Unit,Algo::Trsm::Unblocked>::invoke(one, A, B);

return 0;
}

template<>
template<typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
int
SerialSolveLU<Algo::SolveLU::Blocked,Trans::Transpose>::
invoke(const AViewType &A,
const BViewType &B) {
static_assert(AViewType::rank == 2, "A should have two dimensions");
static_assert((BViewType::rank == 1)||(BViewType::rank == 2), "B should have either one dimension or two dimensions");
static_assert(std::is_same<typename AViewType::memory_space, typename BViewType::memory_space>::value, "A and B should be on the same memory space");
assert(A.extent(0)==A.extent(1));
assert(A.extent(1)==B.extent(0));

typedef typename AViewType::value_type ScalarType;

const ScalarType one(1.0);

//First, compute Y (= L'*X) by solving the system U'*Y = B for Y
SerialTrsm<Side::Left,Uplo::Lower,Trans::Transpose,Diag::NonUnit,Algo::Trsm::Blocked>::invoke(one, A, B);
//Second, compute X by solving the system L'*X = Y for X
SerialTrsm<Side::Left,Uplo::Upper,Trans::Transpose,Diag::Unit,Algo::Trsm::Blocked>::invoke(one, A, B);

return 0;
}

}
}

#endif
128 changes: 128 additions & 0 deletions src/batched/KokkosBatched_SolveLU_Team_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#ifndef __KOKKOSBATCHED_SOLVELU_TEAM_IMPL_HPP__
#define __KOKKOSBATCHED_SOLVELU_TEAM_IMPL_HPP__


/// \author Vinh Dang (vqdang@sandia.gov)

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Trsm_Decl.hpp"
#include "KokkosBatched_Trsm_Team_Impl.hpp"

namespace KokkosBatched {
namespace Experimental {
///
/// Team Impl
/// =========

///
/// SolveLU no piv
///

template<typename MemberType>
struct TeamSolveLU<MemberType,Algo::SolveLU::Unblocked,Trans::NoTranspose> {
template<typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const MemberType &member, const AViewType &A, const BViewType &B) {
static_assert(AViewType::rank == 2, "A should have two dimensions");
static_assert((BViewType::rank == 1)||(BViewType::rank == 2), "B should have either one dimension or two dimensions");
static_assert(std::is_same<typename AViewType::memory_space, typename BViewType::memory_space>::value, "A and B should be on the same memory space");
assert(A.extent(0)==A.extent(1));
assert(A.extent(1)==B.extent(0));

typedef typename AViewType::value_type ScalarType;

const ScalarType one(1.0);

//First, compute Y (= U*X) by solving the system L*Y = B for Y
TeamTrsm<MemberType,Side::Left,Uplo::Lower,Trans::NoTranspose,Diag::Unit,Algo::Trsm::Unblocked>::invoke(member, one, A, B);
//Second, compute X by solving the system U*X = Y for X
TeamTrsm<MemberType,Side::Left,Uplo::Upper,Trans::NoTranspose,Diag::NonUnit,Algo::Trsm::Unblocked>::invoke(member, one, A, B);

return 0;
}
};

template<typename MemberType>
struct TeamSolveLU<MemberType,Algo::SolveLU::Blocked,Trans::NoTranspose> {
template<typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const MemberType &member, const AViewType &A, const BViewType &B) {
static_assert(AViewType::rank == 2, "A should have two dimensions");
static_assert((BViewType::rank == 1)||(BViewType::rank == 2), "B should have either one dimension or two dimensions");
static_assert(std::is_same<typename AViewType::memory_space, typename BViewType::memory_space>::value, "A and B should be on the same memory space");
assert(A.extent(0)==A.extent(1));
assert(A.extent(1)==B.extent(0));

typedef typename AViewType::value_type ScalarType;

const ScalarType one(1.0);

//First, compute Y (= U*X) by solving the system L*Y = B for Y
TeamTrsm<MemberType,Side::Left,Uplo::Lower,Trans::NoTranspose,Diag::Unit,Algo::Trsm::Blocked>::invoke(member, one, A, B);
//Second, compute X by solving the system U*X = Y for X
TeamTrsm<MemberType,Side::Left,Uplo::Upper,Trans::NoTranspose,Diag::NonUnit,Algo::Trsm::Blocked>::invoke(member, one, A, B);

return 0;
}
};

template<typename MemberType>
struct TeamSolveLU<MemberType,Algo::SolveLU::Unblocked,Trans::Transpose> {
template<typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const MemberType &member, const AViewType &A, const BViewType &B) {
static_assert(AViewType::rank == 2, "A should have two dimensions");
static_assert((BViewType::rank == 1)||(BViewType::rank == 2), "B should have either one dimension or two dimensions");
static_assert(std::is_same<typename AViewType::memory_space, typename BViewType::memory_space>::value, "A and B should be on the same memory space");
assert(A.extent(0)==A.extent(1));
assert(A.extent(1)==B.extent(0));

typedef typename AViewType::value_type ScalarType;

const ScalarType one(1.0);

//First, compute Y (= L'*X) by solving the system U'*Y = B for Y
TeamTrsm<MemberType,Side::Left,Uplo::Lower,Trans::Transpose,Diag::NonUnit,Algo::Trsm::Unblocked>::invoke(member, one, A, B);
//Second, compute X by solving the system L'*X = Y for X
TeamTrsm<MemberType,Side::Left,Uplo::Upper,Trans::Transpose,Diag::Unit,Algo::Trsm::Unblocked>::invoke(member, one, A, B);

return 0;
}
};

template<typename MemberType>
struct TeamSolveLU<MemberType,Algo::SolveLU::Blocked,Trans::Transpose> {
template<typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const MemberType &member, const AViewType &A, const BViewType &B) {
static_assert(AViewType::rank == 2, "A should have two dimensions");
static_assert((BViewType::rank == 1)||(BViewType::rank == 2), "B should have either one dimension or two dimensions");
static_assert(std::is_same<typename AViewType::memory_space, typename BViewType::memory_space>::value, "A and B should be on the same memory space");
assert(A.extent(0)==A.extent(1));
assert(A.extent(1)==B.extent(0));

typedef typename AViewType::value_type ScalarType;

const ScalarType one(1.0);

//First, compute Y (= L'*X) by solving the system U'*Y = B for Y
TeamTrsm<MemberType,Side::Left,Uplo::Lower,Trans::Transpose,Diag::NonUnit,Algo::Trsm::Blocked>::invoke(member, one, A, B);
//Second, compute X by solving the system L'*X = Y for X
TeamTrsm<MemberType,Side::Left,Uplo::Upper,Trans::Transpose,Diag::Unit,Algo::Trsm::Blocked>::invoke(member, one, A, B);

return 0;
}
};

}
}

#endif
Loading

0 comments on commit 486dda8

Please sign in to comment.