Skip to content

Commit

Permalink
Restructured to preserve compatibility with older MILC
Browse files Browse the repository at this point in the history
  • Loading branch information
leonhostetler committed Jan 25, 2025
1 parent 4f57280 commit 6d93109
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 23 deletions.
110 changes: 105 additions & 5 deletions include/quda_milc_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,40 @@ extern "C" {
HISQ fermions since the tadpole factor is
baked into the links during their construction */
double naik_epsilon; /** Naik epsilon parameter (HISQ fermions only).*/

QudaEigParam eig_param; /** To pass deflation-related stuff like eigenvector filename **/
double tol_restart;
QudaPrecision prec_eigensolver;
} QudaInvertArgs_t;

/**
* Parameters related to deflated solvers.
* Parameters related to deflated linear solvers.
*/
typedef struct {
double tol_restart;
QudaPrecision prec_eigensolver;

int poly_deg; /** Degree of the Chebyshev polynomial **/
double a_min; /** Range used in polynomial acceleration **/
double a_max;
QudaBoolean preserve_evals; /** Whether to preserve the evals or recompute them **/
int n_ev; /** Size of the eigenvector search space **/
int n_kr; /** Total size of Krylov space **/
int n_conv; /** Number of requested converged eigenvectors **/
int n_ev_deflate; /** Number of requested converged eigenvectors to use in deflation **/
double tol; /** Tolerance on the least well known eigenvalue's residual **/
int max_restarts; /** For IRLM/IRAM, quit after n restarts **/
int batched_rotate; /** For the Ritz rotation, the maximal number of extra vectors the solver may allocate **/
int block_size; /** For block method solvers, the block size **/
char vec_infile[256]; /** Filename prefix where to load the null-space vectors */
char vec_outfile[256]; /** Filename prefix for where to save the null-space vectors */
QudaPrecision save_prec; /** The precision with which to save the vectors */
QudaBoolean partfile; /** Whether to save eigenvectors in QIO singlefile or partfile format */
QudaBoolean io_parity_inflate; /** Whether to inflate single-parity eigen-vector I/O **/
QudaBoolean use_norm_op;
QudaBoolean use_pc;

} QudaEigensolverArgs_t;


/**
* Parameters related to EigCG deflated solvers.
*/

typedef struct {
Expand Down Expand Up @@ -372,6 +398,42 @@ extern "C" {
double* const final_rel_resid,
int* num_iters);

/**
* Solve Ax=b with deflation for an improved staggered operator. All fields are fields
* passed and returned are host (CPU) field in MILC order. This
* function requires that persistent gauge and clover fields have
* been created prior. This interface is experimental.
*
* @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single)
* @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single)
* @param[in] mass Fermion mass parameter
* @param[in] inv_args Struct setting some solver metadata
* @param[in] eig_args Struct setting some eigensolver metadata
* @param[in] target_residual Target residual
* @param[in] target_relative_residual Target Fermilab residual
* @param[in] milc_fatlink Fat-link field on the host
* @param[in] milc_longlink Long-link field on the host
* @param[in] source Right-hand side source field
* @param[out] solution Solution spinor field
* @param[in] final_residual True residual
* @param[in] final_relative_residual True Fermilab residual
* @param[in] num_iters Number of iterations taken
*/
void qudaInvertDeflatable(int external_precision,
int quda_precision,
double mass,
QudaInvertArgs_t inv_args,
QudaEigensolverArgs_t eig_args,
double target_residual,
double target_fermilab_residual,
const void* const milc_fatlink,
const void* const milc_longlink,
void* source,
void* solution,
double* const final_resid,
double* const final_rel_resid,
int* num_iters);

/**
* Prepare a staggered/HISQ multigrid solve with given fat and
* long links. All fields passed are host (CPU) fields
Expand Down Expand Up @@ -464,6 +526,44 @@ extern "C" {
int* num_iters,
int num_src);

/**
* Solve Ax=b with deflation for an improved staggered operator with many right hand sides.
* All fields are fields passed and returned are host (CPU) field in MILC order.
* This function requires that persistent gauge and clover fields have
* been created prior. This interface is experimental.
*
* @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single)
* @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single)
* @param[in] mass Fermion mass parameter
* @param[in] inv_args Struct setting some solver metadata
* @param[in] eig_args Struct setting some eigensolver metadata
* @param[in] target_residual Target residual
* @param[in] target_relative_residual Target Fermilab residual
* @param[in] milc_fatlink Fat-link field on the host
* @param[in] milc_longlink Long-link field on the host
* @param[in] source array of right-hand side source fields
* @param[out] solution array of solution spinor fields
* @param[in] final_residual True residual
* @param[in] final_relative_residual True Fermilab residual
* @param[in] num_iters Number of iterations taken
* @param[in] num_src Number of source fields
*/
void qudaInvertMsrcDeflatable(int external_precision,
int quda_precision,
double mass,
QudaInvertArgs_t inv_args,
QudaEigensolverArgs_t eig_args,
double target_residual,
double target_fermilab_residual,
const void* const fatlink,
const void* const longlink,
void** sourceArray,
void** solutionArray,
double* const final_residual,
double* const final_fermilab_residual,
int* num_iters,
int num_src);

/**
* Solve for multiple shifts (e.g., masses) using an improved
* staggered operator. All fields are fields passed and returned
Expand Down
149 changes: 131 additions & 18 deletions lib/milc_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1153,14 +1153,72 @@ void qudaMultishiftInvert(int external_precision, int quda_precision, int num_of
qudamilc_called<false>(__func__, verbosity);
} // qudaMultiShiftInvert

// Wrapper function for qudaInvertDeflatable to maintain backward compatibility with old(er) MILC
void qudaInvert(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args,
double target_residual, double target_fermilab_residual, const void *const fatlink,
const void *const longlink, void *source, void *solution, double *const final_residual,
double *const final_fermilab_residual, int *num_iters)
{

// If this function is called then QUDA is not doing deflation
// Create dummy QudaEigensolverArgs_t that requests 0 eigenvalues
QudaEigensolverArgs_t eig_args;
eig_args.n_ev_deflate=0;
eig_args.prec_eigensolver = QUDA_DOUBLE_PRECISION; /* Not used in the undeflated case but needs to be initialized */

qudaInvertDeflatable(external_precision, quda_precision, mass, inv_args, eig_args,
target_residual, target_fermilab_residual, fatlink,
longlink, source, solution, final_residual,
final_fermilab_residual, num_iters);

} // qudaInvert


