diff --git a/pydiso/mkl_solver.pyx b/pydiso/mkl_solver.pyx index 82a9179..4328010 100644 --- a/pydiso/mkl_solver.pyx +++ b/pydiso/mkl_solver.pyx @@ -4,9 +4,11 @@ cimport numpy as np from cython cimport numeric import warnings +from time import time import numpy as np import scipy.sparse as sp import os +from pathlib import Path ctypedef long long MKL_INT64 ctypedef unsigned long long MKL_UINT64 @@ -30,6 +32,7 @@ cdef extern from 'mkl.h': void mkl_get_version(MKLVersion* pv) void mkl_set_num_threads(int nth) + int mkl_domain_set_num_threads(int nt, int domain) int mkl_get_max_threads() int mkl_domain_get_max_threads(int domain) @@ -39,6 +42,10 @@ cdef extern from 'mkl.h': ctypedef void * _MKL_DSS_HANDLE_t + void pardiso_handle_store(_MKL_DSS_HANDLE_t pt, char *dirname, int *err) + + void pardiso_handle_restore(_MKL_DSS_HANDLE_t pt, char *dirname, int *err) + void pardiso(_MKL_DSS_HANDLE_t, const int*, const int*, const int*, const int *, const int *, const void *, const int *, const int *, int *, const int *, int *, @@ -184,13 +191,16 @@ cdef class MKLPardisoSolver: cdef int_t _factored cdef size_t shape[2] cdef int_t _initialized + cdef char* call_flag_dir + cdef char* _flag_dir + cdef int_t _store cdef void * a cdef object _data_type cdef object _Adata #a reference to make sure the pointer "a" doesn't get destroyed - def __init__(self, A, matrix_type=None, factor=True, verbose=False): + def __init__(self, A, matrix_type=None, factor=True, verbose=False, store_factorization_dir=None): '''ParidsoSolver(A, matrix_type=None, factor=True, verbose=False) An interface to the intel MKL pardiso sparse matrix solver. @@ -305,6 +315,27 @@ cdef class MKLPardisoSolver: self._set_A(A.data) self._analyze() self._factored = False + + # check if we want to store the factorization + if store_factorization_dir is not None: + + # check if the flag files exist. If so delete them so factorization file get overwritten + check_file = Path(store_factorization_dir) / 'factorization_done.txt' + + if os.path.exists(check_file): + + second_file_to_remove = Path(store_factorization_dir) / "flagfile.txt" + os.remove(check_file) + os.remove(second_file_to_remove) + + self._store = True + flag_dir_ = bytes(store_factorization_dir, 'utf-8') + self._flag_dir = flag_dir_ + + else: + + self._store = False + if factor: self._factor() @@ -422,6 +453,11 @@ cdef class MKLPardisoSolver: else: self._par.iparm[i] = val + def store_factorization(self, directory=b'./'): + + self._store = True + self._flag_dir = directory + @property def nnz(self): return self.iparm[17] @@ -515,11 +551,47 @@ cdef class MKLPardisoSolver: cdef _factor(self): #phase = 22 self._factored = False + + if self._store: + try: - err = self._run_pardiso(22) - if err!=0: - raise PardisoError("Factor step error, "+_err_messages[err]) - self._factored = True + flag_file = self._flag_dir.decode("utf-8") + 'flagfile.txt' + + self.call_flag_dir = self._flag_dir + + with open(flag_file, 'x') as f: + f.write('inversion in progress') + + err = self._run_pardiso(22) + + self._pardiso_store(self.call_flag_dir) + + done_file = self._flag_dir.decode("utf-8") + 'factorization_done.txt' + + with open(done_file, 'w') as f2: + f2.write('done') + + self._factored = True + return + + except FileExistsError: + + # flag file exists, wait for "done" file and read in factorization + done_file = self._flag_dir.decode("utf-8") + 'factorization_done.txt' + + while not os.path.isfile(done_file): + time.sleep(1) + + # now read in the factorization from the file + self.call_flag_dir = self._flag_dir + self._pardiso_restore(self.call_flag_dir) + + else: + + err = self._run_pardiso(22) + if err!=0: + raise PardisoError("Factor step error, "+_err_messages[err]) + self._factored = True cdef _solve(self, void* b, void* x, int_t nrhs_in): #phase = 33 @@ -544,3 +616,16 @@ cdef class MKLPardisoSolver: &phase64, &self._par64.n, self.a, &self._par64.ia[0], &self._par64.ja[0], &self._par64.perm[0], &nrhs64, self._par64.iparm, &self._par64.msglvl, b, x, &error64) return error64 + + cdef _pardiso_store(self, char *dir_name): + + cdef int_t error=0 + + pardiso_handle_store(self.handle, dir_name, &error) + + cdef _pardiso_restore(self, char *dir_name): + + cdef int_t error=0 + + pardiso_handle_restore(self.handle, dir_name, &error) + diff --git a/tests/test.py b/tests/test.py index c71f05e..ae73545 100644 --- a/tests/test.py +++ b/tests/test.py @@ -110,6 +110,44 @@ def test_multiple_RHS(): assert rel_err < 1E3*eps return rel_err +def test_multiple_RHS_store_factorization(): + A = A_real_dict["real_symmetric_positive_definite"] + x = np.c_[xr, xr] + b = A @ x + + solver = Solver(A, "real_symmetric_positive_definite", store_factorization_dir='./') + x2 = solver.solve(b) + + eps = np.finfo(np.float64).eps + rel_err = np.linalg.norm(x-x2)/np.linalg.norm(x) + assert rel_err < 1E3*eps + return rel_err + +def test_multiple_RHS_store_factorization_clean_flag_files(): + A = A_real_dict["real_symmetric_positive_definite"] + x = np.c_[xr, xr] + b = A @ x + + solver = Solver(A, "real_symmetric_positive_definite", store_factorization_dir='./') + x2 = solver.solve(b) + + eps = np.finfo(np.float64).eps + rel_err = np.linalg.norm(x-x2)/np.linalg.norm(x) + + assert rel_err < 1E3*eps + + # run again to make sure the created flag files are checked and removed and running again works + x3 = solver.solve(b) + + eps3 = np.finfo(np.float64).eps + rel_err3 = np.linalg.norm(x-x2)/np.linalg.norm(x) + + assert rel_err3 < 1E3*eps3 + + assert rel_err == rel_err3 + + return rel_err + def test_matrix_type_errors(): A = A_real_dict["real_symmetric_positive_definite"]