Skip to content

Commit

Permalink
Update QUDA to lattice/quda#1489.
Browse files Browse the repository at this point in the history
  • Loading branch information
SaltyChiang committed Oct 12, 2024
1 parent 55eff52 commit c1c8baf
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 55 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Python wrapper for [QUDA](https://github.com/lattice/quda) written in Cython.

This project aims to benefit from the optimized linear algebra library [CuPy](https://cupy.dev/) in Python based on CUDA. CuPy and QUDA will allow us to perform most lattice QCD research operations with high performance. [PyTorch](https://pytorch.org/) is an alternative option.

This project is based on the latest QUDA `develop` branch. PyQUDA should be compatible with any commit of QUDA after 2024, but leave some features disabled.
This project is based on the latest QUDA `develop` branch. PyQUDA should be compatible with any commit of QUDA after https://github.com/lattice/quda/pull/1489, but leave some features disabled.

## Feature

Expand Down
2 changes: 2 additions & 0 deletions pyquda/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
LatticeGauge,
LatticeMom,
LatticeFermion,
MultiLatticeFermion,
LatticeStaggeredFermion,
MultiLatticeStaggeredFermion,
LatticePropagator,
LatticeStaggeredPropagator,
lexico,
Expand Down
28 changes: 28 additions & 0 deletions pyquda/dirac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
QudaGaugeSmearParam,
QudaGaugeObservableParam,
invertQuda,
invertMultiSrcQuda,
invertMultiShiftQuda,
MatQuda,
MatDagMatQuda,
dslashQuda,
dslashMultiSrcQuda,
newMultigridQuda,
updateMultigridQuda,
destroyMultigridQuda,
Expand Down Expand Up @@ -232,6 +234,19 @@ def dslash(self, x: LatticeFermion, parity: QudaParity):
dslashQuda(b.data_ptr, x.data_ptr, self.invert_param, parity)
return b

def invertMultiSrc(self, b: MultiLatticeFermion):
self.invert_param.num_src = b.L5
x = MultiLatticeFermion(b.latt_info, b.L5)
invertMultiSrcQuda(x.data_ptrs, b.data_ptrs, self.invert_param)
self.performance()
return x

def dslashMultiSrc(self, x: MultiLatticeFermion, parity: QudaParity):
self.invert_param.num_src = x.L5
b = MultiLatticeFermion(x.latt_info, x.L5)
dslashMultiSrcQuda(b.data_ptrs, x.data_ptrs, self.invert_param, parity)
return b

def _invertMultiShiftParam(self, offset: List[float], residue: List[float], norm: float = None):
assert len(offset) == len(residue)
num_offset = len(offset)
Expand Down Expand Up @@ -349,6 +364,19 @@ def dslash(self, x: LatticeStaggeredFermion, parity: QudaParity):
dslashQuda(b.data_ptr, x.data_ptr, self.invert_param, parity)
return b

def invertMultiSrc(self, b: MultiLatticeStaggeredFermion):
self.invert_param.num_src = b.L5
x = MultiLatticeStaggeredFermion(b.latt_info, b.L5)
invertMultiSrcQuda(x.data_ptrs, b.data_ptrs, self.invert_param)
self.performance()
return x

def dslashMultiSrc(self, x: MultiLatticeStaggeredFermion, parity: QudaParity):
self.invert_param.num_src = x.L5
b = MultiLatticeStaggeredFermion(x.latt_info, x.L5)
dslashMultiSrcQuda(b.data_ptrs, x.data_ptrs, self.invert_param, parity)
return b

def invertMultiShiftPC(
self, b: LatticeStaggeredFermion, offset: List[float], residue: List[float], norm: float = None
) -> Union[LatticeStaggeredFermion, MultiLatticeStaggeredFermion]:
Expand Down
1 change: 1 addition & 0 deletions pyquda/dirac/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def newQudaMultigridParam(
mg_param.n_block_ortho = [1] * QUDA_MAX_MG_LEVEL

mg_param.setup_inv_type = [QudaInverterType.QUDA_CGNR_INVERTER] * QUDA_MAX_MG_LEVEL
mg_param.n_vec_batch = [1] * QUDA_MAX_MG_LEVEL
mg_param.num_setup_iter = [1] * QUDA_MAX_MG_LEVEL
mg_param.setup_tol = [setup_tol] * QUDA_MAX_MG_LEVEL
mg_param.setup_maxiter = [setup_maxiter] * QUDA_MAX_MG_LEVEL
Expand Down
14 changes: 7 additions & 7 deletions pyquda/enum_quda.in.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,9 @@
This number may be changed if need be.
"""

QUDA_MAX_BLOCK_SRC = 64
QUDA_MAX_MULTI_SRC = 128
"""
Maximum number of sources that can be supported by the block solver
"""

QUDA_MAX_ARRAY_SIZE = max(QUDA_MAX_MULTI_SHIFT, QUDA_MAX_BLOCK_SRC)
"""
Maximum array length used in QudaInvertParam arrays
Maximum number of sources that can be supported by the multi-src solver
"""

QUDA_MAX_DWF_LS = 32
Expand Down Expand Up @@ -97,6 +92,11 @@ class QudaGaugeFixed(IntEnum):


class QudaDslashType(IntEnum):
"""
Note: make sure QudaDslashType has corresponding entries in
tests/utils/misc.cpp
"""

pass


Expand Down
81 changes: 63 additions & 18 deletions pyquda/pyquda.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ from .enum_quda import ( # noqa: F401
QUDA_MAX_DIM,
QUDA_MAX_GEOMETRY,
QUDA_MAX_MULTI_SHIFT,
QUDA_MAX_BLOCK_SRC,
QUDA_MAX_ARRAY_SIZE,
QUDA_MAX_MULTI_SRC,
QUDA_MAX_DWF_LS,
QUDA_MAX_MG_LEVEL,
qudaError_t,
Expand Down Expand Up @@ -277,10 +276,10 @@ class QudaInvertParam:

compute_true_res: int
"""Whether to compute the true residual post solve"""
true_res: double
"""Actual L2 residual norm achieved in solver"""
true_res_hq: double
"""Actual heavy quark residual norm achieved in solver"""
true_res: List[double, QUDA_MAX_MULTI_SRC]
"""Actual L2 residual norm achieved in the solver"""
true_res_hq: List[double, QUDA_MAX_MULTI_SRC]
"""Actual heavy quark residual norm achieved in the solver"""
maxiter: int
"""Maximum number of iterations in the linear solver"""
reliable_delta: double
Expand Down Expand Up @@ -473,6 +472,14 @@ class QudaInvertParam:
"""The Gflops rate of the solver"""
secs: double
"""The time taken by the solver"""
energy: double
"""The energy consumed by the solver"""
power: double
"""The mean power of the solver"""
temp: double
"""The mean temperature of the device for the duration of the solve"""
clock: double
"""The mean clock frequency of the device for the duration of the solve"""

tune: QudaTune
"""Enable auto-tuning? (default = QUDA_TUNE_YES)"""
Expand Down Expand Up @@ -763,6 +770,8 @@ class QudaEigParam:
"""For the Ritz rotation, the maximal number of extra vectors the solver may allocate"""
block_size: int
"""For block method solvers, the block size"""
compute_evals_batch_size: int
"""The batch size used when computing eigenvalues"""
max_ortho_attempts: int
"""For block method solvers, quit after n attempts at block orthonormalisation"""
ortho_block_size: int
Expand Down Expand Up @@ -815,12 +824,6 @@ class QudaEigParam:
partfile: QudaBoolean
"""Whether to save eigenvectors in QIO singlefile or partfile format"""

gflops: double
"""The Gflops rate of the eigensolver setup"""

secs: double
"""The time taken by the eigensolver setup"""

extlib_type: QudaExtLibType
"""Which external library to use in the deflation operations (Eigen)"""

Expand Down Expand Up @@ -868,6 +871,9 @@ class QudaMultigridParam:
setup_inv_type: List[QudaInverterType, QUDA_MAX_MG_LEVEL]
"""Inverter to use in the setup phase"""

n_vec_batch: List[int, QUDA_MAX_MG_LEVEL]
"""Solver batch size to use in the setup phase"""

num_setup_iter: List[int, QUDA_MAX_MG_LEVEL]
"""Number of setup iterations"""

Expand Down Expand Up @@ -1022,12 +1028,6 @@ class QudaMultigridParam:
preserve_deflation: QudaBoolean
"""Whether to preserve the deflation space during MG update"""

gflops: double
"""The Gflops rate of the multigrid solver setup"""

secs: double
"""The time taken by the multigrid solver setup"""

mu_factor: List[double, QUDA_MAX_MG_LEVEL]
"""Multiplicative factor for the mu parameter"""

Expand Down Expand Up @@ -1364,6 +1364,25 @@ def invertQuda(h_x: Pointer, h_b: Pointer, param: QudaInvertParam) -> None:
Contains all metadata regarding host and device
storage and solver parameters
"""

def invertMultiSrcQuda(_hp_x: Pointers, _hp_b: Pointers, param: QudaInvertParam) -> None:
"""
Perform the solve like @invertQuda but for multiple rhs by spliting the comm grid into
sub-partitions: each sub-partition invert one or more rhs'.
The QudaInvertParam object specifies how the solve should be performed on each sub-partition.
Unlike @invertQuda, the interface also takes the host side gauge as input. The gauge pointer and
gauge_param are used if for inv_param split_grid[0] * split_grid[1] * split_grid[2] * split_grid[3]
is larger than 1, in which case gauge field is not required to be loaded beforehand; otherwise
this interface would just work as @invertQuda, which requires gauge field to be loaded beforehand,
and the gauge field pointer and gauge_param are not used.
@param _hp_x:
Array of solution spinor fields
@param _hp_b:
Array of source spinor fields
@param param:
Contains all metadata regarding host and device storage and solver parameters
"""
...

def invertMultiShiftQuda(_hp_x: Pointers, _hp_b: Pointer, param: QudaInvertParam) -> None:
Expand Down Expand Up @@ -1456,6 +1475,24 @@ def dslashQuda(h_out: Pointer, h_in: Pointer, inv_param: QudaInvertParam, parity
"""
...

def dslashMultiSrcQuda(_hp_x: Pointers, _hp_b: Pointers, param: QudaInvertParam, parity: QudaParity) -> None:
"""
Perform the solve like @dslashQuda but for multiple rhs by spliting the comm grid into
sub-partitions: each sub-partition does one or more rhs'.
The QudaInvertParam object specifies how the solve should be performed on each sub-partition.
Unlike @invertQuda, the interface also takes the host side gauge as
input - gauge field is not required to be loaded beforehand.
@param _hp_x:
Array of solution spinor fields
@param _hp_b:
Array of source spinor fields
@param param:
Contains all metadata regarding host and device storage and solver parameters
@param parity:
Parity to apply dslash on
"""

def cloverQuda(h_out: Pointer, h_in: Pointer, inv_param: QudaInvertParam, parity: QudaParity, inverse: int) -> None:
"""
Apply the clover operator or its inverse.
Expand Down Expand Up @@ -2074,6 +2111,14 @@ class QudaQuarkSmearParam:
"""Time taken for the smearing operations"""
gflops: double
"""Flops count for the smearing operations"""
energy: double
"""The energy consumed by the smearing operations"""
power: double
"""The mean power of the smearing operations"""
temp: double
"""The mean temperature of the device for the duration of the smearing operations"""
clock: double
"""The mean clock frequency of the device for the duration of the smearing operations"""

def performTwoLinkGaussianSmearNStep(h_in: Pointer, smear_param: QudaQuarkSmearParam) -> None:
"""
Expand Down
2 changes: 2 additions & 0 deletions pyquda/quda/include/enum_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ typedef enum QudaGaugeFixed_s {
// Types used in QudaInvertParam
//

// Note: make sure QudaDslashType has corresponding entries in
// tests/utils/misc.cpp
typedef enum QudaDslashType_s {
QUDA_WILSON_DSLASH,
QUDA_CLOVER_WILSON_DSLASH,
Expand Down
29 changes: 15 additions & 14 deletions pyquda/quda/include/quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ extern "C" {
double tol_hq; /**< Solver tolerance in the heavy quark residual norm */

int compute_true_res; /** Whether to compute the true residual post solve */
double true_res; /**< Actual L2 residual norm achieved in solver */
double true_res_hq; /**< Actual heavy quark residual norm achieved in solver */
double true_res[QUDA_MAX_MULTI_SRC]; /**< Actual L2 residual norm achieved in the solver */
double true_res_hq[QUDA_MAX_MULTI_SRC]; /**< Actual heavy quark residual norm achieved in the solver */
int maxiter; /**< Maximum number of iterations in the linear solver */
double reliable_delta; /**< Reliable update tolerance */
double reliable_delta_refinement; /**< Reliable update tolerance used in post multi-shift solver refinement */
Expand Down Expand Up @@ -278,6 +278,10 @@ extern "C" {
int iter; /**< The number of iterations performed by the solver */
double gflops; /**< The Gflops rate of the solver */
double secs; /**< The time taken by the solver */
double energy; /**< The energy consumed by the solver */
double power; /**< The mean power of the solver */
double temp; /**< The mean temperature of the device for the duration of the solve */
double clock; /**< The mean clock frequency of the device for the duration of the solve */

QudaTune tune; /**< Enable auto-tuning? (default = QUDA_TUNE_YES) */

Expand Down Expand Up @@ -550,6 +554,8 @@ extern "C" {
int batched_rotate;
/** For block method solvers, the block size **/
int block_size;
/** The batch size used when computing eigenvalues **/
int compute_evals_batch_size;
/** For block method solvers, quit after n attempts at block orthonormalisation **/
int max_ortho_attempts;
/** For hybrid modifeld Gram-Schmidt orthonormalisations **/
Expand Down Expand Up @@ -602,12 +608,6 @@ extern "C" {
/** Whether to save eigenvectors in QIO singlefile or partfile format */
QudaBoolean partfile;

/** The Gflops rate of the eigensolver setup */
double gflops;

/**< The time taken by the eigensolver setup */
double secs;

/** Which external library to use in the deflation operations (Eigen) */
QudaExtLibType extlib_type;
//-------------------------------------------------
Expand Down Expand Up @@ -655,6 +655,9 @@ extern "C" {
/** Inverter to use in the setup phase */
QudaInverterType setup_inv_type[QUDA_MAX_MG_LEVEL];

/** Solver batch size to use in the setup phase */
int n_vec_batch[QUDA_MAX_MG_LEVEL];

/** Number of setup iterations */
int num_setup_iter[QUDA_MAX_MG_LEVEL];

Expand Down Expand Up @@ -805,12 +808,6 @@ extern "C" {
/** Whether to preserve the deflation space during MG update */
QudaBoolean preserve_deflation;

/** The Gflops rate of the multigrid solver setup */
double gflops;

/**< The time taken by the multigrid solver setup */
double secs;

/** Multiplicative factor for the mu parameter */
double mu_factor[QUDA_MAX_MG_LEVEL];

Expand Down Expand Up @@ -1819,6 +1816,10 @@ extern "C" {
double secs;
/** Flops count for the smearing operations **/
double gflops;
double energy; /**< The energy consumed by the smearing operations */
double power; /**< The mean power of the smearing operations */
double temp; /**< The mean temperature of the device for the duration of the smearing operations */
double clock; /**< The mean clock frequency of the device for the duration of the smearing operations */

} QudaQuarkSmearParam;

Expand Down
10 changes: 2 additions & 8 deletions pyquda/quda/include/quda_constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,9 @@

/**
* @def QUDA_MAX_BLOCK_SRC
* @brief Maximum number of sources that can be supported by the block solver
* @brief Maximum number of sources that can be supported by the multi-src solver
*/
#define QUDA_MAX_BLOCK_SRC 64

/**
* @def QUDA_MAX_ARRAY
* @brief Maximum array length used in QudaInvertParam arrays
*/
#define QUDA_MAX_ARRAY_SIZE (QUDA_MAX_MULTI_SHIFT > QUDA_MAX_BLOCK_SRC ? QUDA_MAX_MULTI_SHIFT : QUDA_MAX_BLOCK_SRC)
#define QUDA_MAX_MULTI_SRC 128

/**
* @def QUDA_MAX_DWF_LS
Expand Down
10 changes: 4 additions & 6 deletions pyquda/src/pyquda.in.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,8 @@ def eigensolveQuda(Pointers h_evecs, ndarray[double_complex, ndim=1] h_evals, Qu
def invertQuda(Pointer h_x, Pointer h_b, QudaInvertParam param):
quda.invertQuda(h_x.ptr, h_b.ptr, &param.param)

# def invertMultiSrcQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param, Pointer h_gauge, QudaGaugeParam gauge_param)
# def invertMultiSrcStaggeredQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param, Pointer milc_fatlinks, Pointer milc_longlinks, QudaGaugeParam gauge_param)
# def invertMultiSrcCloverQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param, Pointer h_gauge, QudaGaugeParam gauge_param, Pointer h_clover, Pointer h_clovinv)
def invertMultiSrcQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param):
quda.invertMultiSrcQuda(_hp_x.ptrs, _hp_b.ptrs, &param.param)

def invertMultiShiftQuda(Pointers _hp_x, Pointer _hp_b, QudaInvertParam param):
quda.invertMultiShiftQuda(_hp_x.ptrs, _hp_b.ptr, &param.param)
Expand All @@ -276,9 +275,8 @@ def dumpMultigridQuda(Pointer mg_instance, QudaMultigridParam param):
def dslashQuda(Pointer h_out, Pointer h_in, QudaInvertParam inv_param, quda.QudaParity parity):
quda.dslashQuda(h_out.ptr, h_in.ptr, &inv_param.param, parity)

# def dslashMultiSrcQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param, QudaParity parity, Pointer h_gauge, QudaGaugeParam gauge_param)
# def dslashMultiSrcStaggeredQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param, QudaParity parity, Pointers milc_fatlinks, Pointers milc_longlinks, QudaGaugeParam gauge_param)
# def dslashMultiSrcCloverQuda(Pointers_hp_x, Pointers_hp_b, QudaInvertParam param, QudaParity parity, Pointer h_gauge, QudaGaugeParam gauge_param, Pointer h_clover, Pointer h_clovinv)
def dslashMultiSrcQuda(Pointers _hp_x, Pointers _hp_b, QudaInvertParam param, quda.QudaParity parity):
quda.dslashMultiSrcQuda(_hp_x.ptrs, _hp_b.ptrs, &param.param, parity)

def cloverQuda(Pointer h_out, Pointer h_in, QudaInvertParam inv_param, quda.QudaParity parity, int inverse):
quda.cloverQuda(h_out.ptr, h_in.ptr, &inv_param.param, parity, inverse)
Expand Down
2 changes: 1 addition & 1 deletion pyquda/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.77"
__version__ = "0.8.0"

0 comments on commit c1c8baf

Please sign in to comment.