Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Pathway class #305

Merged
merged 25 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4c59fd6
Skip sites check because it cannot be None
stefsmeets Apr 15, 2024
40fb253
Match volume with Path
stefsmeets Apr 15, 2024
34c514c
Refactor Pathway
stefsmeets Apr 15, 2024
061365a
Add method to calculate path length
stefsmeets Apr 15, 2024
bb028c3
Return PeriodicSite from path_over_structure
stefsmeets Apr 15, 2024
d1faa1e
Use np.testing.assert_allclose
stefsmeets Apr 15, 2024
6ada1a2
Fix bug
stefsmeets Apr 15, 2024
8c8316b
Update plot to work with PeriodicSites
stefsmeets Apr 15, 2024
d0073f0
Add shortcuts to path plots as methods
stefsmeets Apr 15, 2024
291fae6
Optimize api for generating Pathways from Volumes
stefsmeets Apr 16, 2024
1d41a25
Subclass to FreeEnergyVolume
stefsmeets Apr 16, 2024
68fc290
Remove redundant date.shape
stefsmeets Apr 16, 2024
e65be69
Add shortcut for multiple/percolating_paths on volume
stefsmeets Apr 16, 2024
334df24
Update function names
stefsmeets Apr 16, 2024
2c044c4
Simplify percolate argument
stefsmeets Apr 16, 2024
b5dd846
Fix plot and get rid of Volume requirement for paths
stefsmeets Apr 16, 2024
cea568e
Fix tests
stefsmeets Apr 16, 2024
ce915e9
Refactor plotting code
stefsmeets Apr 16, 2024
0e892c7
Yapf it
stefsmeets Apr 16, 2024
06e00d2
Update src/gemdat/plots/matplotlib/_paths.py
stefsmeets Apr 22, 2024
cb31005
Merge branch 'main' into pathway275
stefsmeets Apr 22, 2024
03a2ea3
Always wrap for frac_sites
stefsmeets Apr 22, 2024
b76956e
Rename file for consistency
stefsmeets Apr 22, 2024
d8fc099
Clean percolate function
stefsmeets Apr 22, 2024
12559b4
Remove whitespace
stefsmeets Apr 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 124 additions & 68 deletions src/gemdat/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@

from __future__ import annotations

from collections.abc import Collection
from dataclasses import dataclass
from typing import Literal
from itertools import pairwise
from typing import TYPE_CHECKING, Literal

import networkx as nx
import numpy as np
from pymatgen.core import Structure
from pymatgen.core.units import FloatWithUnit

from gemdat.volume import Volume
from gemdat.volume import FreeEnergyVolume

from .utils import nearest_structure_reference

if TYPE_CHECKING:
from pymatgen.core import Lattice, PeriodicSite


@dataclass
class Pathway:
Expand All @@ -25,16 +31,21 @@ class Pathway:
List of voxel coordinates of the sites defining the path
energy: list[float]
List of the energy along the path
dims: [int, int, int] | None
Voxel dimensions of bounding box. If set (usually to `Volume.dims`),
enable some site transformations.
"""

sites: list[tuple[int, int, int]]
energy: list[float]
dims: tuple[int, int, int] | None = None

def __repr__(self):
s = (
f'Path: {self.start_site} -> {self.stop_site}',
f'Steps: {len(self.sites)}',
f'Total energy: {self.total_energy:.3f} eV',
f'Dimensions: {self.dims}',
)

return '\n'.join(s)
Expand All @@ -44,42 +55,77 @@ def total_energy(self):
"""Return total energy for path."""
return sum(self.energy)

def wrap(self, dims: tuple[int, int, int]):
"""Wrap path in periodic boundary conditions in-place.
def total_length(self, lattice: Lattice) -> FloatWithUnit:
"""Return total length of pathway in Ångstrom.

Parameters
----------
F: np.ndarray
Grid in which the path sites will be wrapped
lattice : Lattice
Lattice parameters

Returns
-------
length : FloatWithUnit
Total distance in Ångstrom
"""
length = 0
for a, b in pairwise(self.frac_sites()):
dist, _ = lattice.get_distance_and_image(a, b)
length += dist
return FloatWithUnit(length, 'ang')

def wrapped_sites(self) -> list[tuple[int, int, int]]:
"""Wrap sites to bounding box.

Returns
-------
np.ndarray
Voxel coordinates wrapped to bounding box.
"""
if self.sites is None:
raise ValueError('Voxel coordinates of the path are required.')
if not self.dims:
raise AttributeError(
f'Dimensions are needed for this method {self.dims=}')
xdim, ydim, zdim = self.dims
return [(x % xdim, y % xdim, z % xdim) for x, y, z in self.sites]

