From cdedc075d8f61d8b017b683685c1b80b2603212e Mon Sep 17 00:00:00 2001 From: leonhostetler Date: Tue, 28 Jan 2025 15:56:22 -0600 Subject: [PATCH] Moved more eigensolver params into the small struct and added a setEigensolverParams function --- include/quda_milc_interface.h | 14 +++++ lib/milc_interface.cpp | 102 +++++++++++++--------------------- 2 files changed, 52 insertions(+), 64 deletions(-) diff --git a/include/quda_milc_interface.h b/include/quda_milc_interface.h index dfe06ae351..dd07cd3244 100644 --- a/include/quda_milc_interface.h +++ b/include/quda_milc_interface.h @@ -77,6 +77,20 @@ extern "C" { QudaBoolean io_parity_inflate; /** Whether to inflate single-parity eigen-vector I/O **/ QudaBoolean use_norm_op; QudaBoolean use_pc; + QudaEigType eig_type; /** Type of eigensolver algorithm to employ **/ + QudaEigSpectrumType spectrum; /** Which part of the spectrum to solve **/ + double qr_tol; /** Tolerance on the QR iteration **/ + QudaBoolean require_convergence; /** If true, the solver will error out if the convergence criteria are not met **/ + int check_interval; /** For IRLM/IRAM, check every nth restart **/ + QudaBoolean use_dagger; /** If use_dagger, use Mdag **/ + QudaBoolean compute_gamma5; /** Performs the \gamma_5 OP solve by post multiplying the eignvectors with \gamma_5 before computing the eigenvalues */ + QudaBoolean compute_svd; /** Performs an MdagM solve, then constructs the left and right SVD. **/ + QudaBoolean use_eigen_qr; /** Use Eigen routines to eigensolve the upper Hessenberg via QR **/ + QudaBoolean use_poly_acc; /** Use Polynomial Acceleration **/ + QudaBoolean arpack_check; /** In the test function, cross check the device result against ARPACK **/ + char arpack_logfile[512]; /** For Arpack cross check, name of the Arpack logfile **/ + int compute_evals_batch_size; /** The batch size used when computing eigenvalues **/ + QudaBoolean preserve_deflation; /** Whether to preserve the deflation space between solves **/ } QudaEigensolverArgs_t; diff --git a/lib/milc_interface.cpp b/lib/milc_interface.cpp index 7e834c7c82..0b1ee735d8 100644 --- a/lib/milc_interface.cpp +++ b/lib/milc_interface.cpp @@ -999,6 +999,42 @@ static void setGaugeParams(QudaGaugeParam &fat_param, QudaGaugeParam &long_param } +static void setEigensolverParams(QudaEigensolverArgs_t eig_args, QudaEigParam *qep) +{ + qep->block_size = eig_args.block_size; + qep->n_conv = eig_args.n_conv; + qep->n_ev_deflate = eig_args.n_ev_deflate; + qep->n_ev = eig_args.n_ev; + qep->n_kr = eig_args.n_kr; + qep->tol = eig_args.tol; + qep->max_restarts = eig_args.max_restarts; + qep->poly_deg = eig_args.poly_deg; + qep->a_min = eig_args.a_min; + qep->a_max = eig_args.a_max; + strcpy( qep->vec_infile, eig_args.vec_infile ); + strcpy( qep->vec_outfile, eig_args.vec_outfile ); + qep->preserve_evals = eig_args.preserve_evals; + qep->batched_rotate = eig_args.batched_rotate; + qep->save_prec = eig_args.save_prec; + qep->partfile = eig_args.partfile; + qep->io_parity_inflate = eig_args.io_parity_inflate; + qep->use_norm_op = eig_args.use_norm_op; + qep->use_pc = eig_args.use_pc; + qep->eig_type = eig_args.eig_type; + qep->spectrum = eig_args.spectrum; + qep->qr_tol = eig_args.qr_tol; + qep->require_convergence = eig_args.require_convergence; + qep->check_interval = eig_args.check_interval; + qep->use_dagger = eig_args.use_dagger; + qep->compute_gamma5 = eig_args.compute_gamma5; + qep->compute_svd = eig_args.compute_svd; + qep->use_eigen_qr = eig_args.use_eigen_qr; + qep->use_poly_acc = eig_args.use_poly_acc; + qep->arpack_check = eig_args.arpack_check; + qep->compute_evals_batch_size = eig_args.compute_evals_batch_size; + qep->preserve_deflation = eig_args.preserve_deflation; +} + static void setColorSpinorParams(const int dim[4], QudaPrecision precision, ColorSpinorParam *param) { param->nColor = 3; @@ -1186,38 +1222,7 @@ void qudaInvertDeflatable(int external_precision, int quda_precision, double mas // parameters for the eigensolve/deflation QudaEigParam qep = newQudaEigParam(); - qep.block_size = eig_args.block_size; - qep.n_conv = eig_args.n_conv; - qep.n_ev_deflate = eig_args.n_ev_deflate; - qep.n_ev = eig_args.n_ev; - qep.n_kr = eig_args.n_kr; - qep.tol = eig_args.tol; - qep.max_restarts = eig_args.max_restarts; - qep.poly_deg = eig_args.poly_deg; - qep.a_min = eig_args.a_min; - qep.a_max = eig_args.a_max; - strcpy( qep.vec_infile, eig_args.vec_infile ); - strcpy( qep.vec_outfile, eig_args.vec_outfile ); - qep.preserve_evals = eig_args.preserve_evals; - qep.batched_rotate = eig_args.batched_rotate; - qep.save_prec = eig_args.save_prec; - qep.partfile = eig_args.partfile; - qep.io_parity_inflate = eig_args.io_parity_inflate; - qep.use_norm_op = eig_args.use_norm_op; - qep.use_pc = eig_args.use_pc; - qep.eig_type = ( qep.block_size > 1 ) ? QUDA_EIG_BLK_TR_LANCZOS : QUDA_EIG_TR_LANCZOS; /* or QUDA_EIG_IR_ARNOLDI, QUDA_EIG_BLK_IR_ARNOLDI */ - qep.spectrum = QUDA_SPECTRUM_SR_EIG; /* Smallest Real. Other options: LM, SM, LR, SR, LI, SI */ - qep.qr_tol = qep.tol; - qep.require_convergence = QUDA_BOOLEAN_TRUE; - qep.check_interval = 10; - qep.use_dagger = QUDA_BOOLEAN_FALSE; - qep.compute_gamma5 = QUDA_BOOLEAN_FALSE; - qep.compute_svd = QUDA_BOOLEAN_FALSE; - qep.use_eigen_qr = QUDA_BOOLEAN_TRUE; - qep.use_poly_acc = QUDA_BOOLEAN_TRUE; - qep.arpack_check = QUDA_BOOLEAN_FALSE; - qep.compute_evals_batch_size = 16; - qep.preserve_deflation = QUDA_BOOLEAN_TRUE; + setEigensolverParams(eig_args, &qep); if (target_fermilab_residual == 0 && target_residual == 0) errorQuda("qudaInvert: requesting zero residual\n"); @@ -1488,38 +1493,7 @@ void qudaInvertMsrcDeflatable(int external_precision, int quda_precision, double // parameters for the eigensolve/deflation QudaEigParam qep = newQudaEigParam(); - qep.block_size = eig_args.block_size; - qep.n_conv = eig_args.n_conv; - qep.n_ev_deflate = eig_args.n_ev_deflate; - qep.n_ev = eig_args.n_ev; - qep.n_kr = eig_args.n_kr; - qep.tol = eig_args.tol; - qep.max_restarts = eig_args.max_restarts; - qep.poly_deg = eig_args.poly_deg; - qep.a_min = eig_args.a_min; - qep.a_max = eig_args.a_max; - strcpy( qep.vec_infile, eig_args.vec_infile ); - strcpy( qep.vec_outfile, eig_args.vec_outfile ); - qep.preserve_evals = eig_args.preserve_evals; - qep.batched_rotate = eig_args.batched_rotate; - qep.save_prec = eig_args.save_prec; - qep.partfile = eig_args.partfile; - qep.io_parity_inflate = eig_args.io_parity_inflate; - qep.use_norm_op = eig_args.use_norm_op; - qep.use_pc = eig_args.use_pc; - qep.eig_type = ( qep.block_size > 1 ) ? QUDA_EIG_BLK_TR_LANCZOS : QUDA_EIG_TR_LANCZOS; /* or QUDA_EIG_IR_ARNOLDI, QUDA_EIG_BLK_IR_ARNOLDI */ - qep.spectrum = QUDA_SPECTRUM_SR_EIG; /* Smallest Real. Other options: LM, SM, LR, SR, LI, SI */ - qep.qr_tol = qep.tol; - qep.require_convergence = QUDA_BOOLEAN_TRUE; - qep.check_interval = 10; - qep.use_dagger = QUDA_BOOLEAN_FALSE; - qep.compute_gamma5 = QUDA_BOOLEAN_FALSE; - qep.compute_svd = QUDA_BOOLEAN_FALSE; - qep.use_eigen_qr = QUDA_BOOLEAN_TRUE; - qep.use_poly_acc = QUDA_BOOLEAN_TRUE; - qep.arpack_check = QUDA_BOOLEAN_FALSE; - qep.compute_evals_batch_size = 16; - qep.preserve_deflation = QUDA_BOOLEAN_TRUE; + setEigensolverParams(eig_args, &qep); if (target_fermilab_residual == 0 && target_residual == 0) errorQuda("qudaInvert: requesting zero residual\n");