Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/Ions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,29 @@ void Ions::writeForces(HDFrestart& h5f_file)
}
}

void Ions::setLocalForces(
const std::vector<double>& forces, const std::vector<std::string>& names)
{
assert(forces.size() == 3 * names.size());

// loop over global list of forces and atom names
std::vector<std::string>::const_iterator s = names.begin();
for (auto it = forces.begin(); it != forces.end(); it += 3)
{
// find possible matching ion
for (auto& ion : local_ions_)
{
if (ion->compareName(*s))
{
ion->set_force(0, *it);
ion->set_force(1, *(it + 1));
ion->set_force(2, *(it + 2));
}
}
s++;
}
}

// Writes out the postions of the ions and the current forces on them by root
void Ions::printForcesGlobal(std::ostream& os, const int root) const
{
Expand Down Expand Up @@ -2187,6 +2210,22 @@ void Ions::getLocalPositions(std::vector<double>& tau) const
}
}

void Ions::getLocalNames(std::vector<std::string>& names) const
{
for (auto& ion : local_ions_)
{
names.push_back(ion->name());
}
}

void Ions::getNames(std::vector<std::string>& names) const
{
for (auto& ion : list_ions_)
{
names.push_back(ion->name());
}
}

void Ions::getPositions(std::vector<double>& tau)
{
std::vector<double> tau_local(3 * local_ions_.size());
Expand Down
9 changes: 9 additions & 0 deletions src/Ions.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,20 @@ class Ions
const std::vector<double>& tau, const std::vector<short>& anumbers);

void getLocalPositions(std::vector<double>& tau) const;
void getLocalNames(std::vector<std::string>& names) const;
void getNames(std::vector<std::string>& names) const;
void getPositions(std::vector<double>& tau);
void getAtomicNumbers(std::vector<short>& atnumbers);

void getForces(std::vector<double>& forces);
void getLocalForces(std::vector<double>& tau) const;

/*!
* set forces for ions in local_ions_ based on names matching
*/
void setLocalForces(const std::vector<double>& forces,
const std::vector<std::string>& names);

void syncData(const std::vector<Species>& sp);
// void syncNames(const int nions, std::vector<std::string>& local_names,
// std::vector<std::string>& names);
Expand Down
45 changes: 40 additions & 5 deletions tests/testIons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

int main(int argc, char** argv)
{
int status = 0;

int mpirc = MPI_Init(&argc, &argv);

MPI_Comm comm = MPI_COMM_WORLD;
Expand Down Expand Up @@ -92,7 +94,7 @@ int main(int argc, char** argv)
if (ntotal != na)
{
std::cout << "ntotal = " << ntotal << std::endl;
return 1;
status = 1;
}
}
MPI_Barrier(MPI_COMM_WORLD);
Expand Down Expand Up @@ -141,16 +143,49 @@ int main(int argc, char** argv)
MPI_Allreduce(&nlocal, &ntotal, 1, MPI_INT, MPI_SUM, comm);
if (ntotal != na)
{
std::cout << "ntotal = " << ntotal << std::endl;
return 1;
std::cerr << "ntotal = " << ntotal << std::endl;
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
}
}

// get the names of all the ions
std::vector<std::string> names;
ions.getNames(names);
if (myrank == 0)
for (auto& name : names)
std::cout << "name = " << name << std::endl;
if (names.size() != na)
{
std::cerr << "Incorrect count of names..." << std::endl;
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
}
MPI_Barrier(MPI_COMM_WORLD);

std::vector<double> forces(3 * na);
// arbitrary value
const double fval = 1.12;
for (auto& f : forces)
f = fval;
ions.setLocalForces(forces, names);

int nlocal = ions.getNumLocIons();
std::vector<double> lforces(3 * nlocal);
ions.getLocalForces(lforces);
for (auto& f : lforces)
{
if (std::abs(f - fval) > 1.e-14)
{
std::cerr << "f = " << f << std::endl;
MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
}
}

mpirc = MPI_Finalize();
if (mpirc != MPI_SUCCESS)
{
std::cerr << "MPI Finalize failed!!!" << std::endl;
return 1;
status = 1;
}

return 0;
return status;
}