def frac_sites(self, wrapped: bool = False) -> np.ndarray:
"""Return fractional sites.

Parameters
----------
wrapped : bool
If True, wrap coordinates to bounding box using
`.wrapped_sites()`.

stefsmeets marked this conversation as resolved.
Show resolved Hide resolved
X, Y, Z = dims
self.sites = [(x % X, y % Y, z % Z) for x, y, z in self.sites]
Returns
-------
np.ndarray
Fractional coordinates for sites
"""
if not self.dims:
raise AttributeError(
f'Dimensions are needed for this method {self.dims=}')
sites = self.wrapped_sites() if wrapped else self.sites
return (np.array(sites) + 0.5) / np.array(self.dims)

def path_over_structure(
self,
structure: Structure,
volume: Volume,
) -> tuple[list[str], list[np.ndarray]]:
) -> list[PeriodicSite]:
"""Find the nearest site of the structure to the path sites.

Parameters
----------
structure : Structure
Reference structure
volume : Volume
Volume object that contains the information about the nearest sites of the structure

Returns
-------
nearest_structure_label: list[str]
List of the label of the closest site of the reference structure
nearest_structure_coord: list[np.ndarray]
List of cartesian coordinates of the closest site of the reference structure
nearest_sites: list[PeriodicSite]
stefsmeets marked this conversation as resolved.
Show resolved Hide resolved
List closest sites of the reference structure
"""
frac_sites = volume.voxel_to_frac_coords(np.array(self.sites))
frac_sites = np.array(self.frac_sites(wrapped=True))

nearest_structure_tree, nearest_structure_map = nearest_structure_reference(
structure)

