Skip to content

Commit

Permalink
Add particle neighbor kernel method (#4662)
Browse files Browse the repository at this point in the history
Closes #4620

Description of changes:
- add a short-range loop method to generate a list of the particle ids of all particles within interaction range of a central particle (to be used in machine-learned potentials based on descriptors of a particle's neighborhood)
- expose `cell_structure.max_range()` to the python interface for unit testing
- refactor short-range loop methods (dependency inversion and separation of concerns)
  • Loading branch information
kodiakhq[bot] authored Feb 28, 2023
2 parents 1945d5b + f929cb7 commit e6fdd8a
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 78 deletions.
138 changes: 71 additions & 67 deletions src/core/cells.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
#include "integrate.hpp"
#include "particle_node.hpp"

#include <utils/Vector.hpp>
#include <utils/math/sqr.hpp>
#include <utils/mpi/gather_buffer.hpp>

#include <boost/mpi/collectives/all_reduce.hpp>
#include <boost/range/algorithm/min_element.hpp>
Expand All @@ -61,19 +61,18 @@ CellStructure cell_structure{box_geo};
* filter criterion.
*
* It uses link_cell to get pairs out of the cellsystem
* by a simple distance criterion and
* by a simple distance criterion and applies the filter on both particles.
*
* Pairs are sorted so that first.id < second.id
*/
template <class Filter>
std::vector<std::pair<int, int>> get_pairs_filtered(double const distance,
Filter filter) {
std::vector<std::pair<int, int>> ret;
on_observable_calc();
auto const cutoff2 = distance * distance;
auto pair_kernel = [&ret, &cutoff2, &filter](Particle const &p1,
Particle const &p2,
Distance const &d) {
auto const cutoff2 = Utils::sqr(distance);
auto const pair_kernel = [cutoff2, &filter, &ret](Particle const &p1,
Particle const &p2,
Distance const &d) {
if (d.dist2 < cutoff2 and filter(p1) and filter(p2))
ret.emplace_back(p1.id(), p2.id());
};
Expand All @@ -89,109 +88,114 @@ std::vector<std::pair<int, int>> get_pairs_filtered(double const distance,
return ret;
}

namespace boost {
namespace serialization {
template <class Archive>
void serialize(Archive &ar, PairInfo &p, const unsigned int /* version */) {
ar &p.id1;
ar &p.id2;
ar &p.pos1;
ar &p.pos2;
ar &p.vec21;
ar &p.node;
}
} // namespace serialization
} // namespace boost

namespace detail {
static void search_distance_sanity_check(double const distance) {
static auto get_max_neighbor_search_range() {
return *boost::min_element(cell_structure.max_range());
}
static void search_distance_sanity_check_max_range(double const distance) {
/* get_pairs_filtered() finds pairs via the non_bonded_loop. The maximum
* finding range is therefore limited by the decomposition that is used.
*/
auto const range = *boost::min_element(cell_structure.max_range());
if (distance > range) {
auto const max_range = get_max_neighbor_search_range();
if (distance > max_range) {
throw std::domain_error("pair search distance " + std::to_string(distance) +
" bigger than the decomposition range " +
std::to_string(range));
std::to_string(max_range));
}
}
static void search_neighbors_sanity_check(double const distance) {
search_distance_sanity_check(distance);
static void search_distance_sanity_check_cell_structure(double const distance) {
if (cell_structure.decomposition_type() ==
CellStructureType::CELL_STRUCTURE_HYBRID) {
throw std::runtime_error("Cannot search for neighbors in the hybrid "
"decomposition cell system");
}
}
static void search_neighbors_sanity_checks(double const distance) {
search_distance_sanity_check_max_range(distance);
search_distance_sanity_check_cell_structure(distance);
}
} // namespace detail

boost::optional<std::vector<int>>
mpi_get_short_range_neighbors_local(int const pid, double const distance,
bool run_sanity_checks) {

if (run_sanity_checks) {
detail::search_neighbors_sanity_check(distance);
}
on_observable_calc();

auto const p = cell_structure.get_local_particle(pid);
if (not p or p->is_ghost()) {
return {};
}

get_short_range_neighbors(int const pid, double const distance) {
detail::search_neighbors_sanity_checks(distance);
std::vector<int> ret;
auto const cutoff2 = distance * distance;
auto kernel = [&ret, cutoff2](Particle const &p1, Particle const &p2,
Utils::Vector3d const &vec) {
auto const cutoff2 = Utils::sqr(distance);
auto const kernel = [cutoff2, &ret](Particle const &, Particle const &p2,
Utils::Vector3d const &vec) {
if (vec.norm2() < cutoff2) {
ret.emplace_back(p2.id());
}
};
cell_structure.run_on_particle_short_range_neighbors(*p, kernel);
return {ret};
auto const p = ::cell_structure.get_local_particle(pid);
if (p and not p->is_ghost()) {
::cell_structure.run_on_particle_short_range_neighbors(*p, kernel);
return {ret};
}
return {};
}

REGISTER_CALLBACK_ONE_RANK(mpi_get_short_range_neighbors_local)

std::vector<int> mpi_get_short_range_neighbors(int const pid,
double const distance) {
detail::search_neighbors_sanity_check(distance);
return mpi_call(::Communication::Result::one_rank,
mpi_get_short_range_neighbors_local, pid, distance, false);
/**
* @brief Get pointers to all interacting neighbors of a central particle.
*/
static auto get_interacting_neighbors(Particle const &p) {
auto const distance = *boost::min_element(::cell_structure.max_range());
detail::search_neighbors_sanity_checks(distance);
std::vector<Particle const *> ret;
auto const cutoff2 = Utils::sqr(distance);
auto const kernel = [cutoff2, &ret](Particle const &, Particle const &p2,
Utils::Vector3d const &vec) {
if (vec.norm2() < cutoff2) {
ret.emplace_back(&p2);
}
};
::cell_structure.run_on_particle_short_range_neighbors(p, kernel);
return ret;
}

std::vector<std::pair<int, int>> get_pairs(double const distance) {
detail::search_distance_sanity_check(distance);
auto pairs =
get_pairs_filtered(distance, [](Particle const &) { return true; });
Utils::Mpi::gather_buffer(pairs, comm_cart);
return pairs;
detail::search_neighbors_sanity_checks(distance);
return get_pairs_filtered(distance, [](Particle const &) { return true; });
}

std::vector<std::pair<int, int>>
get_pairs_of_types(double const distance, std::vector<int> const &types) {
detail::search_distance_sanity_check(distance);
auto pairs = get_pairs_filtered(distance, [types](Particle const &p) {
detail::search_neighbors_sanity_checks(distance);
return get_pairs_filtered(distance, [types](Particle const &p) {
return std::any_of(types.begin(), types.end(),
// NOLINTNEXTLINE(bugprone-exception-escape)
[p](int const type) { return p.type() == type; });
});
Utils::Mpi::gather_buffer(pairs, comm_cart);
return pairs;
}

std::vector<PairInfo> non_bonded_loop_trace() {
std::vector<PairInfo> non_bonded_loop_trace(int const rank) {
std::vector<PairInfo> pairs;
auto pair_kernel = [&pairs](Particle const &p1, Particle const &p2,
Distance const &d) {
pairs.emplace_back(p1.id(), p2.id(), p1.pos(), p2.pos(), d.vec21,
comm_cart.rank());
auto const pair_kernel = [&pairs, rank](Particle const &p1,
Particle const &p2,
Distance const &d) {
pairs.emplace_back(p1.id(), p2.id(), p1.pos(), p2.pos(), d.vec21, rank);
};
cell_structure.non_bonded_loop(pair_kernel);
Utils::Mpi::gather_buffer(pairs, comm_cart);
return pairs;
}

std::vector<NeighborPIDs> get_neighbor_pids() {
std::vector<NeighborPIDs> ret;
auto kernel = [&ret](Particle const &p,
std::vector<Particle const *> const &neighbors) {
std::vector<int> neighbor_pids;
neighbor_pids.reserve(neighbors.size());
for (auto const &neighbor : neighbors) {
neighbor_pids.emplace_back(neighbor->id());
}
ret.emplace_back(p.id(), neighbor_pids);
};
for (auto const &p : ::cell_structure.local_particles()) {
kernel(p, get_interacting_neighbors(p));
}
return ret;
}

void set_hybrid_decomposition(std::set<int> n_square_types,
double cutoff_regular) {
cell_structure.set_hybrid_decomposition(comm_cart, cutoff_regular, box_geo,
Expand Down
48 changes: 43 additions & 5 deletions src/core/cells.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@

#include "Particle.hpp"

#include <utils/Vector.hpp>

#include <boost/optional.hpp>

#include <utility>
Expand Down Expand Up @@ -114,10 +116,32 @@ void check_resort_particles();
* @brief Get ids of particles that are within a certain distance
* of another particle.
*/
std::vector<int> mpi_get_short_range_neighbors(int pid, double distance);
boost::optional<std::vector<int>>
mpi_get_short_range_neighbors_local(int pid, double distance,
bool run_sanity_checks);
boost::optional<std::vector<int>> get_short_range_neighbors(int pid,
double distance);

struct NeighborPIDs {
NeighborPIDs() = default;
NeighborPIDs(int _pid, std::vector<int> _neighbor_pids)
: pid{_pid}, neighbor_pids{std::move(_neighbor_pids)} {}

int pid;
std::vector<int> neighbor_pids;
};

namespace boost {
namespace serialization {
template <class Archive>
void serialize(Archive &ar, NeighborPIDs &n, unsigned int const /* version */) {
ar &n.pid;
ar &n.neighbor_pids;
}
} // namespace serialization
} // namespace boost

/**
* @brief Returns pairs of particle ids and neighbor particle id lists.
*/
std::vector<NeighborPIDs> get_neighbor_pids();

/**
* @brief Find the cell in which a particle is stored.
Expand All @@ -144,10 +168,24 @@ class PairInfo {
int node;
};

namespace boost {
namespace serialization {
template <class Archive>
void serialize(Archive &ar, PairInfo &p, unsigned int const /* version */) {
ar &p.id1;
ar &p.id2;
ar &p.pos1;
ar &p.pos2;
ar &p.vec21;
ar &p.node;
}
} // namespace serialization
} // namespace boost

/**
* @brief Returns pairs of particle ids, positions and distance as seen by the
* non-bonded loop.
*/
std::vector<PairInfo> non_bonded_loop_trace();
std::vector<PairInfo> non_bonded_loop_trace(int rank);

#endif
3 changes: 2 additions & 1 deletion src/core/reaction_methods/ReactionAlgorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,9 @@ void ReactionAlgorithm::check_exclusion_range(int p_id, int p_type) {
all_ids.end());
particle_ids = all_ids;
} else {
on_observable_calc();
auto const local_ids =
mpi_get_short_range_neighbors_local(p_id, m_max_exclusion_range, true);
get_short_range_neighbors(p_id, m_max_exclusion_range);
assert(p1_ptr == nullptr or !!local_ids);
if (local_ids) {
particle_ids = std::move(*local_ids);
Expand Down
10 changes: 10 additions & 0 deletions src/python/espressomd/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ class Analysis(ScriptInterfaceHelper):
(N,) array_like of :obj:`int`
The neighbouring particle ids.
particle_neighbor_pids()
Get a list of all short-range neighbors for each particle.
Returns
-------
:obj: `dict`
A dictionary where each item is a pair of a particle id and
its respective neighboring particle ids.
calc_re()
Calculate the mean end-to-end distance of chains and its
standard deviation, as well as mean square end-to-end distance of
Expand Down Expand Up @@ -351,6 +360,7 @@ class Analysis(ScriptInterfaceHelper):
"linear_momentum",
"center_of_mass",
"nbhood",
"particle_neighbor_pids",
"calc_re",
"calc_rg",
"calc_rh",
Expand Down
16 changes: 16 additions & 0 deletions src/script_interface/analysis/Analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

#include "core/analysis/statistics.hpp"
#include "core/analysis/statistics_chain.hpp"
#include "core/cells.hpp"
#include "core/dpd.hpp"
#include "core/energy.hpp"
#include "core/event.hpp"
#include "core/grid.hpp"
#include "core/nonbonded_interactions/nonbonded_interaction_data.hpp"
#include "core/partCfg_global.hpp"
Expand All @@ -32,6 +34,7 @@

#include <utils/Vector.hpp>
#include <utils/contains.hpp>
#include <utils/mpi/gather_buffer.hpp>

#include <algorithm>
#include <cmath>
Expand Down Expand Up @@ -81,6 +84,19 @@ Variant Analysis::do_call_method(std::string const &name,
auto const local = particle_short_range_energy_contribution(pid);
return mpi_reduce_sum(context()->get_comm(), local);
}
if (name == "particle_neighbor_pids") {
on_observable_calc();
std::unordered_map<int, std::vector<int>> dict;
context()->parallel_try_catch([&]() {
auto neighbor_pids = get_neighbor_pids();
Utils::Mpi::gather_buffer(neighbor_pids, context()->get_comm());
std::for_each(neighbor_pids.begin(), neighbor_pids.end(),
[&dict](NeighborPIDs const &neighbor_pid) {
dict[neighbor_pid.pid] = neighbor_pid.neighbor_pids;
});
});
return make_unordered_map_of_variants(dict);
}
if (not context()->is_head_node()) {
return {};
}
Expand Down
Loading

0 comments on commit e6fdd8a

Please sign in to comment.