Skip to content

Commit cb90515

Browse files
Fix tlr bug fix om (#30)
* TLR Issue resolved -- zgytlr changed to dgytlr and accuracy calculation fixed * alignment and reformatting --------- Co-authored-by: mahmoud <mahmoud.elkarargy@brightskiesinc.com>
1 parent a25a9d1 commit cb90515

File tree

1 file changed

+72
-51
lines changed

1 file changed

+72
-51
lines changed

src/linear-algebra-solvers/concrete/tlr/HicmaImplementation.cpp

Lines changed: 72 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,21 @@
77
* @file HicmaImplementation.cpp
88
* @brief Sets up the HiCMA descriptors needed for the tile low rank computations in ExaGeoStat.
99
* @version 1.1.0
10+
* @author Omar Marzouk
1011
* @author Mahmoud ElKarargy
1112
* @author Sameh Abdulah
1213
* @date 2024-02-04
13-
**/
14+
**/
1415

1516
#include <mkl_service.h>
1617

1718
#ifdef USE_MPI
18-
1919
#include <mpi.h>
20-
2120
#endif
2221

2322
#include <linear-algebra-solvers/concrete/hicma/tlr/HicmaImplementation.hpp>
2423

2524
using namespace std;
26-
2725
using namespace exageostat::linearAlgebra::tileLowRank;
2826
using namespace exageostat::common;
2927
using namespace exageostat::runtime;
@@ -34,11 +32,12 @@ using namespace exageostat::results;
3432

3533
int store_only_diagonal_tiles = 1;
3634
int use_scratch = 1;
37-
int global_check = 0; //used to create dense matrix for accuracy check
35+
int global_check = 0; // used to create dense matrix for accuracy check
3836

3937
template<typename T>
4038
void HicmaImplementation<T>::SetModelingDescriptors(std::unique_ptr<ExaGeoStatData<T>> &aData,
41-
configurations::Configurations &aConfigurations, const int &aP) {
39+
configurations::Configurations &aConfigurations,
40+
const int &aP) {
4241

4342
int full_problem_size = aConfigurations.GetProblemSize() * aP;
4443
int lts = aConfigurations.GetLowTileSize();
@@ -61,67 +60,84 @@ void HicmaImplementation<T>::SetModelingDescriptors(std::unique_ptr<ExaGeoStatDa
6160
int NBD = lts;
6261
int MD = full_problem_size;
6362
int ND = MBD;
63+
6464
#ifdef USE_MPI
6565
// Due to a bug in HiCMA with MPI, the first created descriptor is not seen.
66-
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CD, is_OOC, nullptr, float_point,
67-
lts, lts, lts * lts, full_problem_size, 1, 0, 0, full_problem_size, 1,
68-
p_grid, q_grid);
66+
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CD, is_OOC, nullptr,
67+
float_point, lts, lts, lts * lts, full_problem_size, 1,
68+
0, 0, full_problem_size, 1, p_grid, q_grid);
6969
#endif
70-
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CD, is_OOC, nullptr, float_point,
71-
MBD, NBD, MBD * NBD, MD, ND, 0, 0, MD, ND, p_grid, q_grid);
70+
71+
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CD, is_OOC, nullptr,
72+
float_point, MBD, NBD, MBD * NBD, MD, ND,
73+
0, 0, MD, ND, p_grid, q_grid);
74+
7275
int MBUV = lts;
7376
int NBUV = 2 * max_rank;
7477
int MUV;
7578
int N_over_lts_times_lts = full_problem_size / lts * lts;
79+
7680
if (N_over_lts_times_lts < full_problem_size) {
7781
MUV = N_over_lts_times_lts + lts;
7882
} else if (N_over_lts_times_lts == full_problem_size) {
7983
MUV = N_over_lts_times_lts;
8084
} else {
8185
throw runtime_error("This case can't happens, N need to be >= lts*lts");
8286
}
87+
8388
int expr = MUV / lts;
8489
int NUV = 2 * expr * max_rank;
85-
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CUV, is_OOC, nullptr, float_point,
86-
MBUV, NBUV, MBUV * NBUV, MUV, NUV, 0, 0, MUV, NUV, p_grid, q_grid);
90+
91+
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CUV, is_OOC, nullptr,
92+
float_point, MBUV, NBUV, MBUV * NBUV, MUV, NUV,
93+
0, 0, MUV, NUV, p_grid, q_grid);
94+
8795
auto *HICMA_descCUV = aData->GetDescriptorData()->GetDescriptor(DescriptorType::HICMA_DESCRIPTOR,
8896
DescriptorName::DESCRIPTOR_CUV).hicma_desc;
97+
8998
int MBrk = 1;
9099
int NBrk = 1;
91100
int Mrk = HICMA_descCUV->mt;
92101
int Nrk = HICMA_descCUV->mt;
93-
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CRK, is_OOC, nullptr, float_point,
94-
MBrk, NBrk, MBrk * NBrk, Mrk, Nrk, 0, 0, Mrk, Nrk, p_grid, q_grid);
102+
103+
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_CRK, is_OOC, nullptr,
104+
float_point, MBrk, NBrk, MBrk * NBrk, Mrk, Nrk,
105+
0, 0, Mrk, Nrk, p_grid, q_grid);
106+
107+
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_Z, is_OOC, nullptr,
108+
float_point, lts, lts, lts * lts, full_problem_size, 1,
109+
0, 0, full_problem_size, 1, p_grid, q_grid);
95110