void qudaInvertDeflatable(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args, QudaEigensolverArgs_t eig_args,
double target_residual, double target_fermilab_residual, const void *const fatlink,
const void *const longlink, void *source, void *solution, double *const final_residual,
double *const final_fermilab_residual, int *num_iters)
{
static const QudaVerbosity verbosity = getVerbosity();
qudamilc_called<true>(__func__, verbosity);

QudaParity local_parity = inv_args.evenodd;

// parameters for the eigensolve/deflation
QudaEigParam qep = newQudaEigParam();
qep.block_size = eig_args.block_size;
qep.n_conv = eig_args.n_conv;
qep.n_ev_deflate = eig_args.n_ev_deflate;
qep.n_ev = eig_args.n_ev;
qep.n_kr = eig_args.n_kr;
qep.tol = eig_args.tol;
qep.max_restarts = eig_args.max_restarts;
qep.poly_deg = eig_args.poly_deg;
qep.a_min = eig_args.a_min;
qep.a_max = eig_args.a_max;
strcpy( qep.vec_infile, eig_args.vec_infile );
strcpy( qep.vec_outfile, eig_args.vec_outfile );
qep.preserve_evals = eig_args.preserve_evals;
qep.batched_rotate = eig_args.batched_rotate;
qep.save_prec = eig_args.save_prec;
qep.partfile = eig_args.partfile;
qep.io_parity_inflate = eig_args.io_parity_inflate;
qep.use_norm_op = eig_args.use_norm_op;
qep.use_pc = eig_args.use_pc;
qep.eig_type = ( qep.block_size > 1 ) ? QUDA_EIG_BLK_TR_LANCZOS : QUDA_EIG_TR_LANCZOS; /* or QUDA_EIG_IR_ARNOLDI, QUDA_EIG_BLK_IR_ARNOLDI */
qep.spectrum = QUDA_SPECTRUM_SR_EIG; /* Smallest Real. Other options: LM, SM, LR, SR, LI, SI */
qep.qr_tol = qep.tol;
qep.require_convergence = QUDA_BOOLEAN_TRUE;
qep.check_interval = 10;
qep.use_dagger = QUDA_BOOLEAN_FALSE;
qep.compute_gamma5 = QUDA_BOOLEAN_FALSE;
qep.compute_svd = QUDA_BOOLEAN_FALSE;
qep.use_eigen_qr = QUDA_BOOLEAN_TRUE;
qep.use_poly_acc = QUDA_BOOLEAN_TRUE;
qep.arpack_check = QUDA_BOOLEAN_FALSE;
qep.compute_evals_batch_size = 16;
qep.preserve_deflation = QUDA_BOOLEAN_TRUE;

if (target_fermilab_residual == 0 && target_residual == 0) errorQuda("qudaInvert: requesting zero residual\n");

QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION;
Expand Down Expand Up @@ -1198,26 +1256,25 @@ void qudaInvert(int external_precision, int quda_precision, double mass, QudaInv

QudaInvertParam invertParam = newQudaInvertParam();

QudaParity local_parity = inv_args.evenodd;
const double reliable_delta = 1e-1;

setInvertParams(host_precision, device_precision, device_precision_sloppy, mass, target_residual,
target_fermilab_residual, inv_args.max_iter, reliable_delta, local_parity, verbosity,
QUDA_CG_INVERTER, &invertParam);

// Deflation for even parity solves when desired
invertParam.eig_param = (local_parity == QUDA_EVEN_PARITY)&&(inv_args.eig_param.n_ev_deflate>0) ? &inv_args.eig_param : nullptr;
invertParam.tol_restart = inv_args.tol_restart;
invertParam.eig_param = (local_parity == QUDA_EVEN_PARITY)&&(qep.n_ev_deflate>0) ? &qep : nullptr;
invertParam.tol_restart = eig_args.tol_restart;

// Eigensolver precision
invertParam.cuda_prec_eigensolver = inv_args.prec_eigensolver;
invertParam.cuda_prec_eigensolver = eig_args.prec_eigensolver;

// Preserve deflation space
static bool deflation_init = false;
if (invertParam.eig_param && inv_args.eig_param.preserve_deflation) {
if (invertParam.eig_param && qep.preserve_deflation) {
if (deflation_init) {
if (!preserved_deflation_space) errorQuda("Unexpected nullptr for preserved deflation space");
inv_args.eig_param.preserve_deflation_space = preserved_deflation_space;
qep.preserve_deflation_space = preserved_deflation_space;
}
}

Expand All @@ -1239,8 +1296,8 @@ void qudaInvert(int external_precision, int quda_precision, double mass, QudaInv

invertQuda(static_cast<char *>(solution) + quark_offset, static_cast<char *>(source) + quark_offset, &invertParam);

if (invertParam.eig_param && inv_args.eig_param.preserve_deflation) {
preserved_deflation_space = inv_args.eig_param.preserve_deflation_space;
if (invertParam.eig_param && qep.preserve_deflation) {
preserved_deflation_space = qep.preserve_deflation_space;
deflation_init = true; // signal that we have deflation space preserved
}

Expand All @@ -1252,7 +1309,7 @@ void qudaInvert(int external_precision, int quda_precision, double mass, QudaInv
if (!create_quda_gauge) invalidateGaugeQuda();

qudamilc_called<false>(__func__, verbosity);
} // qudaInvert
} // qudaInvertDeflatable


void qudaDslash(int external_precision, int quda_precision, QudaInvertArgs_t inv_args, const void *const fatlink,
Expand Down Expand Up @@ -1399,13 +1456,70 @@ void qudaSpinTaste(int external_precision, int quda_precision, const void *const
qudamilc_called<false>(__func__, verbosity);
} // qudaSpinTaste

// Wrapper function for qudaInvertMsrcDeflatable to maintain backward compatibility with old(er) MILC
void qudaInvertMsrc(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args,
double target_residual, double target_fermilab_residual, const void *const fatlink,
const void *const longlink, void **sourceArray, void **solutionArray, double *const final_residual,
double *const final_fermilab_residual, int *num_iters, int num_src)
{

// If this function is called then QUDA is not doing deflation
// Create dummy QudaEigensolverArgs_t that requests 0 eigenvalues
QudaEigensolverArgs_t eig_args;
eig_args.n_ev_deflate=0;
eig_args.prec_eigensolver = QUDA_DOUBLE_PRECISION; /* Not used in the undeflated case but needs to be initialized */

qudaInvertMsrcDeflatable(external_precision, quda_precision, mass, inv_args, eig_args,
target_residual, target_fermilab_residual, fatlink,
longlink, sourceArray, solutionArray, final_residual,
final_fermilab_residual, num_iters, num_src);

} // qudaInvertMsrc

void qudaInvertMsrcDeflatable(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args, QudaEigensolverArgs_t eig_args,
double target_residual, double target_fermilab_residual, const void *const fatlink,
const void *const longlink, void **sourceArray, void **solutionArray, double *const final_residual,
double *const final_fermilab_residual, int *num_iters, int num_src)
{
static const QudaVerbosity verbosity = getVerbosity();
qudamilc_called<true>(__func__, verbosity);

QudaParity local_parity = inv_args.evenodd;

// parameters for the eigensolve/deflation
QudaEigParam qep = newQudaEigParam();
qep.block_size = eig_args.block_size;
qep.n_conv = eig_args.n_conv;
qep.n_ev_deflate = eig_args.n_ev_deflate;
qep.n_ev = eig_args.n_ev;
qep.n_kr = eig_args.n_kr;
qep.tol = eig_args.tol;
qep.max_restarts = eig_args.max_restarts;
qep.poly_deg = eig_args.poly_deg;
qep.a_min = eig_args.a_min;
qep.a_max = eig_args.a_max;
strcpy( qep.vec_infile, eig_args.vec_infile );
strcpy( qep.vec_outfile, eig_args.vec_outfile );
qep.preserve_evals = eig_args.preserve_evals;
qep.batched_rotate = eig_args.batched_rotate;
qep.save_prec = eig_args.save_prec;
qep.partfile = eig_args.partfile;
qep.io_parity_inflate = eig_args.io_parity_inflate;
qep.use_norm_op = eig_args.use_norm_op;
qep.use_pc = eig_args.use_pc;
qep.eig_type = ( qep.block_size > 1 ) ? QUDA_EIG_BLK_TR_LANCZOS : QUDA_EIG_TR_LANCZOS; /* or QUDA_EIG_IR_ARNOLDI, QUDA_EIG_BLK_IR_ARNOLDI */
qep.spectrum = QUDA_SPECTRUM_SR_EIG; /* Smallest Real. Other options: LM, SM, LR, SR, LI, SI */
qep.qr_tol = qep.tol;
qep.require_convergence = QUDA_BOOLEAN_TRUE;
qep.check_interval = 10;
qep.use_dagger = QUDA_BOOLEAN_FALSE;
qep.compute_gamma5 = QUDA_BOOLEAN_FALSE;
qep.compute_svd = QUDA_BOOLEAN_FALSE;
qep.use_eigen_qr = QUDA_BOOLEAN_TRUE;
qep.use_poly_acc = QUDA_BOOLEAN_TRUE;
qep.arpack_check = QUDA_BOOLEAN_FALSE;
qep.compute_evals_batch_size = 16;
qep.preserve_deflation = QUDA_BOOLEAN_TRUE;

if (target_fermilab_residual == 0 && target_residual == 0) errorQuda("qudaInvert: requesting zero residual\n");

Expand All @@ -1427,7 +1541,6 @@ void qudaInvertMsrc(int external_precision, int quda_precision, double mass, Qud

QudaInvertParam invertParam = newQudaInvertParam();

QudaParity local_parity = inv_args.evenodd;
const double reliable_delta = 1e-1;

setInvertParams(host_precision, device_precision, device_precision_sloppy, mass, target_residual,
Expand All @@ -1436,18 +1549,18 @@ void qudaInvertMsrc(int external_precision, int quda_precision, double mass, Qud
invertParam.num_src = num_src;

// Deflation for even parity solves when desired
invertParam.eig_param = (local_parity == QUDA_EVEN_PARITY)&&(inv_args.eig_param.n_ev_deflate>0) ? &inv_args.eig_param : nullptr;
invertParam.tol_restart = inv_args.tol_restart;
invertParam.eig_param = (local_parity == QUDA_EVEN_PARITY)&&(qep.n_ev_deflate>0) ? &qep : nullptr;
invertParam.tol_restart = eig_args.tol_restart;

// Eigensolver precision
invertParam.cuda_prec_eigensolver = inv_args.prec_eigensolver;
invertParam.cuda_prec_eigensolver = eig_args.prec_eigensolver;

// Preserve deflation space
static bool deflation_init = false;
if (invertParam.eig_param && inv_args.eig_param.preserve_deflation) {
if (invertParam.eig_param && qep.preserve_deflation) {
if (deflation_init) {
if (!preserved_deflation_space) errorQuda("Unexpected nullptr for preserved deflation space");
inv_args.eig_param.preserve_deflation_space = preserved_deflation_space;
qep.preserve_deflation_space = preserved_deflation_space;
}
}

Expand Down Expand Up @@ -1477,8 +1590,8 @@ void qudaInvertMsrc(int external_precision, int quda_precision, double mass, Qud
host_free(sln_pointer);
host_free(src_pointer);

if (invertParam.eig_param && inv_args.eig_param.preserve_deflation) {
preserved_deflation_space = inv_args.eig_param.preserve_deflation_space;
if (invertParam.eig_param && qep.preserve_deflation) {
preserved_deflation_space = qep.preserve_deflation_space;
deflation_init = true; // signal that we have deflation space preserved
}

Expand All @@ -1499,7 +1612,7 @@ void qudaInvertMsrc(int external_precision, int quda_precision, double mass, Qud
if (!create_quda_gauge) invalidateGaugeQuda();

qudamilc_called<false>(__func__, verbosity);
} // qudaInvert
} // qudaInvertMsrcDeflatable

void qudaEigCGInvert(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args,
double target_residual, double target_fermilab_residual, const void *const fatlink,
Expand Down

0 comments on commit 6d93109

Please sign in to comment.