Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 72 additions & 51 deletions src/linear-algebra-solvers/concrete/tlr/HicmaImplementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,21 @@
* @file HicmaImplementation.cpp
* @brief Sets up the HiCMA descriptors needed for the tile low rank computations in ExaGeoStat.
* @version 1.1.0
* @author Omar Marzouk
* @author Mahmoud ElKarargy
* @author Sameh Abdulah
* @date 2024-02-04
**/
**/

#include <mkl_service.h>

#ifdef USE_MPI

#include <mpi.h>

#endif

#include <linear-algebra-solvers/concrete/hicma/tlr/HicmaImplementation.hpp>

using namespace std;

using namespace exageostat::linearAlgebra::tileLowRank;
using namespace exageostat::common;
using namespace exageostat::runtime;
Expand All @@ -34,11 +32,12 @@ using namespace exageostat::results;

int store_only_diagonal_tiles = 1;
int use_scratch = 1;
int global_check = 0; //used to create dense matrix for accuracy check
int global_check = 0; // used to create dense matrix for accuracy check

template<typename T>
void HicmaImplementation<T>::SetModelingDescriptors(std::unique_ptr<ExaGeoStatData<T>> &aData,
configurations::Configurations &aConfigurations, const int &aP) {
configurations::Configurations &aConfigurations,
const int &aP) {

int full_problem_size = aConfigurations.GetProblemSize() * aP;
int lts = aConfigurations.GetLowTileSize();
Expand All @@ -61,67 +60,84 @@ void HicmaImplementation<T>::SetModelingDescriptors(std::unique_ptr<ExaGeoStatDa
int NBD = lts;
int MD = full_problem_size;
int ND = MBD;

#ifdef USE_MPI
// Due to a bug in HiCMA with MPI, the first created descriptor is not seen.
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CD, is_OOC, nullptr, float_point,
lts, lts, lts * lts, full_problem_size, 1, 0, 0, full_problem_size, 1,
p_grid, q_grid);
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CD, is_OOC, nullptr,
float_point, lts, lts, lts * lts, full_problem_size, 1,
0, 0, full_problem_size, 1, p_grid, q_grid);
#endif
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CD, is_OOC, nullptr, float_point,
MBD, NBD, MBD * NBD, MD, ND, 0, 0, MD, ND, p_grid, q_grid);

aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CD, is_OOC, nullptr,
float_point, MBD, NBD, MBD * NBD, MD, ND,
0, 0, MD, ND, p_grid, q_grid);

int MBUV = lts;
int NBUV = 2 * max_rank;
int MUV;
int N_over_lts_times_lts = full_problem_size / lts * lts;

if (N_over_lts_times_lts < full_problem_size) {
MUV = N_over_lts_times_lts + lts;
} else if (N_over_lts_times_lts == full_problem_size) {
MUV = N_over_lts_times_lts;
} else {
throw runtime_error("This case can't happens, N need to be >= lts*lts");
}

int expr = MUV / lts;
int NUV = 2 * expr * max_rank;
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CUV, is_OOC, nullptr, float_point,
MBUV, NBUV, MBUV * NBUV, MUV, NUV, 0, 0, MUV, NUV, p_grid, q_grid);

aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CUV, is_OOC, nullptr,
float_point, MBUV, NBUV, MBUV * NBUV, MUV, NUV,
0, 0, MUV, NUV, p_grid, q_grid);

auto *HICMA_descCUV = aData->GetDescriptorData()->GetDescriptor(DescriptorType::HICMA_DESCRIPTOR,
DescriptorName::DESCRIPTOR_CUV).hicma_desc;

int MBrk = 1;
int NBrk = 1;
int Mrk = HICMA_descCUV->mt;
int Nrk = HICMA_descCUV->mt;
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CRK, is_OOC, nullptr, float_point,
MBrk, NBrk, MBrk * NBrk, Mrk, Nrk, 0, 0, Mrk, Nrk, p_grid, q_grid);

aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CRK, is_OOC, nullptr,
float_point, MBrk, NBrk, MBrk * NBrk, Mrk, Nrk,
0, 0, Mrk, Nrk, p_grid, q_grid);

aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_Z, is_OOC, nullptr,
float_point, lts, lts, lts * lts, full_problem_size, 1,
0, 0, full_problem_size, 1, p_grid, q_grid);

aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_Z, is_OOC, nullptr, float_point,
lts, lts, lts * lts, full_problem_size, 1, 0, 0, full_problem_size, 1,
p_grid, q_grid);
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_Z_COPY, is_OOC, nullptr,
float_point, lts, lts, lts * lts, full_problem_size, 1,
0, 0, full_problem_size, 1, p_grid, q_grid);

aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_Z_COPY, is_OOC, nullptr, float_point,
lts, lts, lts * lts, full_problem_size, 1, 0, 0, full_problem_size, 1,
p_grid, q_grid);
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_DETERMINANT, is_OOC, nullptr,
float_point, lts, lts, lts * lts, 1, 1, 0, 0, 1, 1, p_grid, q_grid);
float_point, lts, lts, lts * lts, 1, 1,
0, 0, 1, 1, p_grid, q_grid);