96-
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_Z, is_OOC, nullptr, float_point,
97-
lts, lts, lts * lts, full_problem_size, 1, 0, 0, full_problem_size, 1,
98-
p_grid, q_grid);
111+
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_Z_COPY, is_OOC, nullptr,
112+
float_point, lts, lts, lts * lts, full_problem_size, 1,
113+
0, 0, full_problem_size, 1, p_grid, q_grid);
99114

100-
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_Z_COPY, is_OOC, nullptr, float_point,
101-
lts, lts, lts * lts, full_problem_size, 1, 0, 0, full_problem_size, 1,
102-
p_grid, q_grid);
103115
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_DETERMINANT, is_OOC, nullptr,
104-
float_point, lts, lts, lts * lts, 1, 1, 0, 0, 1, 1, p_grid, q_grid);
116+
float_point, lts, lts, lts * lts, 1, 1,
117+
0, 0, 1, 1, p_grid, q_grid);
105118

106119
if (aConfigurations.GetIsNonGaussian()) {
107120
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_PRODUCT, is_OOC, nullptr,
108-
float_point,
109-
lts, lts, lts * lts, 1, 1, 0, 0, 1, 1, p_grid, q_grid);
121+
float_point, lts, lts, lts * lts, 1, 1,
122+
0, 0, 1, 1, p_grid, q_grid);
110123
aData->GetDescriptorData()->SetDescriptor(common::HICMA_DESCRIPTOR, DESCRIPTOR_SUM, is_OOC, nullptr,
111-
float_point,
112-
lts, lts, lts * lts, 1, 1, 0, 0, 1, 1, p_grid, q_grid);
124+
float_point, lts, lts, lts * lts, 1, 1,
125+
0, 0, 1, 1, p_grid, q_grid);
113126
}
114127
}
115128