Expand All @@ -88,40 +134,41 @@ def path_over_structure(
nearest_structure_tree.query(site)[1] for site in frac_sites
]
# and use it to get its label and coordinates
nearest_structure_label = [
structure.labels[nearest_structure_map[index]]
for index in nearest_structure_indices
]
nearest_structure_coord = [
structure.cart_coords[nearest_structure_map[index]]
nearest_sites = [
structure[nearest_structure_map[index]]
for index in nearest_structure_indices
]

return nearest_structure_label, nearest_structure_coord
return nearest_sites

@property
def start_site(self) -> tuple[int, int, int]:
"""Return first site."""
if self.sites is None:
raise ValueError('Voxel coordinates of the path are required.')
return self.sites[0]

@property
def stop_site(self) -> tuple[int, int, int]:
"""Return stop site."""
if self.sites is None:
raise ValueError('Voxel coordinates of the path are required.')
return self.sites[-1]

def plot_energy_along_path(self, **kwargs):
"""See [gemdat.plots.energy_along_path][] for more info."""
from gemdat import plots
return plots.energy_along_path(path=self, **kwargs)

def free_energy_graph(F: np.ndarray | Volume,
def plot_path_on_grid(self, **kwargs):
"""See [gemdat.plots.path_on_grid][] for more info."""
from gemdat import plots
return plots.path_on_grid(path=self, **kwargs)


def free_energy_graph(F: np.ndarray | FreeEnergyVolume,
max_energy_threshold: float = 1e20,
diagonal: bool = True) -> nx.Graph:
"""Compute the graph of the free energy for networkx functions.

Parameters
----------
F : np.ndarray | Volume
F : np.ndarray | FreeEnergyVolume
Free energy on the 3d grid
max_energy_threshold : float, optional
Maximum energy threshold for the path to be considered valid
Expand All @@ -148,7 +195,7 @@ def free_energy_graph(F: np.ndarray | Volume,

G = nx.Graph()

data = F.data if isinstance(F, Volume) else F
data = F.data if isinstance(F, FreeEnergyVolume) else F

for index, Fi in np.ndenumerate(data):
if 0 <= Fi < max_energy_threshold:
Expand Down Expand Up @@ -231,26 +278,26 @@ def _paths_too_similar(path: list, list_of_paths: list,
return False


def multiple_paths(
*,
def optimal_n_paths(
F_graph: nx.Graph,
start: tuple,
stop: tuple,
*,
start: Collection,
stop: Collection,
method: _PATHFINDING_METHODS = 'dijkstra',
n_paths: int = 3,
min_diff: float = 0.15,
) -> list[Pathway]:
""" Calculate the Np shortest paths between two sites on the graph.
"""Calculate the n_paths shortest paths between two sites on the graph.
This procedure is based the algorithm by Jin Y. Yen (https://doi.org/10.1287/mnsc.17.11.712)
and its implementation in NetworkX. Only paths that are different by at least min_diff are considered.

Parameters
----------
F_graph : nx.Graph
Graph of the free energy
start : tuple
start : Collection
Coordinates of the starting point
stop: tuple
stop: Collection
Coordinates of the stopping point
method : str
Method used to calculate the shortest path. Options are:
Expand All @@ -269,9 +316,11 @@ def multiple_paths(
list_of_paths: list[Pathway]
List of the n_paths shortest paths between the start and stop sites
"""
start = tuple(start)
stop = tuple(stop)

# First compute the optimal path
best_path = optimal_path(F_graph, start, stop, method)
best_path = optimal_path(F_graph, start=start, stop=stop, method=method)

list_of_paths = [best_path]

Expand All @@ -297,8 +346,9 @@ def multiple_paths(

def optimal_path(
F_graph: nx.Graph,
start: tuple,
stop: tuple,
*,
start: Collection,
stop: Collection,
method: _PATHFINDING_METHODS = 'dijkstra',
) -> Pathway:
"""Calculate the shortest cost-effective path using the desired method.
Expand All @@ -307,9 +357,9 @@ def optimal_path(
----------
F_graph : nx.Graph
Graph of the free energy
start : tuple
start : Collection
Coordinates of the starting point
stop: tuple
stop: Collection
Coordinates of the stoping point
method : str
Method used to calculate the shortest path. Options are:
Expand All @@ -334,15 +384,20 @@ def optimal_path(
if method in ('dijkstra-exp', 'minmax-energy', 'simple'):
method = 'dijkstra'

start = tuple(start)
stop = tuple(stop)

optimal_path = nx.shortest_path(F_graph,
source=start,
target=stop,
weight=weight,
method=method)

if method == 'minmax-energy':
optimal_path = _optimal_path_minmax_energy(F_graph, start, stop,
optimal_path)
optimal_path = _optimal_path_minmax_energy(F_graph,
start=start,
stop=stop,
optimal_path=optimal_path)
elif method not in ('dijkstra', 'bellman-ford', 'dijkstra-exp'):
raise ValueError(f'Unknown method {method}')

Expand All @@ -353,6 +408,7 @@ def optimal_path(

def _optimal_path_minmax_energy(
F_graph: nx.Graph,
*,
start: tuple[int, int, int],
stop: tuple[int, int, int],
optimal_path: list,
Expand Down Expand Up @@ -402,34 +458,34 @@ def _optimal_path_minmax_energy(
return optimal_path


def find_best_perc_path(F: Volume,
peaks: np.ndarray,
percolate_x: bool = True,
percolate_y: bool = False,
percolate_z: bool = False) -> Pathway | None:
"""Calculate the best percolating path.
def optimal_percolating_path(
F: FreeEnergyVolume,
*,
peaks: np.ndarray,
percolate: str,
) -> Pathway | None:
"""Calculate the optimal percolating path.

Parameters
----------
F : Volume
F : FreeEnergyVolume
Energy grid that will be used to calculate the shortest path
peaks : np.ndarray
List of the peaks that correspond to high probability regions
percolate_x : bool
If True, consider paths that percolate along the x dimension
percolate_y : bool
If True, consider paths that percolate along the y dimension
percolate_z : bool
If True, consider paths that percolate along the z dimension
percolate : str
Directions to percolate, e.g. 'x' to consider paths that
percolate along the x dimension, 'yz' for the y/z dimension,
or any other combinition of 'x', 'y', and 'z'.

Returns
-------
best_percolating_path: Pathway
Optimal path that percolates the graph in the specified directions
"""
xyz_real = F.dims
percolate_x = 'x' in percolate
percolate_y = 'y' in percolate
percolate_z = 'z' in percolate

# Find percolation using virtual images along the required dimensions
if not any([percolate_x, percolate_y, percolate_z]):
raise ValueError('percolation is not defined')

Expand All @@ -445,7 +501,7 @@ def find_best_perc_path(F: Volume,
# reaching the percolating image
image = tuple(
x * px
for x, px in zip(xyz_real, (percolate_x, percolate_y, percolate_z)))
for x, px in zip(F.dims, (percolate_x, percolate_y, percolate_z)))

# Find the lowest cost path that percolates along the x dimension
best_cost = float('inf')
Expand All @@ -460,8 +516,8 @@ def find_best_perc_path(F: Volume,
try:
path = optimal_path(
F_graph,
tuple(start_point),
tuple(stop_point),
start=start_point,
stop=stop_point,
)
except nx.NetworkXNoPath:
continue
Expand All @@ -473,7 +529,7 @@ def find_best_perc_path(F: Volume,
best_path = path

if best_path:
# Before returning, wrap the path in the original volume
best_path.wrap(F.dims)
stefsmeets marked this conversation as resolved.
Show resolved Hide resolved
# Before returning, set dimensions of original volume
best_path.dims = F.dims

return best_path
Loading
Loading