Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Orientations api #303

Merged
merged 16 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/notebooks
Submodule notebooks updated 3 files
+15 −1 README.md
+880 −880 introduction.ipynb
+237 −200 orientations.ipynb
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ nav:
- Pathways: notebooks/paths.ipynb
- Percolation: notebooks/percolation.ipynb
- Multiple paths: notebooks/multiple_paths.ipynb
- Orientation tracking: notebooks/rotations.ipynb
- Orientation tracking: notebooks/orientations.ipynb
- Python API:
- gemdat: api/gemdat.md
- gemdat.collective: api/gemdat_collective.md
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ source = ["gemdat"]
[tool.pytest.ini_options]
testpaths = ["tests"]

[tool.ruff]
[tool.lint]
# Enable Pyflakes `E` and `F` codes by default.
select = [
"F", # Pyflakes
Expand All @@ -93,7 +93,7 @@ select = [

line-length = 110

[tool.ruff.isort]
[tool.lint.isort]
known-first-party=["gemdat"]
known-third-party = ["pymatgen"]
required-imports = ["from __future__ import annotations"]
Expand Down
2 changes: 1 addition & 1 deletion src/gemdat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .io import load_known_material, read_cif
from .jumps import Jumps
from .rotations import Orientations
from .orientations import Orientations
stefsmeets marked this conversation as resolved.
Show resolved Hide resolved
from .shape import ShapeAnalyzer
from .simulation_metrics import SimulationMetrics
from .trajectory import Trajectory
Expand Down
92 changes: 40 additions & 52 deletions src/gemdat/rotations.py → src/gemdat/orientations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pymatgen.symmetry.groups import PointGroup

from gemdat.trajectory import Trajectory
from gemdat.utils import fft_autocorrelation, cartesian_to_spherical


@dataclass
Expand All @@ -22,15 +23,12 @@ class Orientations:
Type of the central atoms
satellite_type: str
Type of the satellite atoms
nr_central_atoms: int
Number of central atoms, which corresponds to the number of cluster molecules
stefsmeets marked this conversation as resolved.
Show resolved Hide resolved
vectors: np.ndarray
Vectors representing rotation direction
Vectors representing orientation direction
"""
trajectory: Trajectory
center_type: str
satellite_type: str
nr_central_atoms: int
vectors: np.ndarray = field(init=False)
in_vectors: InitVar[np.ndarray | None] = None

Expand Down Expand Up @@ -66,11 +64,6 @@ def _trajectory_sat(self) -> Trajectory:
"""Return trajectory of satellite atoms."""
return self.trajectory.filter(self.satellite_type)

def _fractional_coordinates(self) -> tuple[np.ndarray, np.ndarray]:
"""Return fractional coordinates of central atoms and satellite
atoms."""
return self._trajectory_cent.positions, self._trajectory_sat.positions

@property
def _distances(self) -> np.ndarray:
"""Calculate distances between every central atom and all satellite
Expand Down Expand Up @@ -127,7 +120,11 @@ def _central_satellite_matrix(self, distance: np.ndarray,
combinations: np.ndarray
Matrix of combinations between central and satellite atoms
"""
index_central_atoms = np.arange(self.nr_central_atoms)
nr_central_atoms = frac_coord_cent.shape[1]

index_central_atoms = np.arange(nr_central_atoms)

# index_central_atoms = np.arange(self.nr_central_atoms)
matching_matrix = self._matching_matrix(distance, frac_coord_cent)
combinations = np.array([(i, j) for i in index_central_atoms
for j in matching_matrix[i, :]])
Expand All @@ -147,7 +144,9 @@ def _fractional_directions(self, distance: np.ndarray) -> np.ndarray:
direction: np.ndarray
Contains the direction between central atoms and their ligands.
"""
frac_coord_cent, frac_coord_sat = self._fractional_coordinates()
frac_coord_cent = self._trajectory_cent.positions
frac_coord_sat = self._trajectory_sat.positions

combinations = self._central_satellite_matrix(distance,
frac_coord_cent)

Expand Down Expand Up @@ -241,6 +240,36 @@ def transform(self, matrix: np.ndarray) -> Orientations:

return replace(self, in_vectors=vectors)

@property
def vectors_spherical(self) -> np.ndarray:
"""Return vectors in spherical coordinates in degrees.

Returns
-------
np.array
azimuth, elevation, length
"""
return cartesian_to_spherical(self.vectors)
stefsmeets marked this conversation as resolved.
Show resolved Hide resolved

def autocorrelation(self):
"""Compute the autocorrelation of the orientation vectors using FFT."""
return fft_autocorrelation(self.vectors)

def plot_rectilinear(self, **kwargs):
"""See [gemdat.plots.rectilinear][] for more info."""
from gemdat import plots
return plots.rectilinear(orientations=self, **kwargs)

def plot_bond_length_distribution(self, **kwargs):
"""See [gemdat.plots.bond_length_distribution][] for more info."""
from gemdat import plots
return plots.bond_length_distribution(orientations=self, **kwargs)

def plot_autocorrelation(self, **kwargs):
"""See [gemdat.plots.unit_vector_autocorrelation][] for more info."""
from gemdat import plots
return plots.autocorrelation(orientations=self, **kwargs)


def calculate_spherical_areas(shape: tuple[int, int],
radius: float = 1) -> np.ndarray:
Expand Down Expand Up @@ -273,44 +302,3 @@ def calculate_spherical_areas(shape: tuple[int, int],
#hacky way to get rid of singularity on poles
areas[0, :] = areas[-1, 0]
return areas


def mean_squared_angular_displacement(trajectory: np.ndarray) -> np.ndarray:
"""Compute the mean squared angular displacement using FFT.

Parameters
----------
trajectory : np.ndarray
The input signal in direct cartesian coordinates. It is expected
to have shape (n_times, n_particles, n_coordinates)

Returns
-------
msad:
The mean squared angular displacement
"""
n_times, n_particles, n_coordinates = trajectory.shape

msad = np.zeros((n_particles, n_times))
normalization = np.arange(n_times, 0, -1)

for c in range(n_coordinates):
signal = trajectory[:, :, c]

# Compute the FFT of the signal
fft_signal = np.fft.rfft(signal, n=2 * n_times - 1, axis=0)
# Compute the power spectral density in-place
np.square(np.abs(fft_signal), out=fft_signal)
# Compute the inverse FFT of the power spectral density
autocorr_c = np.fft.irfft(fft_signal, axis=0)

# Only keep the positive times
autocorr_c = autocorr_c[:n_times, :]

msad += autocorr_c.T / normalization

# Normalize the msad such that it starts from 1
# (this makes the normalization independent on the dimensions)
msad = msad / msad[:, 0, np.newaxis]

return msad
8 changes: 4 additions & 4 deletions src/gemdat/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

# Matplotlib plots
from .matplotlib import (
autocorrelation,
bond_length_distribution,
energy_along_path,
jumps_3d_animation,
path_on_grid,
radial_distribution,
rectilinear_plot,
rectilinear,
shape,
unit_vector_autocorrelation,
)

# Plotly plots (matplotlib version might be available)
Expand Down Expand Up @@ -45,10 +45,10 @@
'plot_3d',
'msd_per_element',
'radial_distribution',
'rectilinear_plot',
'rectilinear',
'shape',
'vibrational_amplitudes',
'energy_along_path',
'path_on_grid',
'unit_vector_autocorrelation',
'autocorrelation',
]
10 changes: 5 additions & 5 deletions src/gemdat/plots/matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
path_on_grid,
)
from ._rdf import radial_distribution
from ._rotations import (
from ._orientations import (
autocorrelation,
bond_length_distribution,
rectilinear_plot,
unit_vector_autocorrelation,
rectilinear,
)
from ._shape import shape
from ._vibration import (
Expand All @@ -46,8 +46,8 @@
'msd_per_element',
'path_on_grid',
'radial_distribution',
'rectilinear_plot',
'rectilinear',
'shape',
'unit_vector_autocorrelation',
'autocorrelation',
'vibrational_amplitudes',
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@
from scipy.optimize import curve_fit
from scipy.stats import skewnorm

from gemdat.rotations import (
from gemdat.orientations import (
Orientations,
calculate_spherical_areas,
mean_squared_angular_displacement,
)
from gemdat.utils import cartesian_to_spherical


def rectilinear_plot(*,
orientations: Orientations,
shape: tuple[int, int] = (90, 360),
normalize_histo: bool = True) -> plt.Figure:
def rectilinear(*,
orientations: Orientations,
shape: tuple[int, int] = (90, 360),
normalize_histo: bool = True) -> plt.Figure:
"""Plot a rectilinear projection of a spherical function. This function
uses the transformed trajectory.

Expand All @@ -34,11 +32,10 @@ def rectilinear_plot(*,
fig : matplotlib.figure.Figure
Output figure
"""
# Convert the trajectory to spherical coordinates
trajectory = cartesian_to_spherical(orientations.vectors, degrees=True)

az = trajectory[:, :, 0].flatten()
el = trajectory[:, :, 1].flatten()
# Convert the vectors to spherical coordinates
az, el, _ = orientations.vectors_spherical.T
az = az.flatten()
el = el.flatten()

hist, xedges, yedges = np.histogram2d(el, az, shape)

Expand Down Expand Up @@ -92,14 +89,11 @@ def bond_length_distribution(*,
fig : matplotlib.figure.Figure
Output figure
"""

# Convert the trajectory to spherical coordinates
trajectory = cartesian_to_spherical(orientations.vectors, degrees=True)
*_, bond_lengths = orientations.vectors_spherical.T
bond_lengths = bond_lengths.flatten()

fig, ax = plt.subplots()

bond_lengths = trajectory[:, :, 2].flatten()

# Plot the normalized histogram
hist, edges = np.histogram(bond_lengths, bins=bins, density=True)
bin_centers = (edges[:-1] + edges[1:]) / 2
Expand Down Expand Up @@ -135,7 +129,7 @@ def _skewnorm_fit(x):
return fig


def unit_vector_autocorrelation(
def autocorrelation(
*,
orientations: Orientations,
) -> plt.Figure:
Expand All @@ -151,11 +145,7 @@ def unit_vector_autocorrelation(
fig : matplotlib.figure.Figure
Output figure
"""

# The trajectory is expected to have shape (n_times, n_particles, n_coordinates)
trajectory = orientations.vectors

ac = mean_squared_angular_displacement(trajectory)
ac = orientations.autocorrelation()
ac_std = ac.std(axis=0)
ac_mean = ac.mean(axis=0)

Expand Down
Loading
Loading