77 * @file HicmaImplementation.cpp 
88 * @brief Sets up the HiCMA descriptors needed for the tile low rank computations in ExaGeoStat. 
99 * @version 1.1.0 
10+  * @author Omar Marzouk 
1011 * @author Mahmoud ElKarargy 
1112 * @author Sameh Abdulah 
1213 * @date 2024-02-04 
13- **/  
14+   **/ 
1415
1516#include  < mkl_service.h> 
1617
1718#ifdef  USE_MPI
18- 
1919#include  < mpi.h> 
20- 
2120#endif 
2221
2322#include  < linear-algebra-solvers/concrete/hicma/tlr/HicmaImplementation.hpp> 
2423
2524using  namespace  std ; 
26- 
2725using  namespace  exageostat ::linearAlgebra::tileLowRank; 
2826using  namespace  exageostat ::common; 
2927using  namespace  exageostat ::runtime; 
@@ -34,11 +32,12 @@ using namespace exageostat::results;
3432
3533int  store_only_diagonal_tiles = 1 ;
3634int  use_scratch = 1 ;
37- int  global_check = 0 ;  // used to create dense matrix for accuracy check
35+ int  global_check = 0 ;  //   used to create dense matrix for accuracy check
3836
3937template <typename  T>
4038void  HicmaImplementation<T>::SetModelingDescriptors(std::unique_ptr<ExaGeoStatData<T>> &aData,
41-                                                     configurations::Configurations &aConfigurations, const  int  &aP) {
39+                                                     configurations::Configurations &aConfigurations, 
40+                                                     const  int  &aP) {
4241
4342    int  full_problem_size = aConfigurations.GetProblemSize () * aP;
4443    int  lts = aConfigurations.GetLowTileSize ();
@@ -61,67 +60,84 @@ void HicmaImplementation<T>::SetModelingDescriptors(std::unique_ptr<ExaGeoStatDa
6160    int  NBD = lts;
6261    int  MD = full_problem_size;
6362    int  ND = MBD;
63+ 
6464#ifdef  USE_MPI
6565    //  Due to a bug in HiCMA with MPI, the first created descriptor is not seen.
66-     aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_CD, is_OOC, nullptr , float_point, 
67-                                               lts, lts, lts * lts, full_problem_size, 1 , 0 ,  0 , full_problem_size,  1 , 
68-                                               p_grid, q_grid);
66+     aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_CD, is_OOC, nullptr , 
67+                                               float_point,  lts, lts, lts * lts, full_problem_size, 1 , 
68+                                               0 ,  0 , full_problem_size,  1 ,  p_grid, q_grid);
6969#endif 
70-     aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_CD, is_OOC, nullptr , float_point,
71-                                               MBD, NBD, MBD * NBD, MD, ND, 0 , 0 , MD, ND, p_grid, q_grid);
70+ 
71+     aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_CD, is_OOC, nullptr , 
72+                                               float_point, MBD, NBD, MBD * NBD, MD, ND, 
73+                                               0 , 0 , MD, ND, p_grid, q_grid);
74+ 
7275    int  MBUV = lts;
7376    int  NBUV = 2  * max_rank;
7477    int  MUV;
7578    int  N_over_lts_times_lts = full_problem_size / lts * lts;
79+     
7680    if  (N_over_lts_times_lts < full_problem_size) {
7781        MUV = N_over_lts_times_lts + lts;
7882    } else  if  (N_over_lts_times_lts == full_problem_size) {
7983        MUV = N_over_lts_times_lts;
8084    } else  {
8185        throw  runtime_error (" This case can't happens, N need to be >= lts*lts" 
8286    }
87+     
8388    int  expr = MUV / lts;
8489    int  NUV = 2  * expr * max_rank;
85-     aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_CUV, is_OOC, nullptr , float_point,
86-                                               MBUV, NBUV, MBUV * NBUV, MUV, NUV, 0 , 0 , MUV, NUV, p_grid, q_grid);
90+     
91+     aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_CUV, is_OOC, nullptr , 
92+                                               float_point, MBUV, NBUV, MBUV * NBUV, MUV, NUV, 
93+                                               0 , 0 , MUV, NUV, p_grid, q_grid);
94+ 
8795    auto  *HICMA_descCUV = aData->GetDescriptorData ()->GetDescriptor (DescriptorType::HICMA_DESCRIPTOR,
8896                                                                    DescriptorName::DESCRIPTOR_CUV).hicma_desc ;
97+ 
8998    int  MBrk = 1 ;
9099    int  NBrk = 1 ;
91100    int  Mrk = HICMA_descCUV->mt ;
92101    int  Nrk = HICMA_descCUV->mt ;
93-     aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_CRK, is_OOC, nullptr , float_point,
94-                                               MBrk, NBrk, MBrk * NBrk, Mrk, Nrk, 0 , 0 , Mrk, Nrk, p_grid, q_grid);
102+     
103+     aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_CRK, is_OOC, nullptr , 
104+                                               float_point, MBrk, NBrk, MBrk * NBrk, Mrk, Nrk, 
105+                                               0 , 0 , Mrk, Nrk, p_grid, q_grid);
106+ 
107+     aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_Z, is_OOC, nullptr , 
108+                                               float_point, lts, lts, lts * lts, full_problem_size, 1 , 
109+                                               0 , 0 , full_problem_size, 1 , p_grid, q_grid);
95110
96-     aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_Z , is_OOC, nullptr , float_point, 
97-                                               lts, lts, lts * lts, full_problem_size, 1 , 0 ,  0 , full_problem_size,  1 , 
98-                                               p_grid, q_grid);
111+     aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_Z_COPY , is_OOC, nullptr , 
112+                                               float_point,  lts, lts, lts * lts, full_problem_size, 1 , 
113+                                               0 ,  0 , full_problem_size,  1 ,  p_grid, q_grid);
99114
100-     aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_Z_COPY, is_OOC, nullptr , float_point,
101-                                               lts, lts, lts * lts, full_problem_size, 1 , 0 , 0 , full_problem_size, 1 ,
102-                                               p_grid, q_grid);
103115    aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_DETERMINANT, is_OOC, nullptr ,
104-                                               float_point, lts, lts, lts * lts, 1 , 1 , 0 , 0 , 1 , 1 , p_grid, q_grid);
116+                                               float_point, lts, lts, lts * lts, 1 , 1 , 
117+                                               0 , 0 , 1 , 1 , p_grid, q_grid);
105118
106119    if  (aConfigurations.GetIsNonGaussian ()) {
107120        aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_PRODUCT, is_OOC, nullptr ,
108-                                                   float_point,
109-                                                   lts, lts, lts * lts,  1 ,  1 ,  0 , 0 , 1 , 1 , p_grid, q_grid);
121+                                                   float_point, lts, lts, lts * lts,  1 ,  1 ,  
122+                                                   0 , 0 , 1 , 1 , p_grid, q_grid);
110123        aData->GetDescriptorData ()->SetDescriptor (common::HICMA_DESCRIPTOR, DESCRIPTOR_SUM, is_OOC, nullptr ,
111-                                                   float_point,
112-                                                   lts, lts, lts * lts,  1 ,  1 ,  0 , 0 , 1 , 1 , p_grid, q_grid);
124+                                                   float_point, lts, lts, lts * lts,  1 ,  1 ,  
125+                                                   0 , 0 , 1 , 1 , p_grid, q_grid);
113126    }
114127}
115128
116129template <typename  T>
117130T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &aData,
118-                                             configurations::Configurations &aConfigurations, const  double  *theta,
119-                                             T *apMeasurementsMatrix, const  Kernel<T> &aKernel) {
131+                                             configurations::Configurations &aConfigurations, 
132+                                             const  double  *theta,
133+                                             T *apMeasurementsMatrix, 
134+                                             const  Kernel<T> &aKernel) {
120135
121136    if  (!aData->GetDescriptorData ()->GetIsDescriptorInitiated ()) {
122137        this ->InitiateDescriptors (aConfigurations, *aData->GetDescriptorData (), aKernel.GetVariablesNumber (),
123138                                  apMeasurementsMatrix);
124139    }
140+ 
125141    //  Create a Hicma sequence, if not initialized before through the same descriptors
126142    RUNTIME_request_t request_array[2 ] = {HICMA_REQUEST_INITIALIZER, HICMA_REQUEST_INITIALIZER};
127143    if  (!aData->GetDescriptorData ()->GetSequence ()) {
@@ -132,8 +148,10 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
132148    }
133149    auto  pSequence = (HICMA_sequence_t *) aData->GetDescriptorData ()->GetSequence ();
134150
135-     // Initialization
136-     T loglik, logdet, test_time, variance, variance1 = 1 , variance2 = 1 , variance3, dot_product, dot_product1, dot_product2, dot_product3, dzcpy_time, time_facto, time_solve, logdet_calculate, matrix_gen_time;
151+     //  Initialization
152+     T loglik, logdet, test_time, variance, variance1 = 1 , variance2 = 1 , variance3;
153+     T dot_product, dot_product1, dot_product2, dot_product3;
154+     T dzcpy_time, time_facto, time_solve, logdet_calculate, matrix_gen_time;
137155    double  accumulated_executed_time, accumulated_flops;
138156
139157    int  NRHS, i;
@@ -150,6 +168,8 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
150168    if  (iter_count == 0 ) {
151169        this ->SetModelingDescriptors (aData, aConfigurations, aKernel.GetVariablesNumber ());
152170    }
171+ 
172+     //  Get descriptor pointers
153173    auto  *HICMA_descCUV = aData->GetDescriptorData ()->GetDescriptor (DescriptorType::HICMA_DESCRIPTOR,
154174                                                                    DescriptorName::DESCRIPTOR_CUV).hicma_desc ;
155175    auto  *HICMA_descC = aData->GetDescriptorData ()->GetDescriptor (DescriptorType::HICMA_DESCRIPTOR,
@@ -172,6 +192,7 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
172192                                                                        DescriptorName::DESCRIPTOR_PRODUCT).chameleon_desc ;
173193    auto  *HICMA_desc_sum = aData->GetDescriptorData ()->GetDescriptor (DescriptorType::HICMA_DESCRIPTOR,
174194                                                                     DescriptorName::DESCRIPTOR_SUM).chameleon_desc ;
195+ 
175196    N = HICMA_descCUV->m ;
176197    NRHS = HICMA_descZ->n ;
177198    lts = HICMA_descZ->mb ;
@@ -194,7 +215,8 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
194215            CHAMELEON_dlacpy_Tile (ChamUpperLower, CHAM_descZ, CHAM_descZcpy);
195216        }
196217    }
197-     // Matrix generation part.
218+ 
219+     //  Matrix generation part.
198220    VERBOSE (" LR:Generate New Covariance Matrix..." 
199221    START_TIMING (matrix_gen_time);
200222
@@ -203,9 +225,10 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
203225    hicma_problem.noise  = 1e-4 ;
204226    hicma_problem.ndim  = 2 ;
205227
206-     hicma_problem.kernel_type  =
207-             aConfigurations.GetDistanceMetric () == common::GREAT_CIRCLE_DISTANCE ? STARSH_SPATIAL_MATERN2_GCD
208-                                                                                  : STARSH_SPATIAL_MATERN2_SIMD;
228+     hicma_problem.kernel_type  = aConfigurations.GetDistanceMetric () == common::GREAT_CIRCLE_DISTANCE 
229+                                 ? STARSH_SPATIAL_MATERN2_GCD
230+                                 : STARSH_SPATIAL_MATERN2_SIMD;
231+ 
209232    int  hicma_data_type;
210233    if  (aConfigurations.GetIsNonGaussian ()) {
211234        hicma_data_type = HICMA_STARSH_PROB_GEOSTAT_NON_GAUSSIAN;
@@ -216,31 +239,31 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
216239    HICMA_zgenerate_problem (hicma_data_type, ' S' 0 , N, lts, HICMA_descCUV->mt , HICMA_descCUV->nt ,
217240                            &hicma_problem);
218241    int  compress_diag = 0 ;
219-     HICMA_zgytlr_Tile (EXAGEOSTAT_LOWER, HICMA_descCUV, HICMA_descCD, HICMA_descCrk, 0 , max_rank, pow ( 10 , - 1.0  * acc), 
220-                       compress_diag, HICMA_descC);
242+     HICMA_dgytlr_Tile (EXAGEOSTAT_LOWER, HICMA_descCUV, HICMA_descCD, HICMA_descCrk, 0 , max_rank, 
243+                       pow ( 10 , - 1.0  * acc),  compress_diag, HICMA_descC);
221244
222245    STOP_TIMING (matrix_gen_time);
223246    VERBOSE (" Done." 
224-     // ******************************
247+ 
248+     //  ******************************
225249    VERBOSE (" LR: re-Copy z..." 
226250    START_TIMING (test_time);
227-     // re-store old Z
251+     //   re-store old Z
228252    this ->ExaGeoStatLapackCopyTile (EXAGEOSTAT_UPPER_LOWER, HICMA_descZcpy, HICMA_descZ);
229253    STOP_TIMING (test_time);
230254    VERBOSE (" Done." 
231255
232-     // Calculate Cholesky Factorization (C=LL-1)
256+     //   Calculate Cholesky Factorization (C=LL-1)
233257    VERBOSE (" LR: Cholesky factorization of Sigma..." 
234258    START_TIMING (time_facto);
235259
236-     this ->ExaGeoStatPotrfTile (EXAGEOSTAT_LOWER, HICMA_descCUV, 0 , HICMA_descCD, HICMA_descCrk, max_rank,
237-                               pow (10 , -1.0  * acc));
260+     this ->ExaGeoStatPotrfTile (EXAGEOSTAT_LOWER, HICMA_descCUV, 0 , HICMA_descCD, HICMA_descCrk, max_rank, acc);
238261
239262    STOP_TIMING (time_facto);
240263    flops = flops + flops_dpotrf (N);
241264    VERBOSE (" Done." 
242265
243-     // Calculate log(|C|) --> log(square(|L|))
266+     //   Calculate log(|C|) --> log(square(|L|))
244267    VERBOSE (" LR:Calculating the log determinant ..." 
245268    START_TIMING (logdet_calculate);
246269    RuntimeFunctions<T>::ExaGeoStatMeasureDetTileAsync (aConfigurations.GetComputation (), HICMA_descCD, pSequence,
@@ -276,10 +299,10 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
276299        VERBOSE ("  Done." 
277300    }
278301
279-     // Solving Linear System (L*X=Z)--->inv(L)*Z
302+     //   Solving Linear System (L*X=Z)--->inv(L)*Z
280303    VERBOSE (" LR:Solving the linear system ..." 
281304    START_TIMING (time_solve);
282-     // Compute triangular solve LC*X = Z
305+     //   Compute triangular solve LC*X = Z
283306    this ->ExaGeoStatTrsmTile (EXAGEOSTAT_LEFT, EXAGEOSTAT_LOWER, EXAGEOSTAT_NO_TRANS, EXAGEOSTAT_NON_UNIT, 1 ,
284307                             HICMA_descCUV, HICMA_descCD, HICMA_descCrk, HICMA_descZ, max_rank);
285308    STOP_TIMING (time_solve);
@@ -293,6 +316,7 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
293316    CHAMELEON_dgemm_Tile (ChamTrans, ChamNoTrans, 1 , CHAM_descZ, CHAM_descZ, 0 , CHAM_desc_product);
294317    dot_product = *product;
295318    loglik = -0.5  * dot_product - 0.5  * logdet;
319+     
296320    if  (aConfigurations.GetIsNonGaussian ()) {
297321        loglik = loglik - *sum - N * log (theta[3 ]) - (double ) (N / 2.0 ) * log (2.0  * PI);
298322    } else  {
@@ -304,6 +328,7 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
304328    if  (aConfigurations.GetLogger ()) {
305329        fprintf (aConfigurations.GetFileLogPath (), " \t  %d- Model Parameters (" 1 );
306330    }
331+ 
307332    if  ((aConfigurations.GetKernelName () == " bivariate_matern_parsimonious_profile" 
308333        (aConfigurations.GetKernelName () == " bivariate_matern_parsimonious2_profile" 
309334        LOGGER (setprecision (8 ) << variance1 << setprecision (8 ) << variance2)
@@ -340,12 +365,10 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
340365    aData->SetMleIterations (aData->GetMleIterations () + 1 );
341366
342367    //  for experiments and benchmarking
343-     accumulated_executed_time =
344-             Results::GetInstance ()->GetTotalModelingExecutionTime () + time_facto + logdet_calculate +
345-             time_solve;
368+     accumulated_executed_time = Results::GetInstance ()->GetTotalModelingExecutionTime () + time_facto + 
369+                                 logdet_calculate + time_solve;
346370    Results::GetInstance ()->SetTotalModelingExecutionTime (accumulated_executed_time);
347-     accumulated_flops =
348-             Results::GetInstance ()->GetTotalModelingFlops () + (flops / 1e9  / (time_facto + time_solve));
371+     accumulated_flops = Results::GetInstance ()->GetTotalModelingFlops () + (flops / 1e9  / (time_facto + time_solve));
349372    Results::GetInstance ()->SetTotalModelingFlops (accumulated_flops);
350373
351374    Results::GetInstance ()->SetMLEIterations (iter_count + 1 );
@@ -359,7 +382,6 @@ T HicmaImplementation<T>::ExaGeoStatMLETile(std::unique_ptr<ExaGeoStatData<T>> &
359382    return  loglik;
360383}
361384
362- 
363385template <typename  T>
364386void  HicmaImplementation<T>::ExaGeoStatLapackCopyTile(const  UpperLower &aUpperLower, void  *apA, void  *apB) {
365387    int  status = HICMA_dlacpy_Tile (aUpperLower, (HICMA_desc_t *) apA, (HICMA_desc_t *) apB);
@@ -389,7 +411,6 @@ void HicmaImplementation<T>::ExaGeoStatPotrfTile(const common::UpperLower &aUppe
389411    if  (status != HICMA_SUCCESS) {
390412        throw  std::runtime_error (" HICMA_dpotrf_Tile Failed, Matrix is not positive definite" 
391413    }
392- 
393414}
394415
395416template <typename  T>
0 commit comments