Skip to content

Commit

Permalink
LU: info check in CALU
Browse files Browse the repository at this point in the history
  • Loading branch information
mgates3 committed Sep 19, 2023
1 parent 59252f0 commit 1e36d9d
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 78 deletions.
2 changes: 1 addition & 1 deletion include/slate/slate.hh
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ int64_t getrf_nopiv(
//-----------------------------------------
// getrf_tntpiv()
template <typename scalar_t>
void getrf_tntpiv(
int64_t getrf_tntpiv(
Matrix<scalar_t>& A, Pivots& pivots,
Options const& opts = Options());

Expand Down
17 changes: 8 additions & 9 deletions src/getrf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ int64_t getrf(
/// Note pivots vector is currently ignored for NoPiv.
///
/// @retval 0 successful exit
/// @retval i > 0, $U(i,i)$ is exactly zero, where i is a 1-based index.
/// @retval i > 0: $U(i,i)$ is exactly zero, where i is a 1-based index.
/// The factorization has been completed, but the factor $U$ is exactly
/// singular, and division by zero will occur if it is used
/// to solve a system of equations.
Expand All @@ -328,42 +328,41 @@ int64_t getrf(
Options const& opts )
{
Method method = get_option( opts, Option::MethodLU, MethodLU::PartialPiv );
int64_t info = 0;

// todo: info for tntpiv, nopiv
if (method == MethodLU::CALU) {
getrf_tntpiv( A, pivots, opts );
return getrf_tntpiv( A, pivots, opts );
}
else if (method == MethodLU::NoPiv) {
// todo: fill in pivots vector?
info = getrf_nopiv( A, opts );
return getrf_nopiv( A, opts );
}
else if (method == MethodLU::PartialPiv) {
Target target = get_option( opts, Option::Target, Target::HostTask );

switch (target) {
case Target::Host:
case Target::HostTask:
info = impl::getrf<Target::HostTask>( A, pivots, opts );
return impl::getrf<Target::HostTask>( A, pivots, opts );
break;

case Target::HostNest:
info = impl::getrf<Target::HostNest>( A, pivots, opts );
return impl::getrf<Target::HostNest>( A, pivots, opts );
break;

case Target::HostBatch:
info = impl::getrf<Target::HostBatch>( A, pivots, opts );
return impl::getrf<Target::HostBatch>( A, pivots, opts );
break;

case Target::Devices:
info = impl::getrf<Target::Devices>( A, pivots, opts );
return impl::getrf<Target::Devices>( A, pivots, opts );
break;
}
}
else {
throw Exception( "unknown value for MethodLU" );
}
return info;
return -2; // shouldn't happen
}

//------------------------------------------------------------------------------
Expand Down
39 changes: 25 additions & 14 deletions src/getrf_tntpiv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace impl {
/// @ingroup gesv_impl
///
template <Target target, typename scalar_t>
void getrf_tntpiv(
int64_t getrf_tntpiv(
Matrix<scalar_t>& A, Pivots& pivots,
Options const& opts)
{
Expand Down Expand Up @@ -57,6 +57,7 @@ void getrf_tntpiv(
if (target == Target::Devices)
target_layout = Layout::RowMajor;

int64_t info = 0;
int64_t A_nt = A.nt();
int64_t A_mt = A.mt();
int64_t min_mt_nt = std::min(A.mt(), A.nt());
Expand Down Expand Up @@ -149,6 +150,7 @@ void getrf_tntpiv(
#pragma omp parallel
#pragma omp master
{
int64_t kk = 0;
for (int64_t k = 0; k < min_mt_nt; ++k) {

int64_t diag_len = std::min(A.tileMb(k), A.tileNb(k));
Expand All @@ -163,10 +165,14 @@ void getrf_tntpiv(

// Factor A(k:mt-1, k) using tournament pivoting to get
// pivots and diagonal tile, Akk in workspace Apanel.
int64_t iinfo;
internal::getrf_tntpiv_panel<target>(
A.sub(k, A_mt-1, k, k), std::move(Apanel),
dwork_array, dwork_bytes, diag_len, ib,
pivots.at(k), max_panel_threads, priority_1 );
pivots.at(k), max_panel_threads, priority_1, &iinfo );
if (info == 0 && iinfo > 0) {
info = kk + iinfo;
}

// Root broadcasts the pivot to all ranks.
// todo: Panel ranks send the pivots to the right.
Expand Down Expand Up @@ -340,6 +346,7 @@ void getrf_tntpiv(
}
}
}
kk += A.tileNb( k );
}
#pragma omp taskwait

Expand All @@ -353,6 +360,9 @@ void getrf_tntpiv(
dwork_array[dev] = nullptr;
}
}

internal::reduce_info( &info, A.mpiComm() );
return info;
}

} // namespace impl
Expand Down Expand Up @@ -403,15 +413,16 @@ void getrf_tntpiv(
///
/// TODO: return value
/// @retval 0 successful exit
/// @retval >0 for return value = $i$, $U(i,i)$ is exactly zero. The
/// factorization has been completed, but the factor $U$ is exactly
/// @retval i < 0: the i-th argument had an illegal value.
/// @retval i > 0: $U(i,i)$ is exactly zero, where i is a 1-based index.
/// The factorization has been completed, but the factor $U$ is exactly
/// singular, and division by zero will occur if it is used
/// to solve a system of equations.
///
/// @ingroup gesv_computational
///
template <typename scalar_t>
void getrf_tntpiv(
int64_t getrf_tntpiv(
Matrix<scalar_t>& A, Pivots& pivots,
Options const& opts)
{
Expand All @@ -420,43 +431,43 @@ void getrf_tntpiv(
switch (target) {
case Target::Host:
case Target::HostTask:
impl::getrf_tntpiv<Target::HostTask>( A, pivots, opts );
return impl::getrf_tntpiv<Target::HostTask>( A, pivots, opts );
break;

case Target::HostNest:
impl::getrf_tntpiv<Target::HostNest>( A, pivots, opts );
return impl::getrf_tntpiv<Target::HostNest>( A, pivots, opts );
break;

case Target::HostBatch:
impl::getrf_tntpiv<Target::HostBatch>( A, pivots, opts );
return impl::getrf_tntpiv<Target::HostBatch>( A, pivots, opts );
break;

case Target::Devices:
impl::getrf_tntpiv<Target::Devices>( A, pivots, opts );
return impl::getrf_tntpiv<Target::Devices>( A, pivots, opts );
break;
}
// todo: return value for errors?
return -2; // shouldn't happen
}

//------------------------------------------------------------------------------
// Explicit instantiations.
template
void getrf_tntpiv<float>(
int64_t getrf_tntpiv<float>(
Matrix<float>& A, Pivots& pivots,
Options const& opts);

template
void getrf_tntpiv<double>(
int64_t getrf_tntpiv<double>(
Matrix<double>& A, Pivots& pivots,
Options const& opts);

template
void getrf_tntpiv< std::complex<float> >(
int64_t getrf_tntpiv< std::complex<float> >(
Matrix< std::complex<float> >& A, Pivots& pivots,
Options const& opts);

template
void getrf_tntpiv< std::complex<double> >(
int64_t getrf_tntpiv< std::complex<double> >(
Matrix< std::complex<double> >& A, Pivots& pivots,
Options const& opts);

Expand Down
12 changes: 6 additions & 6 deletions src/internal/Tile_getrf_tntpiv.hh
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ void getrf_tntpiv_local(
std::vector< scalar_t >& max_value,
std::vector< int64_t >& max_index,
std::vector< int64_t >& max_offset,
std::vector< scalar_t >& top_block)
std::vector< scalar_t >& top_block,
int64_t* info )
{
trace::Block trace_block( "lapack::getrf_tntpiv" );

Expand Down Expand Up @@ -235,11 +236,10 @@ void getrf_tntpiv_local(
tile.at( i, j ) /= aux_pivot[ 0 ][ j ].value();
}
}
else {
// aux_pivot[ 0 ][ j ].value() == 0:
// The factorization has been completed
// but the factor U is exactly singular
// todo: how to handle a zero pivot
else if (*info == 0 && idx == 0) {
// U(j,j) = 0; save info on thread with diagonal tile,
// using 1-based index.
*info = j + 1;
}

// trailing update
Expand Down
2 changes: 1 addition & 1 deletion src/internal/internal.hh
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ void getrf_tntpiv_panel(
std::vector< char* > dwork_array, size_t dwork_bytes,
int64_t diag_len, int64_t ib,
std::vector<Pivot>& pivot,
int max_panel_threads, int priority=0);
int max_panel_threads, int priority, int64_t* info );

//-----------------------------------------
// geqrf()
Expand Down
Loading

0 comments on commit 1e36d9d

Please sign in to comment.