From c90526223a8237fa07b6c6cac1a37ce12155454f Mon Sep 17 00:00:00 2001 From: Pablo Zubieta <8410335+pabloferz@users.noreply.github.com> Date: Tue, 7 Feb 2023 16:05:43 -0600 Subject: [PATCH 1/5] Update python versions for CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a0fc9b05..cc451a92 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-20.04, macos-latest] - python-version: [3.7, 3.9] + python-version: [3.8, 3.9] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} on ${{ matrix.os }} From 743ad45ca57379be0c8614213b534b18037aa5f7 Mon Sep 17 00:00:00 2001 From: Malgorzata Zimon <74198137+maggiezimon@users.noreply.github.com> Date: Thu, 19 Jan 2023 15:45:56 -0600 Subject: [PATCH 2/5] Add utils for quaternion based rotations --- LICENSE.md | 33 ++++++ pysages/utils/__init__.py | 14 ++- pysages/utils/core.py | 5 +- pysages/utils/transformations.py | 196 +++++++++++++++++++++++++++++++ 4 files changed, 246 insertions(+), 2 deletions(-) create mode 100644 pysages/utils/transformations.py diff --git a/LICENSE.md b/LICENSE.md index 0c3655d4..a918b7ab 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -48,3 +48,36 @@ Whereas the following applies to `pysages/methods/abf.py`: For a list of contributors to the SSAGES project visit . + +The code in `pysages/utils/transformations.py` was adapted from +. +The following applies to the original implementation: + +> Copyright (c) 2006-2022, Christoph Gohlke +> All rights reserved. +> +> Redistribution and use in source and binary forms, with or without +> modification, are permitted provided that the following conditions are met: +> +> 1. Redistributions of source code must retain the above copyright notice, +> this list of conditions and the following disclaimer. +> +> 2. Redistributions in binary form must reproduce the above copyright notice, +> this list of conditions and the following disclaimer in the documentation +> and/or other materials provided with the distribution. +> +> 3. Neither the name of the copyright holder nor the names of its +> contributors may be used to endorse or promote products derived from +> this software without specific prior written permission. +> +> THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +> AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +> IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +> ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +> LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +> CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +> SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +> INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +> CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +> ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +> POSSIBILITY OF SUCH DAMAGE. diff --git a/pysages/utils/__init__.py b/pysages/utils/__init__.py index d9db7597..a4541a71 100644 --- a/pysages/utils/__init__.py +++ b/pysages/utils/__init__.py @@ -9,4 +9,16 @@ """ from .compat import JaxArray, check_device_array, solve_pos_def, try_import -from .core import Bool, Float, Int, Scalar, ToCPU, copy, dispatch, gaussian, identity +from .core import ( + Bool, + Float, + Int, + Scalar, + ToCPU, + copy, + dispatch, + eps, + gaussian, + identity, +) +from .transformations import quaternion_from_euler, quaternion_matrix diff --git a/pysages/utils/core.py b/pysages/utils/core.py index 1e1a6d20..73aaf1a1 100644 --- a/pysages/utils/core.py +++ b/pysages/utils/core.py @@ -13,7 +13,6 @@ # PySAGES main dispatcher dispatch = Dispatcher() - Bool = Union[JaxArray, bool] Float = Union[JaxArray, float] Int = Union[JaxArray, int] @@ -53,6 +52,10 @@ def identity(x): return x +def eps(T: type = np.zeros(0).dtype): + return np.finfo(T).eps + + def row_sum(x): """ Sum array `x` along each of its row (`axis = 1`), diff --git a/pysages/utils/transformations.py b/pysages/utils/transformations.py new file mode 100644 index 00000000..62702008 --- /dev/null +++ b/pysages/utils/transformations.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: MIT +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +# Code adapted from +# https://github.com/cgohlke/transformations/blob/v2022.9.26/transformations/transformations.py + +from typing import List, NamedTuple + +from jax import lax +from jax import numpy as np + +from pysages.utils.core import dispatch, eps + +# axes indices sequences for Euler angles +_NEXT_AXIS = [1, 2, 0, 1] + +# map axes strings to/from tuples of encodings for: +# ( +# "first axis": {"x": 0, "y": 1, "z": 2}, +# "axes ordering": {"right": 0, "left": 1}, +# "axes sequence": {"Proper Euler": 0, "Tait-Bryan": 1}, +# "frame/rotation": {"static/extrinsic": 0, "rotating/intrinsic": 1} +# ) +_AXES2TUPLE = { + "sxyz": (0, 0, 0, 0), + "sxyx": (0, 0, 1, 0), + "sxzy": (0, 1, 0, 0), + "sxzx": (0, 1, 1, 0), + "syzx": (1, 0, 0, 0), + "syzy": (1, 0, 1, 0), + "syxz": (1, 1, 0, 0), + "syxy": (1, 1, 1, 0), + "szxy": (2, 0, 0, 0), + "szxz": (2, 0, 1, 0), + "szyx": (2, 1, 0, 0), + "szyz": (2, 1, 1, 0), + "rzyx": (0, 0, 0, 1), + "rxyx": (0, 0, 1, 1), + "ryzx": (0, 1, 0, 1), + "rxzx": (0, 1, 1, 1), + "rxzy": (1, 0, 0, 1), + "ryzy": (1, 0, 1, 1), + "rzxy": (1, 1, 0, 1), + "ryxy": (1, 1, 1, 1), + "ryxz": (2, 0, 0, 1), + "rzxz": (2, 0, 1, 1), + "rxyz": (2, 1, 0, 1), + "rzyz": (2, 1, 1, 1), +} + +_TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items()) + + +class EulerAnglesType(NamedTuple): + "Base class for Euler angles types." + pass + + +class ProperEuler(EulerAnglesType): + """ + Dispatch class for signaling + `Proper Euler angles `_. + """ + + pass + + +class TaitBryan(EulerAnglesType): + """ + Dispatch class for signaling + `Tait-Bryan angles `_. + """ + + pass + + +class RotationAxes: + """ + Handles the translation from string or tuple encodings for rotations to an + appropritate representation for the `quaternion_from_euler` implementation. + + Parameters + ---------- + axes: Union[str, truple] + One of 24 axis sequences as string or encoded tuple + """ + + class Parameters(NamedTuple): + sequence: EulerAnglesType + j_sign: int + permutation: List[int] + intrinsic: bool + + def __init__(self, axes): + self.params = self.process_axes(axes) + + @dispatch + def process_axes(self, axes: str): + return self.process_axes(_AXES2TUPLE[axes.lower()], validate=False) + + @dispatch + def process_axes(self, rotation_mode: tuple, validate=True): # noqa: F811 + if validate: + _TUPLE2AXES[rotation_mode] + + first_axis, left_ordering, proper_euler, intrinsic = rotation_mode + + j_sign = -1 if left_ordering else 1 + sequence = ProperEuler() if proper_euler else TaitBryan() + + o = left_ordering + i = first_axis + 1 + j = _NEXT_AXIS[i + o - 1] + 1 + k = _NEXT_AXIS[i - o] + 1 + invperm = (0, i, j, k) + permutation = [invperm[n] for n in invperm] + + return self.Parameters(sequence, j_sign, permutation, intrinsic) + + +def quaternion_from_euler(ai, aj, ak, axes=RotationAxes("sxyz")): + """ + Return a quaternion from Euler angles and axis sequence. + + Arguments + --------- + ai, aj, ak: numbers.Real + Euler's roll, pitch and yaw angles + + axes: RotationAxes + One of 24 axis sequences as string or tuple wrapped as a RotationAxes + """ + + @dispatch + def quaternion_entries(seq: ProperEuler, cj, sj, cc, ss, cs, sc, sgn): + v0 = cj * (cc - ss) + vi = cj * (cs + sc) + vj = sj * (cc + ss) + vk = sj * (cs - sc) + return (v0, vi, sgn * vj, vk) + + @dispatch + def quaternion_entries(seq: TaitBryan, cj, sj, cc, ss, cs, sc, sgn): # noqa: F811 + v0 = cj * cc + sj * ss + vi = cj * sc - sj * cs + vj = cj * ss + sj * cc + vk = cj * cs - sj * sc + return (v0, vi, sgn * vj, vk) + + def _quaternion_from_euler(ai, aj, ak, sequence, j_sign, permutation): + ai /= 2.0 + aj /= 2.0 + ak /= 2.0 + ci = np.cos(ai) + si = np.sin(ai) + cj = np.cos(aj) + sj = np.sin(aj) + ck = np.cos(ak) + sk = np.sin(ak) + cc = ci * ck + cs = ci * sk + sc = si * ck + ss = si * sk + + v = quaternion_entries(sequence, cj, sj, cc, ss, cs, sc, j_sign) + return np.array([v[i] for i in permutation]) + + sequence, j_sign, permutation, intrinsic = axes.params + angles = (ak, j_sign * aj, ai) if intrinsic else (ai, j_sign * aj, ak) + return _quaternion_from_euler(*angles, sequence, j_sign, permutation) + + +def quaternion_matrix(quaternion, dtype: type = np.zeros(0).dtype): + """ + Return the homogeneous rotation matrix from a quaternion. + """ + + def _identity_matrix(*_): + return np.identity(4, dtype=dtype) + + def _quaternion_matrix(q, n): + q *= np.sqrt(2.0 / n) + Q = np.outer(q, q) + return np.array( + [ + [1.0 - Q[2, 2] - Q[3, 3], Q[1, 2] - Q[3, 0], Q[1, 3] + Q[2, 0], 0.0], + [Q[1, 2] + Q[3, 0], 1.0 - Q[1, 1] - Q[3, 3], Q[2, 3] - Q[1, 0], 0.0], + [Q[1, 3] - Q[2, 0], Q[2, 3] + Q[1, 0], 1.0 - Q[1, 1] - Q[2, 2], 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + + q = np.array(quaternion, dtype=dtype) + n = np.dot(q, q) + + return lax.cond(n < 4 * eps(), _identity_matrix, _quaternion_matrix, q, n) From 835a38b96eefe7c8e40fd2737b2a6199259727dd Mon Sep 17 00:00:00 2001 From: Malgorzata Zimon <74198137+maggiezimon@users.noreply.github.com> Date: Tue, 8 Nov 2022 11:55:42 -0600 Subject: [PATCH 3/5] Add dill, jax-md and jaxopt as a dependencies for CI --- .github/workflows/ci.yml | 4 ++-- Dockerfile | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cc451a92..cfc18f85 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,7 @@ jobs: - name: Install python dependecies run: | python -m pip install --upgrade pip - pip install jaxlib pytest matplotlib + pip install dill jaxlib jax-md jaxopt pytest matplotlib - name: Install pysages run: pip install . @@ -60,7 +60,7 @@ jobs: - name: Install python dependecies run: | python -m pip install --upgrade pip - pip install jaxlib pytest pylint flake8 + pip install dill jaxlib jax-md jaxopt pytest pylint flake8 pip install -r docs/requirements.txt - name: Install pysages run: pip install . diff --git a/Dockerfile b/Dockerfile index cf7b1100..42a375dc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,8 +5,9 @@ WORKDIR / RUN python -m pip install --upgrade pip RUN python -m pip install ase gsd matplotlib "pyparsing<3" -# Install JAX +# Install JAX and JAX-MD RUN python -m pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +RUN python -m pip install --upgrade jax-md jaxopt COPY . /PySAGES RUN pip install /PySAGES/ From a74e495d1e3df891f90d46379180308b22078aa8 Mon Sep 17 00:00:00 2001 From: Malgorzata Zimon <74198137+maggiezimon@users.noreply.github.com> Date: Thu, 27 Oct 2022 11:59:32 +0100 Subject: [PATCH 4/5] Add GeM --- pysages/colvars/__init__.py | 1 + pysages/colvars/patterns.py | 375 ++++++++++++++++++++++++++++++++++++ tests/test_pickle.py | 32 ++- 3 files changed, 407 insertions(+), 1 deletion(-) create mode 100644 pysages/colvars/patterns.py diff --git a/pysages/colvars/__init__.py b/pysages/colvars/__init__.py index b067bb57..2e85766b 100644 --- a/pysages/colvars/__init__.py +++ b/pysages/colvars/__init__.py @@ -12,6 +12,7 @@ from .angles import Angle, DihedralAngle from .coordinates import Component, Displacement, Distance +from .patterns import GeM from .shape import ( Acylindricity, Asphericity, diff --git a/pysages/colvars/patterns.py b/pysages/colvars/patterns.py new file mode 100644 index 00000000..16ab87dd --- /dev/null +++ b/pysages/colvars/patterns.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: MIT +# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES + +import time + +from jax import lax +from jax import numpy as np +from jax import random, vmap +from jax.numpy import linalg +from jax_md.partition import space +from jaxopt import GradientDescent as minimize + +from pysages.colvars.core import CollectiveVariable +from pysages.utils import gaussian, quaternion_from_euler, quaternion_matrix + + +def rotate_pattern_with_quaternions(rot_q, pattern): + return np.transpose(np.dot(quaternion_matrix(rot_q)[:3, :3], np.transpose(pattern))) + + +def func_to_optimise(Q, modified_pattern, local_pattern): + return np.linalg.norm(rotate_pattern_with_quaternions(Q, modified_pattern) - local_pattern) + + +# Main class implementing the GeM CV +class Pattern: + """ + For determining nearest neighbors, + [JAX MD](https://jax-md.readthedocs.io/en/main/jax_md.partition.html) + neighborlist library is utilized. This requires the user + to define the indices of all the atoms in the system and a JAX MD + neighbor list callable for updating the state. + """ + + def __init__( + self, + simulation_box, + fractional_coords, + reference, + neighborlist, + characteristic_distance, + centre_j_id, + standard_deviation, + mesh_size, + ): + + self.characteristic_distance = characteristic_distance + self.reference = reference + self.neighborlist = neighborlist + self.simulation_box = simulation_box + self.centre_j_id = centre_j_id + # This is added to handle neighborlists with fractional coordinates + # (needed for NPT simulations) + if fractional_coords: + self.positions = self.neighborlist.reference_position * np.diag(self.simulation_box) + else: + self.positions = self.neighborlist.reference_position + self.centre_j_coords = self.positions[self.centre_j_id] + self.standard_deviation = standard_deviation + self.mesh_size = mesh_size + + def comp_pair_distance_squared(self, pos1): + displacement_fn, shift_fn = space.periodic(np.diag(self.simulation_box)) + mic_vector = displacement_fn(self.centre_j_coords, pos1) + mic_norm = linalg.norm(mic_vector) + return mic_norm, mic_vector + + def _generate_neighborhood(self): + self._neighborhood = [] + + positions_of_all_nbrs = self.positions[self.neighborlist.idx[self.centre_j_id]] + distances, mic_vectors = vmap(self.comp_pair_distance_squared)(positions_of_all_nbrs) + # remove the same atom from the neighborhood + distances = np.where(distances != 0.0, distances, 1e5) + # remove the number of atoms from the list of indices + distances = np.where( + self.neighborlist.idx[self.centre_j_id] != len(self.neighborlist.idx), distances, 1e5 + ) + + ids_of_neighbors = np.argsort(distances)[: len(self.reference)] + + coordinates = mic_vectors[ids_of_neighbors] + self.centre_j_coords + # Step 1: Translate to origin; + coordinates = coordinates.at[:].set(coordinates - np.mean(coordinates, axis=0)) + for vec_id, mic_vector in enumerate(mic_vectors[ids_of_neighbors]): + neighbor = { + "id": ids_of_neighbors[vec_id], + "coordinates": coordinates[vec_id], + "mic_vector": mic_vector, + "pos_wrt_j": self.centre_j_coords - mic_vector, + "distance": distances[ids_of_neighbors[vec_id]], + } + self._neighborhood.append(neighbor) + + self._neighbor_coords = np.array([n["coordinates"] for n in self._neighborhood]) + self._orig_neighbor_coords = positions_of_all_nbrs[ids_of_neighbors] + + def compute_score(self, optim_reference): + r = self._neighbor_coords - optim_reference + return np.prod(gaussian(1, self.standard_deviation, r)) + + def rotate_reference(self, random_euler_point): + # Perform rotation of the reference pattern; + # Using Euler angles in radians construct a quaternion base; + # Convert the quaternion to a 3x3 rotation matrix. + rot_q = quaternion_from_euler(*random_euler_point) + return rotate_pattern_with_quaternions(rot_q, self.reference) + + def resort(self, pattern_to_resort, key): + # This subroutine shuffles randomly the input local pattern + # and resorts the reference indices in order to "minimise" + # the distance of the corresponding sites + + random_indices = random.permutation( + key, np.arange(len(self._neighborhood)), axis=0, independent=False + ) + random_neighbor_coords = self._neighbor_coords[random_indices] + + def find_closest(carry, neighbor_coords): + ref_positions = carry + distances = [np.linalg.norm(vector - neighbor_coords) for vector in ref_positions] + min_index = np.argmin(np.array(distances)) + positions = ref_positions.at[min_index].set(np.array([-1e10, -1e10, -1e10])) + new_ref_positions = ref_positions[min_index] + return positions, new_ref_positions + + _, closest_reference = lax.scan(find_closest, pattern_to_resort, random_neighbor_coords) + # Reorder the reference to match the positions of the neighbors + reshuffled_reference = np.zeros_like(closest_reference) + reshuffled_reference = reshuffled_reference.at[random_indices].set(closest_reference) + return reshuffled_reference, random_indices + + def check_settled_status(self, modified_reference): + def mark_close_sites(_, reference_pos): + def return_close(_, n): + return lax.cond( + np.linalg.norm(n - reference_pos) < self.characteristic_distance / 2.0, + lambda x: (None, x + 1), + lambda x: (None, x), + 0, + ) + + _, close_sites_per_reference = lax.scan(return_close, None, self._neighbor_coords) + return None, close_sites_per_reference + + _, close_sites = lax.scan(mark_close_sites, None, modified_reference) + _, indices = lax.scan( + lambda _, sites: ( + None, + lax.cond(np.sum(sites) == 1, lambda s: s, lambda s: np.zeros_like(s), sites), + ), + None, + close_sites, + ) + # Return the locations of settled nighbours in the neighborhood; + # Settlled site should have a unique neighbor + settled_neighbor_indices = np.where(np.sum(indices, axis=0) >= 1, 1, 0) + return settled_neighbor_indices + + def driver_match(self, number_of_rotations, number_of_opt_steps, num): + + self._generate_neighborhood() + + """Step2: Scale the reference so that the spread matches + with the current local pattern""" + local_distance = 0.0 + reference_distance = 0.0 + for n_index, neighbor in enumerate(self._neighborhood): + local_distance += np.dot(neighbor["coordinates"], neighbor["coordinates"]) + reference_distance += np.dot(self.reference[n_index], self.reference[n_index]) + + self.reference *= np.sqrt(local_distance / reference_distance) + + """Step3: mesh-loop -> Define angles in reduced Euler domain, + and for each rotate, resort and score the pattern + + The implementation below follows the article Martelli et al. 2018 + + + (a) Randomly with uniform probability pick a point in the Euler domain, + (b) Rotate the reference + (c) Resort the local pattern and assign the closest reference sites, + (d) Perform the optimisation step (conjugate gradient), + and (e) store the score with (f) the final settled status""" + + def get_all_scores(newkey, euler_point): + # b. Rotate the reference pattern + rotated_reference = self.rotate_reference(euler_point) + # c. Resort; shuffle the local pattern + # and assign ids to the closest reference sites + newkey, newsubkey = random.split(random.PRNGKey(newkey)) + reshuffled_reference, random_indices = self.resort(rotated_reference, newsubkey) + # d. Find the best rotation that aligns the settled sites + # in both patterns; + # Here, ‘optimal’ or ‘best’ is in terms of least squares errors + solver = minimize(fun=func_to_optimise, maxiter=number_of_opt_steps) + # We are fixing the initial guess for the quaternions; + # different starting conditions are obtained by working + # with a rotated reference (this can be changed) + optim = solver.run( + init_params=np.array([0.1, 0.0, 0.0, 0.1]), + modified_pattern=reshuffled_reference, + local_pattern=self._neighbor_coords, + ) + optimal_reference = rotate_pattern_with_quaternions(optim.params, reshuffled_reference) + # e. Compute and store the score + score = self.compute_score(optimal_reference) + result = dict( + score=score, + rotated_pattern=rotated_reference, + random_indices=random_indices, + reshuffled_pattern=reshuffled_reference, + pattern=optimal_reference, + quaternions=optim.params, + ) + return result + + # a. Randomly pick a point in the Euler domain + + key, subkey = random.split(random.PRNGKey(num)) + mesh_size = self.mesh_size + grid_dimension = np.pi / mesh_size + euler_angles = np.arange( + 0, 0.125 * np.pi + (mesh_size / 2 + 1) * grid_dimension, grid_dimension + ) + random_points = random.randint( + subkey, (number_of_rotations, 3), minval=0.0, maxval=mesh_size + ) + # Excute find_max_score for each angle + # and store the result with the highest score + + scoring_results = vmap(get_all_scores)( + num + np.arange(number_of_rotations), euler_angles[random_points] + ) + optimal_case = np.argmax(scoring_results["score"]) + + # f. Check how many settled sites there are + settled_neighbor_ids = self.check_settled_status(scoring_results["pattern"][optimal_case]) + + # Storing all the data is only for validation and analysis; + # For FFS, only score will be returned, i.e., optimal_result['score']; + # This then can be removed + optimal_result = dict( + score=scoring_results["score"][optimal_case], + rotated_pattern=scoring_results["rotated_pattern"][optimal_case], + random_indices=scoring_results["random_indices"][optimal_case], + reshuffled_pattern=scoring_results["reshuffled_pattern"][optimal_case], + pattern=scoring_results["pattern"][optimal_case], + quaternions=scoring_results["quaternions"][optimal_case], + settled=settled_neighbor_ids, + centre_atom=self.centre_j_coords, + neighborhood=self._neighbor_coords, + neighborhood_orig=self._orig_neighbor_coords, + ) + return optimal_result + + +def calculate_lom(all_positions: np.array, neighborlist, simulation_box, params): + + if params.fractional_coords: + update_neighborlist = neighborlist.update(np.divide(all_positions, np.diag(simulation_box))) + else: + update_neighborlist = neighborlist.update(all_positions) + + """'Step1: Move the reference and + local patterns so that their centers coincide with the origin""" + + reference_positions = params.reference_positions.at[:].set( + params.reference_positions - np.mean(params.reference_positions, axis=0) + ) + + # Calculate scores + seed = np.int64(time.process_time() * 1e5) + optimal_results = vmap( + lambda i: Pattern( + params.box, + params.fractional_coords, + reference_positions, + update_neighborlist, + params.characteristic_distance, + i, + params.standard_deviation, + params.mesh_size, + ).driver_match( + params.number_of_rotations, + params.number_of_opt_it, + seed + i * params.number_of_rotations, + ) + )(np.arange(len(all_positions), dtype=np.int64)) + average_score = np.sum(optimal_results["score"]) / len(all_positions) + return average_score + + +class GeM(CollectiveVariable): + """ + This CV implements a Geometry Matching (GeM) Local Order Metric (LOM). + The algorithm enabling the measurement of order in the neighborhood of + an atomic or a molecular site is described in + [Martelli2018](https://journals.aps.org/prb/abstract/10.1103/PhysRevB.97.064105). + + Given a pattern, the algorithm is returning an average score (from 0 to 1), + denoting how closely the atomic neighbors resemble the reference. + + For determining nearest neighbors, + [JAX MD](https://jax-md.readthedocs.io/en/main/jax_md.partition.html) + neighborlist library is utilized. This requires the user + to define the indices of all the atoms in the system and a JAX MD + neighbor list callable for updating the state. + + Matching a neighborhood to the pattern is an optimization process. + Based on the number of initial rotations of the reference structure + and opt. iterations, we aim to find a rotation matrix Q + that minimizes norm(a-Q*b), where a is the neighborhood + and b denotes the reference. This is defined in `func_to_optimise`. + Optimization is performed using [JAXopt](https://github.com/google/jaxopt). + + Parts of the code related to JAX compatible 3d transformations + (e.g., quaternion_matrix) are taken from + [jax_transformations3d](https://github.com/cpgoodri/jax_transformations3d). + + Parameters + ---------- + indices: list + List of indices of all atoms required for updating neighbor list. + reference_positions: JaxArray + box: JaxArray + Definition of the simulation box. + number_of_rotations: integer + Number of initial rotated structures for the optimization study. + number_of_opt_it: iteger + Number of iterations for gradient descent. + standard_deviation: float + Parameter that controls the spread of the Gaussian function. + mesh_size: integer + Defines the size of the angular grid from which we draw + random Euler angles. + nbrs: callable + JAX MD neighbor list function to update the neighbor list. + fractional_coords: Bool + Set to True if NPT simulation is considered and the box size + changes; use periodic_general for constructing the neighborlist. + Returns + ------- + calculate_lom: float + Average score defining the degree of overlap + with the reference structure. It's a measure of the global order. + """ + + def __init__( + self, + indices, + reference_positions, + box, + number_of_rotations, + number_of_opt_it, + standard_deviation, + mesh_size, + nbrs, + fractional_coords, + ): + super().__init__(indices, group_length=None) + + self.reference_positions = np.asarray(reference_positions) + self.box = np.asarray(box) + self.number_of_rotations = number_of_rotations + self.number_of_opt_it = number_of_opt_it + self.standard_deviation = standard_deviation + self.characteristic_distance = standard_deviation * 2 + self.mesh_size = mesh_size + self.nbrs = nbrs + self.fractional_coords = fractional_coords + + @property + def function(self): + return lambda rs: calculate_lom(rs, self.nbrs, self.box, self) diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 7131db08..541e9de5 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -1,7 +1,8 @@ import inspect -import pickle import tempfile +import dill as pickle +import jax_md as jmd import numpy as np import pysages @@ -11,6 +12,20 @@ pi = np.pi +def build_neighbor_list(box_size, positions, r_cutoff, capacity_multiplier): + """Helper function to generate a jax-md neighbor list""" + displacement_fn, shift_fn = jmd.space.periodic(box_size) + neighbor_list_fn = jmd.partition.neighbor_list( + displacement_fn, + box_size, + r_cutoff, + capacity_multiplier=capacity_multiplier, + format=jmd.partition.NeighborListFormat.Dense, + ) + neighbors = neighbor_list_fn.allocate(positions) + return neighbors + + METHODS_ARGS = { "HarmonicBias": {"cvs": [pysages.colvars.Component([0], 0)], "kspring": 15.0, "center": 0.7}, "Unbiased": {"cvs": [pysages.colvars.Component([0], 0)]}, @@ -111,6 +126,21 @@ def test_pickle_methods(): "Component": {"indices": [0, 1, 2, 3], "axis": 0}, "Distance": {"indices": [0, 1]}, "Displacement": {"indices": [[0], [1]]}, + "GeM": { + "indices": np.arange(20), + "reference_positions": np.array( + [[1.0, 1.0, 1.0], [-1.0, -1.0, 1.0], [-1.0, 1.0, -1.0], [1.0, -1.0, -1.0]] + ), + "box": 2 * np.eye(3), + "number_of_rotations": 20, + "number_of_opt_it": 10, + "standard_deviation": 0.125, + "mesh_size": 30, + "nbrs": build_neighbor_list( + 2.0, positions=np.random.randn(20, 3), r_cutoff=1.5, capacity_multiplier=1.0 + ), + "fractional_coords": True, + }, } From 6ce962c808e0d0537d33ac5e610784b791fd3c51 Mon Sep 17 00:00:00 2001 From: Malgorzata Zimon <74198137+maggiezimon@users.noreply.github.com> Date: Tue, 17 Jan 2023 15:37:32 -0600 Subject: [PATCH 5/5] Conditionally enable GeM --- pysages/colvars/__init__.py | 10 +++++++++- pysages/colvars/patterns.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pysages/colvars/__init__.py b/pysages/colvars/__init__.py index 2e85766b..0ccbf330 100644 --- a/pysages/colvars/__init__.py +++ b/pysages/colvars/__init__.py @@ -12,7 +12,6 @@ from .angles import Angle, DihedralAngle from .coordinates import Component, Displacement, Distance -from .patterns import GeM from .shape import ( Acylindricity, Asphericity, @@ -21,3 +20,12 @@ ShapeAnisotropy, ) from .utils import get_periods, wrap + +# Conditionally export GeM if both `jax_md` and `jaxopt` are available +try: + import jax_md + import jaxopt + + from .patterns import GeM +except ImportError: + pass diff --git a/pysages/colvars/patterns.py b/pysages/colvars/patterns.py index 16ab87dd..eb0d2f60 100644 --- a/pysages/colvars/patterns.py +++ b/pysages/colvars/patterns.py @@ -262,7 +262,7 @@ def calculate_lom(all_positions: np.array, neighborlist, simulation_box, params) else: update_neighborlist = neighborlist.update(all_positions) - """'Step1: Move the reference and + """Step1: Move the reference and local patterns so that their centers coincide with the origin""" reference_positions = params.reference_positions.at[:].set(