if (aConfigurations.GetIsNonGaussian()) {
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_PRODUCT, is_OOC, nullptr,
float_point,
lts, lts, lts * lts, 1, 1, 0, 0, 1, 1, p_grid, q_grid);
float_point, lts, lts, lts * lts, 1, 1,
0, 0, 1, 1, p_grid, q_grid);
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_SUM, is_OOC, nullptr,
float_point,
lts, lts, lts * lts, 1, 1, 0, 0, 1, 1, p_grid, q_grid);
float_point, lts, lts, lts * lts, 1, 1,
0, 0, 1, 1, p_grid, q_grid);
}
}

template<typename T>
T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &aData,
configurations::Configurations &aConfigurations, const double *theta,
T *apMeasurementsMatrix, const Kernel<T> &aKernel) {
configurations::Configurations &aConfigurations,
const double *theta,
T *apMeasurementsMatrix,
const Kernel<T> &aKernel) {

if (!aData->GetDescriptorData()->GetIsDescriptorInitiated()) {
this->InitiateDescriptors(aConfigurations, *aData->GetDescriptorData(), aKernel.GetVariablesNumber(),
apMeasurementsMatrix);
}

// Create a Hicma sequence, if not initialized before through the same descriptors
RUNTIME_request_t request_array[2] = {HICMA_REQUEST_INITIALIZER, HICMA_REQUEST_INITIALIZER};
if (!aData->GetDescriptorData()->GetSequence()) {
Expand All @@ -132,8 +148,10 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
}
auto pSequence = (HICMA_sequence_t *) aData->GetDescriptorData()->GetSequence();

//Initialization
T loglik, logdet, test_time, variance, variance1 = 1, variance2 = 1, variance3, dot_product, dot_product1, dot_product2, dot_product3, dzcpy_time, time_facto, time_solve, logdet_calculate, matrix_gen_time;
// Initialization
T loglik, logdet, test_time, variance, variance1 = 1, variance2 = 1, variance3;
T dot_product, dot_product1, dot_product2, dot_product3;
T dzcpy_time, time_facto, time_solve, logdet_calculate, matrix_gen_time;
double accumulated_executed_time, accumulated_flops;

int NRHS, i;
Expand All @@ -150,6 +168,8 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
if (iter_count == 0) {
this->SetModelingDescriptors(aData, aConfigurations, aKernel.GetVariablesNumber());
}

// Get descriptor pointers
auto *HICMA_descCUV = aData->GetDescriptorData()->GetDescriptor(DescriptorType::HICMA_DESCRIPTOR,
DescriptorName::DESCRIPTOR_CUV).hicma_desc;
auto *HICMA_descC = aData->GetDescriptorData()->GetDescriptor(DescriptorType::HICMA_DESCRIPTOR,
Expand All @@ -172,6 +192,7 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
DescriptorName::DESCRIPTOR_PRODUCT).chameleon_desc;
auto *HICMA_desc_sum = aData->GetDescriptorData()->GetDescriptor(DescriptorType::HICMA_DESCRIPTOR,
DescriptorName::DESCRIPTOR_SUM).chameleon_desc;

N = HICMA_descCUV->m;
NRHS = HICMA_descZ->n;
lts = HICMA_descZ->mb;
Expand All @@ -194,7 +215,8 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
CHAMELEON_dlacpy_Tile(ChamUpperLower, CHAM_descZ, CHAM_descZcpy);
}
}
//Matrix generation part.

// Matrix generation part.
VERBOSE("LR:Generate New Covariance Matrix...")
START_TIMING(matrix_gen_time);

Expand All @@ -203,9 +225,10 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
hicma_problem.noise = 1e-4;
hicma_problem.ndim = 2;

hicma_problem.kernel_type =
aConfigurations.GetDistanceMetric() == common::GREAT_CIRCLE_DISTANCE ? STARSH_SPATIAL_MATERN2_GCD
: STARSH_SPATIAL_MATERN2_SIMD;
hicma_problem.kernel_type = aConfigurations.GetDistanceMetric() == common::GREAT_CIRCLE_DISTANCE
? STARSH_SPATIAL_MATERN2_GCD
: STARSH_SPATIAL_MATERN2_SIMD;

