Skip to content

Commit

Permalink
update computes for env motif match and rmsd min with a wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
DomFijan committed Nov 12, 2024
1 parent e6fa0ba commit ccb8005
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 26 deletions.
2 changes: 1 addition & 1 deletion freud/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def compute(self, system, threshold, cluster_neighbors=None,
def cluster_idx(self):
""":math:`\\left(N_{particles}\\right)` :class:`numpy.ndarray`: The
per-particle index indicating cluster membership."""
return self._cpp_obj.getClusterIdx().toNumpyArray()
return self._cpp_obj.getClusters().toNumpyArray()

@_Compute._computed_property
def num_clusters(self):
Expand Down
70 changes: 45 additions & 25 deletions freud/environment/export-MatchEnv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,65 @@ template<typename T, typename shape>
using nb_array = nanobind::ndarray<T, shape, nanobind::device::cpu, nanobind::c_contig>;

namespace wrap {
void compute(const std::shared_ptr<MatchEnv>& match_env,
void compute_env_motif_match(const std::shared_ptr<EnvironmentMotifMatch>& env_motif_match,
std::shared_ptr<locality::NeighborQuery> nq,
const nb_array<float, nanobind::shape<-1, 3>>& query_points,
const unsigned int n_query_points,
const nb_array<float, nanobind::shape<-1, 4>>& orientations,
std::shared_ptr<locality::NeighborList> nlist,
const locality::QueryArgs& qargs,
const unsigned int max_num_neighbors
const nb_array<float, nanobind::shape<-1, 3>>& motif,
const unsigned int motif_size,
const float threshold,
const bool registration
)
{
auto* query_points_data = reinterpret_cast<vec3<float>*>(query_points.data());
auto* orientations_data = reinterpret_cast<quat<float>*>(orientations.data());
match_env->compute(nq, query_points_data, n_query_points, orientations_data, nlist, qargs, max_num_neighbors);
}
{
auto* motif_data = reinterpret_cast<vec3<float>*>(motif.data());
env_motif_match->compute(nq, nlist, qargs, motif_data, motif_size, threshold, registration);
}

void compute_env_rmsd_min(const std::shared_ptr<EnvironmentRMSDMinimizer>& env_rmsd_min,
std::shared_ptr<locality::NeighborQuery> nq,
std::shared_ptr<locality::NeighborList> nlist,
const locality::QueryArgs& qargs,
const nb_array<float, nanobind::shape<-1, 3>>& motif,
const unsigned int motif_size,
const float threshold,
const bool registration
)
{
auto* motif_data = reinterpret_cast<vec3<float>*>(motif.data());
env_rmsd_min->compute(nq, nlist, qargs, motif_data, motif_size, threshold, registration);
}

};

namespace detail {

void export_MatchEnv(nb::module_& module)
{
// export minimizeRMSD function
// export isSimilar function
// export MatchEnv class
// export getPointEnvironments fn
// export minimizeRMSD function, move convenience to wrap? TODO
nb::function("minimizeRMSD", &minimizeRMSD);//carefull ths fn is overloaded for easier python interactivity. You should use the one that takes box etc in.
// export isSimilar function, move convenience to wrap? TODO
nb::function("isSimilar", &isSimilar); //carefull ths fn is overloaded for easier python interactivity. You should use the one that takes box etc in.

nb::class_<MatchEnv>(module, "MatchEnv")
.def(nb::init<>)
.def("getPointEnvironments", &MatchEnv::getPointEnvironments)
// export EnvironmentCluster class
// export compute fn
// export getClusterIdx fn
// export getClusterEnvironments fn
// export getNumClusters fn
// export EnvironmentMotifMatch class
// export compute fn
// export getMatches fn
// export EnvironmentRMSDMinimizer class
// export compute fn
// export getRMSDs fn

nb::class_<EnvironmentCluster>(module, "EnvironmentCluster")
.def(nb::init<>())
.def("compute", &EnvironmentCluster::compute)
.def("getClusters", &EnvironmentCluster::getClusterIdx)
.def("getClusterEnvironments", &EnvironmentCluster::getClusterEnvironments)
.def("getNumClusters", &EnvironmentCluster::getNumClusters)

nb::class_<EnvironmentMotifMatch>(module, "EnvironmentMotifMatch")
.def(nb::init<>())
.def("compute", &wrap::compute_env_motif_match, nb::arg("nq"), nb::arg("nlist"), nb::arg("qargs"), nb::arg("motif"), nb::arg("motif_size"), nb::arg("threshold"), nb::arg("registration"))
.def("getMatches", &EnvironmentMotifMatch::getMatches)

nb::class_<EnvironmentRMSDMinimizer>(module, "EnvironmentRMSDMinimizer")
.def(nb::init<>())
.def("compute", &wrap::compute_env_rmsd_min, nb::arg("nq"), nb::arg("nlist"), nb::arg("qargs"), nb::arg("motif"), nb::arg("motif_size"), nb::arg("threshold"), nb::arg("registration"))
.def("getRMSDs", &EnvironmentRMSDMinimizer::getRMSDs)

}

Expand Down

0 comments on commit ccb8005

Please sign in to comment.