Skip to content

Commit

Permalink
LU: info check in nopiv
Browse files Browse the repository at this point in the history
  • Loading branch information
mgates3 committed Sep 19, 2023
1 parent daf32ed commit 59252f0
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 86 deletions.
46 changes: 7 additions & 39 deletions include/slate/simplified_api.hh
Original file line number Diff line number Diff line change
Expand Up @@ -250,26 +250,15 @@ int64_t lu_solve(
//-----------------------------------------
// lu_solve_nopiv()

// todo
// gbsv_nopiv
// template <typename scalar_t>
// void lu_solve_nopiv(
// BandMatrix<scalar_t>& A,
// Matrix<scalar_t>& B,
// Options const& opts = Options())
// {
// gbsv_nopiv(A, B, opts);
// }

// gesv_nopiv
// todo: deprecate, use lu_solve( ..., { MethodLU: NoPiv } )
template <typename scalar_t>
void lu_solve_nopiv(
[[deprecated( "Use lu_solve( A, { Option::MethodLU, MethodLU::NoPiv } ) instead. Will be removed 2024-09." )]]
int64_t lu_solve_nopiv(
Matrix<scalar_t>& A,
Matrix<scalar_t>& B,
Options const& opts = Options())
{
gesv_nopiv(A, B, opts);
return gesv_nopiv( A, B, opts );
}

//-----------------------------------------
Expand All @@ -296,24 +285,14 @@ int64_t lu_factor(
//-----------------------------------------
// lu_factor_nopiv()

// todo
// gbtrf_nopiv
// template <typename scalar_t>
// void lu_factor_nopiv(
// BandMatrix<scalar_t>& A,
// Options const& opts = Options())
// {
// gbtrf_nopiv(A, opts);
// }

// getrf_nopiv
// todo: deprecate, use lu_factor_nopiv( ..., { MethodLU: NoPiv } )
template <typename scalar_t>
void lu_factor_nopiv(
[[deprecated( "Use lu_factor( A, { Option::MethodLU, MethodLU::NoPiv } ) instead. Will be removed 2024-09." )]]
int64_t lu_factor_nopiv(
Matrix<scalar_t>& A,
Options const& opts = Options())
{
getrf_nopiv(A, opts);
return getrf_nopiv( A, opts );
}

//-----------------------------------------
Expand Down Expand Up @@ -342,20 +321,9 @@ void lu_solve_using_factor(
//-----------------------------------------
// lu_solve_using_factor_nopiv()

// todo
// gbtrs_nopiv
// template <typename scalar_t>
// void lu_solve_using_factor_nopiv(
// BandMatrix<scalar_t>& A,
// Matrix<scalar_t>& B,
// Options const& opts = Options())
// {
// gbtrs_nopiv(A, B, opts);
// }

// getrs_nopiv
// todo: deprecate, use lu_solve_using_factor( ..., { MethodLU: NoPiv } )
template <typename scalar_t>
[[deprecated( "Use lu_solve_using_factor( A, { Option::MethodLU, MethodLU::NoPiv } ) instead. Will be removed 2024-09." )]]
void lu_solve_using_factor_nopiv(
Matrix<scalar_t>& A,
Matrix<scalar_t>& B,
Expand Down
2 changes: 1 addition & 1 deletion include/slate/slate.hh
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ int64_t getrf(
//-----------------------------------------
// getrf_nopiv()
template <typename scalar_t>
void getrf_nopiv(
int64_t getrf_nopiv(
Matrix<scalar_t>& A,
Options const& opts = Options());

Expand Down
25 changes: 12 additions & 13 deletions src/gesv_nopiv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,14 @@ namespace slate {
/// - HostBatch: batched BLAS on CPU host.
/// - Devices: batched BLAS on GPU device.
///
/// TODO: return value
/// @retval 0 successful exit
/// @retval >0 for return value = $i$, the computed $U(i,i)$ is exactly zero.
/// The factorization has been completed, but the factor U is exactly
/// singular, so the solution could not be computed.
/// @retval i > 0: U(i,i) is exactly zero (1-based index). The factorization
/// will have NaN due to division by zero.
///
/// @ingroup gesv
///
template <typename scalar_t>
void gesv_nopiv(
int64_t gesv_nopiv(
Matrix<scalar_t>& A,
Matrix<scalar_t>& B,
Options const& opts)
Expand All @@ -76,36 +74,37 @@ void gesv_nopiv(
slate_assert(B.mt() == A.mt());

// factorization
getrf_nopiv(A, opts);
int64_t info = getrf_nopiv( A, opts );

// solve
getrs_nopiv(A, B, opts);

// todo: return value for errors?
if (info == 0) {
getrs_nopiv( A, B, opts );
}
return info;
}

//------------------------------------------------------------------------------
// Explicit instantiations.
template
void gesv_nopiv<float>(
int64_t gesv_nopiv<float>(
Matrix<float>& A,
Matrix<float>& B,
Options const& opts);

template
void gesv_nopiv<double>(
int64_t gesv_nopiv<double>(
Matrix<double>& A,
Matrix<double>& B,
Options const& opts);

template
void gesv_nopiv< std::complex<float> >(
int64_t gesv_nopiv< std::complex<float> >(
Matrix< std::complex<float> >& A,
Matrix< std::complex<float> >& B,
Options const& opts);

template
void gesv_nopiv< std::complex<double> >(
int64_t gesv_nopiv< std::complex<double> >(
Matrix< std::complex<double> >& A,
Matrix< std::complex<double> >& B,
Options const& opts);
Expand Down
2 changes: 1 addition & 1 deletion src/getrf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ int64_t getrf(
}
else if (method == MethodLU::NoPiv) {
// todo: fill in pivots vector?
getrf_nopiv( A, opts );
info = getrf_nopiv( A, opts );
}
else if (method == MethodLU::PartialPiv) {
Target target = get_option( opts, Option::Target, Target::HostTask );
Expand Down
42 changes: 25 additions & 17 deletions src/getrf_nopiv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace impl {
/// @ingroup gesv_impl
///
template <Target target, typename scalar_t>
void getrf_nopiv(
int64_t getrf_nopiv(
Matrix<scalar_t>& A,
Options const& opts )
{
Expand All @@ -48,6 +48,7 @@ void getrf_nopiv(
A.reserveDeviceWorkspace();
}

int64_t info = 0;
int64_t A_nt = A.nt();
int64_t A_mt = A.mt();
int64_t min_mt_nt = std::min(A.mt(), A.nt());
Expand All @@ -72,6 +73,7 @@ void getrf_nopiv(
#pragma omp parallel
#pragma omp master
{
int64_t kk = 0; // column index (not block-column)
for (int64_t k = 0; k < min_mt_nt; ++k) {

// panel, high priority
Expand All @@ -80,8 +82,12 @@ void getrf_nopiv(
priority(1)
{
// factor A(k, k)
int64_t iinfo;
internal::getrf_nopiv<Target::HostTask>(
A.sub(k, k, k, k), ib, priority_1 );
A.sub(k, k, k, k), ib, priority_1, &iinfo );
if (info == 0 && iinfo > 0) {
info = kk + iinfo;
}

// Update panel
int tag_k = k;
Expand Down Expand Up @@ -227,12 +233,16 @@ void getrf_nopiv(
}
}
}
kk += A.tileNb( k );
}

#pragma omp taskwait
A.tileUpdateAllOrigin();
}
A.clearWorkspace();

internal::reduce_info( &info, A.mpiComm() );
return info;
}

} // namespace impl
Expand Down Expand Up @@ -275,17 +285,15 @@ void getrf_nopiv(
/// - HostBatch: batched BLAS on CPU host.
/// - Devices: batched BLAS on GPU device.
///
/// TODO: return value
/// @retval 0 successful exit
/// @retval >0 for return value = $i$, $U(i,i)$ is exactly zero. The
/// factorization has been completed, but the factor $U$ is exactly
/// singular, and division by zero will occur if it is used
/// to solve a system of equations.
/// @retval i < 0: the i-th argument had an illegal value.
/// @retval i > 0: U(i,i) is exactly zero (1-based index). The factorization
/// will have NaN due to division by zero.
///
/// @ingroup gesv_computational
///
template <typename scalar_t>
void getrf_nopiv(
int64_t getrf_nopiv(
Matrix<scalar_t>& A,
Options const& opts )
{
Expand All @@ -294,43 +302,43 @@ void getrf_nopiv(
switch (target) {
case Target::Host:
case Target::HostTask:
impl::getrf_nopiv<Target::HostTask>( A, opts );
return impl::getrf_nopiv<Target::HostTask>( A, opts );
break;

case Target::HostNest:
impl::getrf_nopiv<Target::HostNest>( A, opts );
return impl::getrf_nopiv<Target::HostNest>( A, opts );
break;

case Target::HostBatch:
impl::getrf_nopiv<Target::HostBatch>( A, opts );
return impl::getrf_nopiv<Target::HostBatch>( A, opts );
break;

case Target::Devices:
impl::getrf_nopiv<Target::Devices>( A, opts );
return impl::getrf_nopiv<Target::Devices>( A, opts );
break;
}
// todo: return value for errors?
return -2; // shouldn't happen
}

//------------------------------------------------------------------------------
// Explicit instantiations.
template
void getrf_nopiv<float>(
int64_t getrf_nopiv<float>(
Matrix<float>& A,
Options const& opts);

template
void getrf_nopiv<double>(
int64_t getrf_nopiv<double>(
Matrix<double>& A,
Options const& opts);

template
void getrf_nopiv< std::complex<float> >(
int64_t getrf_nopiv< std::complex<float> >(
Matrix< std::complex<float> >& A,
Options const& opts);

template
void getrf_nopiv< std::complex<double> >(
int64_t getrf_nopiv< std::complex<double> >(
Matrix< std::complex<double> >& A,
Options const& opts);

Expand Down
21 changes: 18 additions & 3 deletions src/internal/Tile_getrf_nopiv.hh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
#include <lapack.hh>

namespace slate {
namespace internal {
namespace tile {

//------------------------------------------------------------------------------
/// Compute the LU factorization of a tile without pivoting.
///
Expand All @@ -24,11 +25,19 @@ namespace internal {
/// @param[in,out] tile
/// tile to factor
///
/// @param[in,out] info
/// Exit status.
/// * 0: successful exit
/// * i > 0: U(i,i) is exactly zero (1-based index). The factorization
/// will have NaN due to division by zero.
///
/// @ingroup gesv_tile
///
template <typename scalar_t>
void getrf_nopiv(Tile<scalar_t> tile, int64_t ib)
void getrf_nopiv(
Tile<scalar_t> tile, int64_t ib, int64_t* info )
{
const scalar_t zero = 0.0;
const scalar_t one = 1.0;
int64_t nb = tile.nb();
int64_t mb = tile.mb();
Expand All @@ -44,9 +53,15 @@ void getrf_nopiv(Tile<scalar_t> tile, int64_t ib)
// Loop over ib columns of a stripe.
for (int64_t j = k; j < k+kb; ++j) {

// Detect exact singularity.
scalar_t pivot = tile( j, j );
if (*info == 0 && pivot == zero)
*info = j + 1;

// todo: should this `if` condition be here? The last col of a panel isn't updated??
if (j+1 < mb) {
// Update column
blas::scal(mb-j-1, one/tile(j, j), &tile.at(j+1, j), 1);
blas::scal( mb-j-1, one/pivot, &tile.at( j+1, j ), 1 );
}

// trailing update within ib block
Expand Down
5 changes: 3 additions & 2 deletions src/internal/internal.hh
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,9 @@ void getrf_panel(
//-----------------------------------------
// getrf_nopiv()
template <Target target=Target::HostTask, typename scalar_t>
void getrf_nopiv(Matrix<scalar_t>&& A,
int64_t ib, int priority=0);
void getrf_nopiv(
Matrix<scalar_t>&& A,
int64_t ib, int priority, int64_t* info );

//-----------------------------------------
// getrf_tntpiv()
Expand Down
Loading

0 comments on commit 59252f0

Please sign in to comment.