int hicma_data_type;
if (aConfigurations.GetIsNonGaussian()) {
hicma_data_type = HICMA_STARSH_PROB_GEOSTAT_NON_GAUSSIAN;
Expand All @@ -216,31 +239,31 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
HICMA_zgenerate_problem(hicma_data_type, 'S', 0, N, lts, HICMA_descCUV->mt, HICMA_descCUV->nt,
&hicma_problem);
int compress_diag = 0;
HICMA_zgytlr_Tile(EXAGEOSTAT_LOWER, HICMA_descCUV, HICMA_descCD, HICMA_descCrk, 0, max_rank, pow(10, -1.0 * acc),
compress_diag, HICMA_descC);
HICMA_dgytlr_Tile(EXAGEOSTAT_LOWER, HICMA_descCUV, HICMA_descCD, HICMA_descCrk, 0, max_rank,
pow(10, -1.0 * acc), compress_diag, HICMA_descC);

STOP_TIMING(matrix_gen_time);
VERBOSE("Done.")
//******************************

// ******************************
VERBOSE("LR: re-Copy z...")
START_TIMING(test_time);
//re-store old Z
// re-store old Z
this->ExaGeoStatLapackCopyTile(EXAGEOSTAT_UPPER_LOWER, HICMA_descZcpy, HICMA_descZ);
STOP_TIMING(test_time);
VERBOSE("Done.")

//Calculate Cholesky Factorization (C=LL-1)
// Calculate Cholesky Factorization (C=LL-1)
VERBOSE("LR: Cholesky factorization of Sigma...")
START_TIMING(time_facto);

this->ExaGeoStatPotrfTile(EXAGEOSTAT_LOWER, HICMA_descCUV, 0, HICMA_descCD, HICMA_descCrk, max_rank,
pow(10, -1.0 * acc));
this->ExaGeoStatPotrfTile(EXAGEOSTAT_LOWER, HICMA_descCUV, 0, HICMA_descCD, HICMA_descCrk, max_rank, acc);

STOP_TIMING(time_facto);
flops = flops + flops_dpotrf(N);
VERBOSE("Done.")

//Calculate log(|C|) --> log(square(|L|))
// Calculate log(|C|) --> log(square(|L|))
VERBOSE("LR:Calculating the log determinant ...")
START_TIMING(logdet_calculate);
RuntimeFunctions<T>::ExaGeoStatMeasureDetTileAsync(aConfigurations.GetComputation(), HICMA_descCD, pSequence,
Expand Down Expand Up @@ -276,10 +299,10 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
VERBOSE(" Done.")
}

//Solving Linear System (L*X=Z)--->inv(L)*Z
// Solving Linear System (L*X=Z)--->inv(L)*Z
VERBOSE("LR:Solving the linear system ...")
START_TIMING(time_solve);
//Compute triangular solve LC*X = Z
// Compute triangular solve LC*X = Z
this->ExaGeoStatTrsmTile(EXAGEOSTAT_LEFT, EXAGEOSTAT_LOWER, EXAGEOSTAT_NO_TRANS, EXAGEOSTAT_NON_UNIT, 1,
HICMA_descCUV, HICMA_descCD, HICMA_descCrk, HICMA_descZ, max_rank);
STOP_TIMING(time_solve);
Expand All @@ -293,6 +316,7 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
CHAMELEON_dgemm_Tile(ChamTrans, ChamNoTrans, 1, CHAM_descZ, CHAM_descZ, 0, CHAM_desc_product);
dot_product = *product;
loglik = -0.5 * dot_product - 0.5 * logdet;

if (aConfigurations.GetIsNonGaussian()) {
loglik = loglik - *sum - N * log(theta[3]) - (double) (N / 2.0) * log(2.0 * PI);
} else {
Expand All @@ -304,6 +328,7 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
if (aConfigurations.GetLogger()) {
fprintf(aConfigurations.GetFileLogPath(), "\t %d- Model Parameters (", iter_count + 1);
}

if ((aConfigurations.GetKernelName() == "bivariate_matern_parsimonious_profile") ||
(aConfigurations.GetKernelName() == "bivariate_matern_parsimonious2_profile")) {
LOGGER(setprecision(8) << variance1 << setprecision(8) << variance2)
Expand Down Expand Up @@ -340,12 +365,10 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
aData->SetMleIterations(aData->GetMleIterations() + 1);

// for experiments and benchmarking
accumulated_executed_time =
Results::GetInstance()->GetTotalModelingExecutionTime() + time_facto + logdet_calculate +
time_solve;
accumulated_executed_time = Results::GetInstance()->GetTotalModelingExecutionTime() + time_facto +
logdet_calculate + time_solve;
Results::GetInstance()->SetTotalModelingExecutionTime(accumulated_executed_time);
accumulated_flops =
Results::GetInstance()->GetTotalModelingFlops() + (flops / 1e9 / (time_facto + time_solve));
accumulated_flops = Results::GetInstance()->GetTotalModelingFlops() + (flops / 1e9 / (time_facto + time_solve));
Results::GetInstance()->SetTotalModelingFlops(accumulated_flops);

Results::GetInstance()->SetMLEIterations(iter_count + 1);
Expand All @@ -359,7 +382,6 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
return loglik;
}


template<typename T>
void HicmaImplementation<T>::ExaGeoStatLapackCopyTile(const UpperLower &aUpperLower, void *apA, void *apB) {
int status = HICMA_dlacpy_Tile(aUpperLower, (HICMA_desc_t *) apA, (HICMA_desc_t *) apB);
Expand Down Expand Up @@ -389,7 +411,6 @@ void HicmaImplementation<T>::ExaGeoStatPotrfTile(const common::UpperLower &aUppe
if (status != HICMA_SUCCESS) {
throw std::runtime_error("HICMA_dpotrf_Tile Failed, Matrix is not positive definite");
}

}

template<typename T>
Expand Down