diff --git a/src/Ions.cc b/src/Ions.cc index bc30417b..10baad10 100644 --- a/src/Ions.cc +++ b/src/Ions.cc @@ -1288,6 +1288,29 @@ void Ions::writeForces(HDFrestart& h5f_file) } } +void Ions::setLocalForces( + const std::vector& forces, const std::vector& names) +{ + assert(forces.size() == 3 * names.size()); + + // loop over global list of forces and atom names + std::vector::const_iterator s = names.begin(); + for (auto it = forces.begin(); it != forces.end(); it += 3) + { + // find possible matching ion + for (auto& ion : local_ions_) + { + if (ion->compareName(*s)) + { + ion->set_force(0, *it); + ion->set_force(1, *(it + 1)); + ion->set_force(2, *(it + 2)); + } + } + s++; + } +} + // Writes out the postions of the ions and the current forces on them by root void Ions::printForcesGlobal(std::ostream& os, const int root) const { @@ -2187,6 +2210,22 @@ void Ions::getLocalPositions(std::vector& tau) const } } +void Ions::getLocalNames(std::vector& names) const +{ + for (auto& ion : local_ions_) + { + names.push_back(ion->name()); + } +} + +void Ions::getNames(std::vector& names) const +{ + for (auto& ion : list_ions_) + { + names.push_back(ion->name()); + } +} + void Ions::getPositions(std::vector& tau) { std::vector tau_local(3 * local_ions_.size()); diff --git a/src/Ions.h b/src/Ions.h index b03972d1..33b8982b 100644 --- a/src/Ions.h +++ b/src/Ions.h @@ -285,11 +285,20 @@ class Ions const std::vector& tau, const std::vector& anumbers); void getLocalPositions(std::vector& tau) const; + void getLocalNames(std::vector& names) const; + void getNames(std::vector& names) const; void getPositions(std::vector& tau); void getAtomicNumbers(std::vector& atnumbers); void getForces(std::vector& forces); void getLocalForces(std::vector& tau) const; + + /*! + * set forces for ions in local_ions_ based on names matching + */ + void setLocalForces(const std::vector& forces, + const std::vector& names); + void syncData(const std::vector& sp); // void syncNames(const int nions, std::vector& local_names, // std::vector& names); diff --git a/tests/testIons.cc b/tests/testIons.cc index 6bfdd74a..e23e0e24 100644 --- a/tests/testIons.cc +++ b/tests/testIons.cc @@ -8,6 +8,8 @@ int main(int argc, char** argv) { + int status = 0; + int mpirc = MPI_Init(&argc, &argv); MPI_Comm comm = MPI_COMM_WORLD; @@ -92,7 +94,7 @@ int main(int argc, char** argv) if (ntotal != na) { std::cout << "ntotal = " << ntotal << std::endl; - return 1; + status = 1; } } MPI_Barrier(MPI_COMM_WORLD); @@ -141,16 +143,49 @@ int main(int argc, char** argv) MPI_Allreduce(&nlocal, &ntotal, 1, MPI_INT, MPI_SUM, comm); if (ntotal != na) { - std::cout << "ntotal = " << ntotal << std::endl; - return 1; + std::cerr << "ntotal = " << ntotal << std::endl; + MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); + } + } + + // get the names of all the ions + std::vector names; + ions.getNames(names); + if (myrank == 0) + for (auto& name : names) + std::cout << "name = " << name << std::endl; + if (names.size() != na) + { + std::cerr << "Incorrect count of names..." << std::endl; + MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); + } + MPI_Barrier(MPI_COMM_WORLD); + + std::vector forces(3 * na); + // arbitrary value + const double fval = 1.12; + for (auto& f : forces) + f = fval; + ions.setLocalForces(forces, names); + + int nlocal = ions.getNumLocIons(); + std::vector lforces(3 * nlocal); + ions.getLocalForces(lforces); + for (auto& f : lforces) + { + if (std::abs(f - fval) > 1.e-14) + { + std::cerr << "f = " << f << std::endl; + MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); } } + mpirc = MPI_Finalize(); if (mpirc != MPI_SUCCESS) { std::cerr << "MPI Finalize failed!!!" << std::endl; - return 1; + status = 1; } - return 0; + return status; }