116129
template<typename T>
117130
T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &aData,
118-
configurations::Configurations &aConfigurations, const double *theta,
119-
T *apMeasurementsMatrix, const Kernel<T> &aKernel) {
131+
configurations::Configurations &aConfigurations,
132+
const double *theta,
133+
T *apMeasurementsMatrix,
134+
const Kernel<T> &aKernel) {
120135

121136
if (!aData->GetDescriptorData()->GetIsDescriptorInitiated()) {
122137
this->InitiateDescriptors(aConfigurations, *aData->GetDescriptorData(), aKernel.GetVariablesNumber(),
123138
apMeasurementsMatrix);
124139
}
140+
125141
// Create a Hicma sequence, if not initialized before through the same descriptors
126142
RUNTIME_request_t request_array[2] = {HICMA_REQUEST_INITIALIZER, HICMA_REQUEST_INITIALIZER};
127143
if (!aData->GetDescriptorData()->GetSequence()) {
@@ -132,8 +148,10 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
132148
}
133149
auto pSequence = (HICMA_sequence_t *) aData->GetDescriptorData()->GetSequence();
134150

135-
//Initialization
136-
T loglik, logdet, test_time, variance, variance1 = 1, variance2 = 1, variance3, dot_product, dot_product1, dot_product2, dot_product3, dzcpy_time, time_facto, time_solve, logdet_calculate, matrix_gen_time;
151+
// Initialization
152+
T loglik, logdet, test_time, variance, variance1 = 1, variance2 = 1, variance3;
153+
T dot_product, dot_product1, dot_product2, dot_product3;
154+
T dzcpy_time, time_facto, time_solve, logdet_calculate, matrix_gen_time;
137155
double accumulated_executed_time, accumulated_flops;
138156

139157
int NRHS, i;
@@ -150,6 +168,8 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
150168
if (iter_count == 0) {
151169
this->SetModelingDescriptors(aData, aConfigurations, aKernel.GetVariablesNumber());
152170
}
171+
172+
// Get descriptor pointers
153173
auto *HICMA_descCUV = aData->GetDescriptorData()->GetDescriptor(DescriptorType::HICMA_DESCRIPTOR,
154174
DescriptorName::DESCRIPTOR_CUV).hicma_desc;
155175
auto *HICMA_descC = aData->GetDescriptorData()->GetDescriptor(DescriptorType::HICMA_DESCRIPTOR,
@@ -172,6 +192,7 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
172192
DescriptorName::DESCRIPTOR_PRODUCT).chameleon_desc;
173193
auto *HICMA_desc_sum = aData->GetDescriptorData()->GetDescriptor(DescriptorType::HICMA_DESCRIPTOR,
174194
DescriptorName::DESCRIPTOR_SUM).chameleon_desc;
195+
175196
N = HICMA_descCUV->m;
176197
NRHS = HICMA_descZ->n;
177198
lts = HICMA_descZ->mb;
@@ -194,7 +215,8 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
194215
CHAMELEON_dlacpy_Tile(ChamUpperLower, CHAM_descZ, CHAM_descZcpy);
195216
}
196217
}
197-
//Matrix generation part.
218+
219+
// Matrix generation part.
198220
VERBOSE("LR:Generate New Covariance Matrix...")
199221
START_TIMING(matrix_gen_time);
200222

@@ -203,9 +225,10 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
203225
hicma_problem.noise = 1e-4;
204226
hicma_problem.ndim = 2;
205227

