Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/mrhs solvers #1489

Merged
merged 114 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
4951f11
Fix warning in spin taste and minor cleanup
maddyscientist Jun 21, 2024
7c4793b
Some cleanup of CG interface
maddyscientist Jun 21, 2024
9a2190c
Add MRHS interface for all solvers, and mandate source vector is const
maddyscientist Jun 21, 2024
fda4669
Optimize DiracWilson: vectorize the prepare/reconstruct functions
maddyscientist Jun 26, 2024
44fb98a
Small cleanup to block_transpose.in.cu
maddyscientist Jun 26, 2024
fdd40fb
Add new parameter: QudaMultigridParam::n_vec_batch which is the batch…
maddyscientist Jun 26, 2024
fa64adf
Vectorize DiracCoarsePC prepare/reconstruct
maddyscientist Jun 26, 2024
fef58e8
Ensure we don't enable large arg support for pre Volta architecture
maddyscientist Jun 26, 2024
8b8cd99
Create vector variants of create_alias
maddyscientist Jun 27, 2024
ee6fd26
Add some more scalar wrappers: this facilitates us making the vector …
maddyscientist Jun 29, 2024
ac23c73
Supress annoying warning with Eigen
maddyscientist Jun 29, 2024
2b0763b
Add default copy/move constructors/assignment operator for XUpdateBat…
maddyscientist Jun 29, 2024
70a94df
Add some useful overloads to vector class to facilitate writing batch…
maddyscientist Jun 29, 2024
f792a33
Add explicit casting to double in anticipation of making the cast ope…
maddyscientist Jun 29, 2024
57ba15e
First pass at enabling MRHS for CG, MR and SD solvers. To better fin…
maddyscientist Jul 1, 2024
224bdb2
Accelerate MG::verify by using batch blas where applicable
maddyscientist Jul 1, 2024
6256391
Fix bug in MRE solver
maddyscientist Jul 1, 2024
7962dc3
Apply MRHS optimization to MRE solver
maddyscientist Jul 1, 2024
075cfb8
Remove complex.h inclusion
maddyscientist Jul 3, 2024
7cbab27
Vectorize all remaining Dirac prepare/reconstruct functions
maddyscientist Jul 4, 2024
d488607
Fix bug in GammaApply with introduced in #1416
maddyscientist Jul 4, 2024
6d1bafe
Fix issue with CG::hq_solve
maddyscientist Jul 4, 2024
d526544
Merge branch 'develop' of github.com:lattice/quda into feature/mrhs-s…
maddyscientist Jul 9, 2024
902d8ab
Fix bug with Clover Hasenbsusch operator (wrong braces)
maddyscientist Jul 9, 2024
02eecaa
Fix bug with DiracCoarsePC::reconstruct when using odd solve
maddyscientist Jul 11, 2024
8b067d4
Fix bug with counting bytes with clover operator
maddyscientist Jul 11, 2024
c682cae
Default inner GCR solver to use L2 residual to enable early exit if p…
maddyscientist Jul 11, 2024
1c5baef
Initial work to prepare for multi-rhs solver exposure: move the body …
maddyscientist Jul 21, 2024
faf4658
Fix flops counters for blas and reduce functions
maddyscientist Jul 21, 2024
b2f9849
Move remainder of invertQuda body into new MRHS solve wrapper that is…
maddyscientist Jul 21, 2024
c51f6e6
Fix true residual computation: QudaInvertParam::true_res and QudaInve…
maddyscientist Jul 25, 2024
c877352
invert_test and staggered_invert_test now respect --nsrc-tile flag fo…
maddyscientist Jul 25, 2024
19d3348
Add some size checks to P and R
maddyscientist Jul 26, 2024
3583ef8
Use batched blas in DiracCoarse
maddyscientist Jul 26, 2024
f64a9ac
Set verbosity in solve()
maddyscientist Jul 26, 2024
fa26c4e
GCR, CA-GCR and PreconditionedSolver are now MRHS aware
maddyscientist Jul 26, 2024
f3a3d8e
Multigrid solver is now MRHS aware
maddyscientist Jul 26, 2024
eefe8c7
Remove some legacy debug code from multigrid
maddyscientist Jul 26, 2024
f58bc3b
Add rescaling to coarse dslash with MMA - the code still needs cleanup.
hummingtree Jul 27, 2024
958bf12
Add tensor core support for 32/64 MG coarsening
maddyscientist Jul 28, 2024
21b4c08
Add striped signifier to packing kernel tune key
maddyscientist Jul 28, 2024
fce329d
Fix multi-RHS deflation
maddyscientist Jul 28, 2024
d9efb9c
Augmentation of state reporting to report the power, energy, temperat…
maddyscientist Jul 29, 2024
a781103
We should probably use MPI_THREAD_FUNNELED given we have threads now....
maddyscientist Jul 29, 2024
8d3b59e
Report energy when running the solver now
maddyscientist Jul 29, 2024
56ddf51
Fix Ampere+ mma kernels
maddyscientist Jul 29, 2024
20d33dd
Fix staggered MG bug
maddyscientist Jul 29, 2024
5f8f398
Clean up the coarse dslash MMA code:
hummingtree Jul 29, 2024
1b220b7
cvector -> vector
maddyscientist Jul 29, 2024
ab3f667
Merge branch 'feature/mrhs-solvers' of github.com:lattice/quda into f…
hummingtree Jul 29, 2024
7af1ade
Fix MPI bug
maddyscientist Jul 30, 2024
a20bf0e
Merge branch 'feature/mrhs-solvers' of github.com:lattice/quda into f…
maddyscientist Jul 31, 2024
414260a
Fix deflateSVD for block deflation
maddyscientist Jul 31, 2024
181a52e
Add some sanity checking when using split grid
maddyscientist Aug 28, 2024
4447320
If communicator is not found, do not call errorQuda (which causes an …
maddyscientist Aug 28, 2024
994bdd8
Fix some verbosity aspects of tuning
maddyscientist Aug 29, 2024
49c6c8a
Merge branch 'develop' of github.com:lattice/quda into feature/mrhs-s…
maddyscientist Aug 30, 2024
ceb84fa
Merge branch 'develop' of github.com:lattice/quda into feature/mrhs-s…
maddyscientist Aug 30, 2024
35e17db
Check set sizes match when copying between them
maddyscientist Sep 3, 2024
661a2a1
Multi-RHS solvers should check to see if their state needs to be resized
maddyscientist Sep 3, 2024
2dd2502
Add iterator-pair constructor for quda::vector class
maddyscientist Sep 4, 2024
3ad7a57
MRHS optimizations for eigensolver: exposed new parameter QudaInvertP…
maddyscientist Sep 5, 2024
a386654
Preserve eigen space when running multi-src deflated solves
maddyscientist Sep 5, 2024
2eb1289
Merge branch 'develop' of github.com:lattice/quda into feature/mrhs-s…
maddyscientist Sep 5, 2024
0ed1fe9
Fix CI warnings (one of which was a real bug)
maddyscientist Sep 6, 2024
75a3b4c
More CI warnings
maddyscientist Sep 6, 2024
37a6ae9
Fix some outstanding CI issues
maddyscientist Sep 6, 2024
fc0762d
renaming as suggested in CI
maddyscientist Sep 6, 2024
44d8a2a
Use std::vector iterator constructor
maddyscientist Sep 6, 2024
e89be7d
Revert change made in this branch
maddyscientist Sep 6, 2024
dcd0d43
Cleanup of DiracCloverHasenbuschTwistPC
maddyscientist Sep 6, 2024
f0d9f3c
Cleanup ero source checking in the solvers
maddyscientist Sep 6, 2024
36c9b6d
CGNR and CGNE are now MRHS
maddyscientist Sep 10, 2024
d5c6708
CG3 is now MRHS
maddyscientist Sep 10, 2024
2e52dad
Remove derived CGNR and CGNE specialziations for CG/CA-CG/CG3: we now…
maddyscientist Sep 11, 2024
0803993
Optimize HQ in CA-GCR
maddyscientist Sep 11, 2024
e1589e5
CA-CG is now MRHS
maddyscientist Sep 11, 2024
a3186a0
BiCGStab is now MRHS
maddyscientist Sep 12, 2024
cbf5943
BiCGStab(l) is now multi-RHS
maddyscientist Sep 16, 2024
dce23c6
Fix typo. Closes #1492
maddyscientist Sep 17, 2024
827700d
Use fine-grain parallelization for CopySpinor
maddyscientist Sep 18, 2024
ade0e16
Ensure that mg_eig_evals_batch_size in test code has sensible default
maddyscientist Sep 18, 2024
36ba139
Updated MILC interface to batched CG, hq tolerance bugfix in CG itself
weinbe2 Sep 23, 2024
50e7879
PCG is now multi-RHS ready. Improve robustness of Solver convergence…
maddyscientist Sep 26, 2024
c7f05b5
Fix default asan options which broke when the separate test library w…
maddyscientist Sep 26, 2024
0073449
Merge branch 'develop' of github.com:lattice/quda into feature/mrhs-s…
maddyscientist Sep 26, 2024
2f3faab
Conditionally print energy information
maddyscientist Sep 26, 2024
5af9d01
Reduce memory for clover force (use smaller halo for extended field
maddyscientist Sep 27, 2024
764fb78
Fix compiler warning
maddyscientist Sep 27, 2024
7523f3a
Fix bug in CA CG
maddyscientist Sep 27, 2024
e46b436
Work arounds for NVSHMEM due to coarse grained synchronization used i…
maddyscientist Sep 27, 2024
aa598b6
Merge branch 'develop' of github.com:lattice/quda into feature/mrhs-s…
maddyscientist Sep 27, 2024
093e0e6
Split grid true residuals now correctly returned in QudaInvertParam s…
maddyscientist Oct 1, 2024
a3e889d
invert_test should never call invertMultiSrcQuda is multishift is int…
maddyscientist Oct 1, 2024
9933e4d
Some MRHS HQ related solver fixes
maddyscientist Oct 1, 2024
da38153
More MRHS solver fixes
maddyscientist Oct 1, 2024
e5d74e4
Add QudaInvertParam::energy/power/temp/clock to check_param
maddyscientist Oct 1, 2024
3bfcc68
Wilson ctest invert_test now uses multi-RHS
maddyscientist Oct 1, 2024
d3217ad
Fix memory freeing with chrono predictor
maddyscientist Oct 1, 2024
1e28221
Fix CG3 for MRHS
maddyscientist Oct 2, 2024
9936354
Fix clover force test
maddyscientist Oct 2, 2024
a0184d6
Heterogeneous reductions now break up the device-local partial read a…
maddyscientist Oct 3, 2024
e51c59c
ctest should use mrhs for asqtad solver test
maddyscientist Oct 3, 2024
70a3b75
Fix typo
maddyscientist Oct 4, 2024
9d4abe9
Fix for QudaMultigridParam::dslash_use_mma so that it respects the co…
maddyscientist Oct 7, 2024
a9ef50b
Apply clang format
maddyscientist Oct 7, 2024
5c3192a
Updated the MILC HISQ MG interface for setting batch sizes
weinbe2 Oct 8, 2024
7903288
Set QudaMultigridParam::n_vec_batch to invalid to force user to set t…
maddyscientist Oct 8, 2024
e862afd
Merge branch 'feature/mrhs-solvers' of github.com:lattice/quda into f…
maddyscientist Oct 8, 2024
7a5fb37
Made nvec_batch more robust in the MILC HISQ MG interface
weinbe2 Oct 8, 2024
4af762d
Merge branch 'feature/mrhs-solvers' of https://github.com/lattice/qud…
weinbe2 Oct 8, 2024
2d56bfd
bump CPM (silences some warnings with newer cmake)
mathiaswagner Oct 9, 2024
05b2bc6
Fix typo
maddyscientist Oct 9, 2024
4cef59f
Merge branch 'feature/mrhs-solvers' of github.com:lattice/quda into f…
maddyscientist Oct 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmake/CPM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2019-2023 Lars Melchior and contributors

set(CPM_DOWNLOAD_VERSION 0.38.5)
set(CPM_HASH_SUM "192aa0ccdc57dfe75bd9e4b176bf7fb5692fd2b3e3f7b09c74856fc39572b31c")
set(CPM_DOWNLOAD_VERSION 0.40.2)
set(CPM_HASH_SUM "c8cdc32c03816538ce22781ed72964dc864b2a34a310d3b7104812a5ca2d835d")

if(CPM_SOURCE_CACHE)
set(CPM_DOWNLOAD_LOCATION "${CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake")
Expand Down
9 changes: 7 additions & 2 deletions include/accelerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ namespace quda
* @param out Solution vector.
* @param in Right-hand side.
*/
virtual void operator()(ColorSpinorField &out, ColorSpinorField &in)
virtual void operator()(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) override
maddyscientist marked this conversation as resolved.
Show resolved Hide resolved
{
for (auto i = 0u; i < in.size(); i++) operator()(out[i], in[i]);
}

void operator()(ColorSpinorField &out, const ColorSpinorField &in)
{
if (transformer.trained) {
transformer.apply(*base_solver, out, in);
Expand All @@ -64,7 +69,7 @@ namespace quda
* @param null Solver to solve for null vectors.
* @param in meta color spinor field.
*/
virtual void train_param(Solver &null, ColorSpinorField &in)
virtual void train_param(Solver &null, const ColorSpinorField &in) override
{
if (!active_training && !transformer.trained) {
active_training = true;
Expand Down
102 changes: 96 additions & 6 deletions include/blas_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace quda {

inline void copy(cvector_ref<ColorSpinorField> &dst, cvector_ref<const ColorSpinorField> &src)
{
if (dst.size() != src.size()) errorQuda("Mismatched vector sets %lu != %lu", dst.size(), src.size());
for (auto i = 0u; i < src.size(); i++) { dst[i].copy(src[i]); }
}

Expand Down Expand Up @@ -293,7 +294,7 @@ namespace quda {

inline array<double, 2> max_deviation(const ColorSpinorField &x, const ColorSpinorField &y)
weinbe2 marked this conversation as resolved.
Show resolved Hide resolved
{
return max_deviation(cvector_ref<const ColorSpinorField>(x), cvector_ref<const ColorSpinorField>(y));
return max_deviation(cvector_ref<const ColorSpinorField>(x), cvector_ref<const ColorSpinorField>(y))[0];
}

/**
Expand All @@ -302,13 +303,15 @@ namespace quda {
*/
cvector<double> norm1(cvector_ref<const ColorSpinorField> &x);

inline double norm1(const ColorSpinorField &x) { return norm1(cvector_ref<const ColorSpinorField> {x})[0]; }

/**
@brief Compute the L2 norm (||x||^2) of a field
@param[in] x The field we are reducing
*/
cvector<double> norm2(cvector_ref<const ColorSpinorField> &x);

inline double norm2(const ColorSpinorField &x) { return norm2(cvector_ref<const ColorSpinorField> {x}); }
inline double norm2(const ColorSpinorField &x) { return norm2(cvector_ref<const ColorSpinorField> {x})[0]; }

/**
@brief Compute y += a * x and then (x, y)
Expand All @@ -319,6 +322,11 @@ namespace quda {
cvector<double> axpyReDot(cvector<double> &a, cvector_ref<const ColorSpinorField> &x,
cvector_ref<ColorSpinorField> &y);

inline double axpyReDot(double a, const ColorSpinorField &x, ColorSpinorField &y)
{
return axpyReDot(cvector<double>(a), cvector_ref<const ColorSpinorField>(x), y)[0];
}

/**
@brief Compute the real-valued inner product (x, y)
@param[in] x input vector
Expand All @@ -328,7 +336,7 @@ namespace quda {

inline double reDotProduct(const ColorSpinorField &x, const ColorSpinorField &y)
{
return reDotProduct(cvector_ref<const ColorSpinorField> {x}, cvector_ref<const ColorSpinorField> {y});
return reDotProduct(cvector_ref<const ColorSpinorField> {x}, cvector_ref<const ColorSpinorField> {y})[0];
}

/**
Expand All @@ -342,6 +350,12 @@ namespace quda {
cvector<double> axpbyzNorm(cvector<double> &a, cvector_ref<const ColorSpinorField> &x, cvector<double> &b,
cvector_ref<const ColorSpinorField> &y, cvector_ref<ColorSpinorField> &z);

inline double axpbyzNorm(double a, const ColorSpinorField &x, double b, const ColorSpinorField &y,
ColorSpinorField &z)
{
return axpbyzNorm(cvector<double>(a), cvector_ref<const ColorSpinorField>(x), cvector<double>(b), y, z)[0];
}

/**
@brief Compute y += a * x and then ||y||^2
@param[in] a scalar multiplier
Expand All @@ -354,6 +368,11 @@ namespace quda {
return axpbyzNorm(a, x, 1.0, y, y);
}

inline double axpyNorm(double a, const ColorSpinorField &x, ColorSpinorField &y)
{
return axpyNorm(a, cvector_ref<const ColorSpinorField> {x}, cvector_ref<ColorSpinorField> {y})[0];
}

/**
@brief Compute the complex-valued inner product (x, y)
@param[in] x input vector
Expand All @@ -363,7 +382,7 @@ namespace quda {

inline Complex cDotProduct(const ColorSpinorField &x, const ColorSpinorField &y)
{
return cDotProduct(cvector_ref<const ColorSpinorField> {x}, cvector_ref<const ColorSpinorField> {y});
return cDotProduct(cvector_ref<const ColorSpinorField> {x}, cvector_ref<const ColorSpinorField> {y})[0];
}

/**
Expand All @@ -373,6 +392,11 @@ namespace quda {
*/
cvector<double4> cDotProductNormAB(cvector_ref<const ColorSpinorField> &x, cvector_ref<const ColorSpinorField> &y);

inline double4 cDotProductNormAB(const ColorSpinorField &x, const ColorSpinorField &y)
{
return cDotProductNormAB(cvector_ref<const ColorSpinorField> {x}, cvector_ref<const ColorSpinorField> {y})[0];
}

/**
@brief Return complex-valued inner product (x,y) and ||x||^2
@param[in] x input vector
Expand All @@ -387,6 +411,11 @@ namespace quda {
return a;
}

inline double3 cDotProductNormA(const ColorSpinorField &x, const ColorSpinorField &y)
{
return cDotProductNormA(cvector_ref<const ColorSpinorField> {x}, cvector_ref<const ColorSpinorField> {y})[0];
}

/**
@brief Return complex-valued inner product (x,y) and ||y||^2
@param[in] x input vector
Expand All @@ -401,6 +430,11 @@ namespace quda {
return a;
}

inline double3 cDotProductNormB(const ColorSpinorField &x, const ColorSpinorField &y)
{
return cDotProductNormB(cvector_ref<const ColorSpinorField> {x}, cvector_ref<const ColorSpinorField> {y})[0];
}

/**
@brief Apply the operation z += a * x + b * y, y -= b * w,
compute complex-valued inner product (u, y) and ||y||^2
Expand All @@ -418,6 +452,14 @@ namespace quda {
cvector_ref<const ColorSpinorField> &w,
cvector_ref<const ColorSpinorField> &u);

inline double3 caxpbypzYmbwcDotProductUYNormY(const Complex &a, const ColorSpinorField &x, const Complex &b,
ColorSpinorField &y, ColorSpinorField &z, const ColorSpinorField &w,
const ColorSpinorField &u)
{
return caxpbypzYmbwcDotProductUYNormY(cvector<Complex>(a), cvector_ref<const ColorSpinorField>(x), b, y, z, w,
u)[0];
}

/**
@brief Compute y = a * x + b * y and then ||y||^2
@param[in] a scalar multiplier
Expand All @@ -440,6 +482,11 @@ namespace quda {
return caxpbyNorm(a, x, 1.0, y);
}

inline double caxpyNorm(const Complex &a, const ColorSpinorField &x, ColorSpinorField &y)
{
return caxpyNorm(a, cvector_ref<const ColorSpinorField> {x}, cvector_ref<ColorSpinorField> {y})[0];
}

/**
@brief Compute y -= x and then ||y||^2
@param[in] x input vector
Expand All @@ -450,6 +497,11 @@ namespace quda {
return caxpbyNorm(1.0, x, -1.0, y);
}

inline double xmyNorm(const ColorSpinorField &x, ColorSpinorField &y)
{
return xmyNorm(cvector_ref<const ColorSpinorField> {x}, cvector_ref<ColorSpinorField> {y})[0];
}

/**
@brief Compute z = a * b * x + y, x = a * x, and then ||z||^2
@param[in] a scalar multiplier
Expand All @@ -461,6 +513,12 @@ namespace quda {
cvector<double> cabxpyzAxNorm(cvector<double> &a, cvector<Complex> &b, cvector_ref<ColorSpinorField> &x,
cvector_ref<const ColorSpinorField> &y, cvector_ref<ColorSpinorField> &z);

inline double cabxpyzAxNorm(double a, const Complex &b, ColorSpinorField &x, const ColorSpinorField &y,
ColorSpinorField &z)
{
return cabxpyzAxNorm(cvector<double>(a), cvector<Complex>(b), cvector_ref<ColorSpinorField>(x), y, z)[0];
}

/**
@brief Compute y += a * x and the resulting complex-valued inner product (z, y)
@param[in] a scalar multiplier
Expand All @@ -471,6 +529,11 @@ namespace quda {
cvector<Complex> caxpyDotzy(cvector<Complex> &a, cvector_ref<const ColorSpinorField> &x,
cvector_ref<ColorSpinorField> &y, cvector_ref<const ColorSpinorField> &z);

inline Complex caxpyDotzy(const Complex &a, const ColorSpinorField &x, ColorSpinorField &y, const ColorSpinorField &z)
{
return caxpyDotzy(cvector<Complex>(a), cvector_ref<const ColorSpinorField>(x), y, z)[0];
}

/**
@brief Compute y += a * x and then compute ||y||^2 and
real-valued inner product (y_out, y_out-y_in)
Expand All @@ -481,6 +544,11 @@ namespace quda {
cvector<double2> axpyCGNorm(cvector<double> &a, cvector_ref<const ColorSpinorField> &x,
cvector_ref<ColorSpinorField> &y);

inline double2 axpyCGNorm(double a, const ColorSpinorField &x, ColorSpinorField &y)
{
return axpyCGNorm(cvector<double>(a), cvector_ref<const ColorSpinorField>(x), y)[0];
}

/**
@brief Computes ||x||^2, ||r||^2 and the MILC/FNAL heavy quark
residual norm
Expand All @@ -492,7 +560,7 @@ namespace quda {

inline double3 HeavyQuarkResidualNorm(const ColorSpinorField &x, const ColorSpinorField &r)
{
return HeavyQuarkResidualNorm(cvector_ref<const ColorSpinorField>(x), cvector_ref<const ColorSpinorField>(r));
return HeavyQuarkResidualNorm(cvector_ref<const ColorSpinorField>(x), cvector_ref<const ColorSpinorField>(r))[0];
}

/**
Expand All @@ -510,7 +578,7 @@ namespace quda {
const ColorSpinorField &r)
{
return xpyHeavyQuarkResidualNorm(cvector_ref<const ColorSpinorField>(x), cvector_ref<const ColorSpinorField>(y),
cvector_ref<const ColorSpinorField>(r));
cvector_ref<const ColorSpinorField>(r))[0];
}

/**
Expand All @@ -522,6 +590,11 @@ namespace quda {
cvector<double3> tripleCGReduction(cvector_ref<const ColorSpinorField> &x, cvector_ref<const ColorSpinorField> &y,
cvector_ref<const ColorSpinorField> &z);

inline double3 tripleCGReduction(const ColorSpinorField &x, const ColorSpinorField &y, const ColorSpinorField &z)
{
return tripleCGReduction(cvector_ref<const ColorSpinorField>(x), y, z)[0];
}

/**
@brief Computes ||x||^2, ||y||^2, the real-valued inner product (y, z), and ||z||^2
@param[in] x input vector
Expand All @@ -531,6 +604,11 @@ namespace quda {
cvector<double4> quadrupleCGReduction(cvector_ref<const ColorSpinorField> &x, cvector_ref<const ColorSpinorField> &y,
cvector_ref<const ColorSpinorField> &z);

inline double4 quadrupleCGReduction(const ColorSpinorField &x, const ColorSpinorField &y, const ColorSpinorField &z)
{
return quadrupleCGReduction(cvector_ref<const ColorSpinorField>(x), y, z)[0];
}

/**
@brief Computes z = x, w = y, x += a * y, y -= a * v and ||y||^2
@param[in] a scalar multiplier
Expand All @@ -544,6 +622,12 @@ namespace quda {
cvector_ref<ColorSpinorField> &y, cvector_ref<ColorSpinorField> &z,
cvector_ref<ColorSpinorField> &w, cvector_ref<const ColorSpinorField> &v);

inline double quadrupleCG3InitNorm(double a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z,
ColorSpinorField &w, const ColorSpinorField &v)
{
return quadrupleCG3InitNorm(cvector<double>(a), cvector_ref<ColorSpinorField>(x), y, z, w, v)[0];
}

/**
@brief Computes x = b * (x + a * y) + ( 1 - b) * z,
y = b * (y + a * v) + (1 - b) * w, z = x_in, w = y_in, and
Expand All @@ -560,6 +644,12 @@ namespace quda {
cvector_ref<ColorSpinorField> &y, cvector_ref<ColorSpinorField> &z,
cvector_ref<ColorSpinorField> &w, cvector_ref<const ColorSpinorField> &v);

inline double quadrupleCG3UpdateNorm(double a, double b, ColorSpinorField &x, ColorSpinorField &y,
ColorSpinorField &z, ColorSpinorField &w, const ColorSpinorField &v)
{
return quadrupleCG3UpdateNorm(cvector<double>(a), b, cvector_ref<ColorSpinorField>(x), y, z, w, v)[0];
}

namespace block
{

Expand Down
30 changes: 30 additions & 0 deletions include/color_spinor_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,36 @@ namespace quda
void resize(std::vector<ColorSpinorField> &v, size_t new_size, QudaFieldCreate create,
const ColorSpinorField &src = ColorSpinorField());

/**
@brief Create a vector of fields that aliases another vector of
fields' storage. The alias field can use a different precision
than this field, though it cannot be greater. This
functionality is useful for the case where we have multiple
temporaries in different precisions, but do not need them
simultaneously. Use this functionality with caution.
@param[out] alias The vector of aliased fields
@param[in] v The vector of fields to alias
@param[in] param Parameters for the alias field
*/
void create_alias(cvector_ref<ColorSpinorField> &alias, cvector_ref<const ColorSpinorField> &v,
const ColorSpinorParam &param = ColorSpinorParam());

/**
@brief Create a vector of fields that aliases another vector of
fields' storage. The alias field can use a different precision
than this field, though it cannot be greater. This functionality
is useful for the case where we have multiple temporaries in
different precisions, but do not need them simultaneously. This
variant is used with std::vector as opposed to vector_ref, and
allows for correct resizing. Use this functionality with
caution.
@param[out] alias The vector of aliased fields
@param[in] v The vector of fields to alias
@param[in] param Parameters for the alias field
*/
void create_alias(std::vector<ColorSpinorField> &alias, cvector_ref<const ColorSpinorField> &v,
const ColorSpinorParam &param = ColorSpinorParam());

void copyGenericColorSpinor(ColorSpinorField &dst, const ColorSpinorField &src, QudaFieldLocation location,
void *Dst = nullptr, const void *Src = nullptr);

Expand Down
5 changes: 0 additions & 5 deletions include/dirac_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,6 @@ namespace quda {
DiracWilson(const DiracWilson &dirac);
DiracWilson(const DiracParam &param, const int nDims); // to correctly adjust face for DW and non-deg twisted mass

virtual ~DiracWilson();
bjoo marked this conversation as resolved.
Show resolved Hide resolved
DiracWilson& operator=(const DiracWilson &dirac);

virtual void Dslash(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in,
Expand Down Expand Up @@ -592,7 +591,6 @@ namespace quda {
public:
DiracWilsonPC(const DiracParam &param);
DiracWilsonPC(const DiracWilsonPC &dirac);
virtual ~DiracWilsonPC();
DiracWilsonPC& operator=(const DiracWilsonPC &dirac);

void M(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override;
Expand Down Expand Up @@ -1962,7 +1960,6 @@ namespace quda {
@param[in] param Parameters defining this operator
*/
DiracCoarse(const DiracCoarse &dirac, const DiracParam &param);
virtual ~DiracCoarse();

virtual bool isCoarse() const override { return true; }

Expand Down Expand Up @@ -2108,8 +2105,6 @@ namespace quda {
*/
DiracCoarsePC(const DiracCoarse &dirac, const DiracParam &param);

virtual ~DiracCoarsePC();

/**
@brief Apply preconditioned Dslash out = (D * in)
@param[out] out Output field
Expand Down
Loading
Loading