206-
hicma_problem.kernel_type =
207-
aConfigurations.GetDistanceMetric() == common::GREAT_CIRCLE_DISTANCE ? STARSH_SPATIAL_MATERN2_GCD
208-
: STARSH_SPATIAL_MATERN2_SIMD;
228+
hicma_problem.kernel_type = aConfigurations.GetDistanceMetric() == common::GREAT_CIRCLE_DISTANCE
229+
? STARSH_SPATIAL_MATERN2_GCD
230+
: STARSH_SPATIAL_MATERN2_SIMD;
231+
209232
int hicma_data_type;
210233
if (aConfigurations.GetIsNonGaussian()) {
211234
hicma_data_type = HICMA_STARSH_PROB_GEOSTAT_NON_GAUSSIAN;
@@ -216,31 +239,31 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
216239
HICMA_zgenerate_problem(hicma_data_type, 'S', 0, N, lts, HICMA_descCUV->mt, HICMA_descCUV->nt,
217240
&hicma_problem);
218241
int compress_diag = 0;
219-
HICMA_zgytlr_Tile(EXAGEOSTAT_LOWER, HICMA_descCUV, HICMA_descCD, HICMA_descCrk, 0, max_rank, pow(10, -1.0 * acc),
220-
compress_diag, HICMA_descC);
242+
HICMA_dgytlr_Tile(EXAGEOSTAT_LOWER, HICMA_descCUV, HICMA_descCD, HICMA_descCrk, 0, max_rank,
243+
pow(10, -1.0 * acc), compress_diag, HICMA_descC);
221244

222245
STOP_TIMING(matrix_gen_time);
223246
VERBOSE("Done.")
224-
//******************************
247+
248+
// ******************************
225249
VERBOSE("LR: re-Copy z...")
226250
START_TIMING(test_time);
227-
//re-store old Z
251+
// re-store old Z
228252
this->ExaGeoStatLapackCopyTile(EXAGEOSTAT_UPPER_LOWER, HICMA_descZcpy, HICMA_descZ);
229253
STOP_TIMING(test_time);
230254
VERBOSE("Done.")
231255

232-
//Calculate Cholesky Factorization (C=LL-1)
256+
// Calculate Cholesky Factorization (C=LL-1)
233257
VERBOSE("LR: Cholesky factorization of Sigma...")
234258
START_TIMING(time_facto);
235259

236-
this->ExaGeoStatPotrfTile(EXAGEOSTAT_LOWER, HICMA_descCUV, 0, HICMA_descCD, HICMA_descCrk, max_rank,
237-
pow(10, -1.0 * acc));
260+
this->ExaGeoStatPotrfTile(EXAGEOSTAT_LOWER, HICMA_descCUV, 0, HICMA_descCD, HICMA_descCrk, max_rank, acc);
238261

239262
STOP_TIMING(time_facto);
240263
flops = flops + flops_dpotrf(N);
241264
VERBOSE("Done.")
242265

243-
//Calculate log(|C|) --> log(square(|L|))
266+
// Calculate log(|C|) --> log(square(|L|))
244267
VERBOSE("LR:Calculating the log determinant ...")
245268
START_TIMING(logdet_calculate);
246269
RuntimeFunctions<T>::ExaGeoStatMeasureDetTileAsync(aConfigurations.GetComputation(), HICMA_descCD, pSequence,
@@ -276,10 +299,10 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
276299
VERBOSE(" Done.")
277300
}
278301

279-
//Solving Linear System (L*X=Z)--->inv(L)*Z
302+
// Solving Linear System (L*X=Z)--->inv(L)*Z
280303
VERBOSE("LR:Solving the linear system ...")
281304
START_TIMING(time_solve);
282-
//Compute triangular solve LC*X = Z
305+
// Compute triangular solve LC*X = Z
283306
this->ExaGeoStatTrsmTile(EXAGEOSTAT_LEFT, EXAGEOSTAT_LOWER, EXAGEOSTAT_NO_TRANS, EXAGEOSTAT_NON_UNIT, 1,
284307
HICMA_descCUV, HICMA_descCD, HICMA_descCrk, HICMA_descZ, max_rank);
285308
STOP_TIMING(time_solve);
@@ -293,6 +316,7 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
293316
CHAMELEON_dgemm_Tile(ChamTrans, ChamNoTrans, 1, CHAM_descZ, CHAM_descZ, 0, CHAM_desc_product);
294317
dot_product = *product;
295318
loglik = -0.5 * dot_product - 0.5 * logdet;
319+
296320
if (aConfigurations.GetIsNonGaussian()) {
297321
loglik = loglik - *sum - N * log(theta[3]) - (double) (N / 2.0) * log(2.0 * PI);
298322
} else {
@@ -304,6 +328,7 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
304328
if (aConfigurations.GetLogger()) {
305329
fprintf(aConfigurations.GetFileLogPath(), "\t %d- Model Parameters (", iter_count + 1);
306330
}
331+
307332
if ((aConfigurations.GetKernelName() == "bivariate_matern_parsimonious_profile") ||
308333
(aConfigurations.GetKernelName() == "bivariate_matern_parsimonious2_profile")) {
309334
LOGGER(setprecision(8) << variance1 << setprecision(8) << variance2)
@@ -340,12 +365,10 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
340365
aData->SetMleIterations(aData->GetMleIterations() + 1);
341366

342367
// for experiments and benchmarking
343-
accumulated_executed_time =
344-
Results::GetInstance()->GetTotalModelingExecutionTime() + time_facto + logdet_calculate +
345-
time_solve;
368+
accumulated_executed_time = Results::GetInstance()->GetTotalModelingExecutionTime() + time_facto +
369+
logdet_calculate + time_solve;
346370
Results::GetInstance()->SetTotalModelingExecutionTime(accumulated_executed_time);
347-
accumulated_flops =
348-
Results::GetInstance()->GetTotalModelingFlops() + (flops / 1e9 / (time_facto + time_solve));
371+
accumulated_flops = Results::GetInstance()->GetTotalModelingFlops() + (flops / 1e9 / (time_facto + time_solve));
349372
Results::GetInstance()->SetTotalModelingFlops(accumulated_flops);
350373

351374
Results::GetInstance()->SetMLEIterations(iter_count + 1);
@@ -359,7 +382,6 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
359382
return loglik;
360383
}
361384

362-
363385
template<typename T>
364386
void HicmaImplementation<T>::ExaGeoStatLapackCopyTile(const UpperLower &aUpperLower, void *apA, void *apB) {
365387
int status = HICMA_dlacpy_Tile(aUpperLower, (HICMA_desc_t *) apA, (HICMA_desc_t *) apB);
@@ -389,7 +411,6 @@ void HicmaImplementation<T>::ExaGeoStatPotrfTile(const common::UpperLower &aUppe
389411
if (status != HICMA_SUCCESS) {
390412
throw std::runtime_error("HICMA_dpotrf_Tile Failed, Matrix is not positive definite");
391413
}
392-
393414
}
394415

395416
template<typename T>

0 commit comments

Comments
 (0)