Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docstrings, type-hints, documentation fixes #78

Merged
merged 6 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions autoplex/benchmark/phonons/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
"""Utility functions for benchmarking jobs."""

from __future__ import annotations

from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine

if TYPE_CHECKING:
from matplotlib.figure import Figure
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
from pymatgen.phonon.plotter import PhononBSPlotter


def get_rmse(
ml_bs: PhononBandStructureSymmLine,
dft_bs: PhononBandStructureSymmLine,
q_dependent_rmse: bool = False,
):
) -> float | list[float]:
"""
Compute root mean squared error (rmse) between DFT and ML phonon band-structure.

Expand Down Expand Up @@ -44,7 +51,7 @@ def rmse_qdep_plot(
which_q_path=1,
file_name="rms.pdf",
img_format="pdf",
):
) -> plt:
"""
Save q dependent root mean squared error plot between DFT and ML phonon band-structure.

Expand Down Expand Up @@ -94,7 +101,7 @@ def compare_plot(
ml_bs: PhononBandStructureSymmLine,
dft_bs: PhononBandStructureSymmLine,
file_name: str = "band_comparison.pdf",
):
) -> Figure:
"""
Save DFT and ML phonon band-structure overlay plot for visual comparison.

Expand Down
35 changes: 19 additions & 16 deletions autoplex/data/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def scale_cell(
volume_scale_factor_range: list[float] | None = None,
n_structures: int = 10,
volume_custom_scale_factors: list[float] | None = None,
):
) -> list[Structure]:
"""
Take in a pymatgen Structure object and generates stretched or compressed structures.

Expand Down Expand Up @@ -154,9 +154,9 @@ def scale_cell(
return distorted_cells


def check_distances(structure: Structure, min_distance: float = 1.5):
def check_distances(structure: Structure, min_distance: float = 1.5) -> bool:
"""
Take in a pymatgen Structure object and checks distances between atoms using minimum image convention.
Take in a pymatgen Structure object and check minimum distances between atoms using minimum image convention.

Useful after distorting cell angles and rattling to check atoms aren't too close.

Expand Down Expand Up @@ -191,7 +191,7 @@ def random_vary_angle(
w_angle: list[float] | None = None,
n_structures: int = 8,
angle_max_attempts: int = 1000,
):
) -> list[Structure]:
"""
Take in a pymatgen Structure object and generates angle-distorted structures.

Expand Down Expand Up @@ -237,10 +237,10 @@ def random_vary_angle(
volume_custom_scale_factors=[1.03],
)

distorted_cells = AseAtomsAdaptor.get_atoms(distorted_cells[0])
distorted_supercells: Atoms = AseAtomsAdaptor.get_atoms(distorted_cells[0])

# getting stretched cell out of array
newcell = distorted_cells.cell.cellpar()
# getting stretched supercell out of array
newcell = distorted_supercells.cell.cellpar()

# current angles
alpha = atoms_copy.cell.cellpar()[3]
Expand Down Expand Up @@ -287,7 +287,7 @@ def std_rattle(
n_structures: int = 5,
rattle_std: float = 0.01,
rattle_seed: int = 42,
):
) -> list[Structure]:
"""
Take in a pymatgen Structure object and generates rattled structures.

Expand Down Expand Up @@ -331,7 +331,7 @@ def mc_rattle(
min_distance: float = 1.5,
rattle_seed: int = 42,
rattle_mc_n_iter: int = 10,
):
) -> list[Structure]:
"""
Take in a pymatgen Structure object and generates rattled structures.

Expand Down Expand Up @@ -375,7 +375,7 @@ def mc_rattle(
return [AseAtomsAdaptor.get_structure(xtal) for xtal in mc_rattle]


def extract_base_name(filename, is_out=False):
def extract_base_name(filename, is_out=False) -> str:
"""
Extract the base of a file name to easier manipulate other file names.

Expand All @@ -401,7 +401,7 @@ def extract_base_name(filename, is_out=False):
return "A problem with the files occurred."


def filter_outlier_energy(in_file, out_file, criteria: float = 0.0005):
def filter_outlier_energy(in_file, out_file, criteria: float = 0.0005) -> None:
"""
Filter data outliers per energy criteria and write them into files.

Expand Down Expand Up @@ -457,7 +457,9 @@ def filter_outlier_energy(in_file, out_file, criteria: float = 0.0005):
)


def filter_outlier_forces(in_file, out_file, symbol="Si", criteria: float = 0.1):
def filter_outlier_forces(
in_file, out_file, symbol="Si", criteria: float = 0.1
) -> None:
"""
Filter data outliers per force criteria and write them into files.

Expand Down Expand Up @@ -526,13 +528,14 @@ def filter_outlier_forces(in_file, out_file, symbol="Si", criteria: float = 0.1)
)


# copied from libatoms GAP tutorial page and adjusted
def energy_plot(
in_file, out_file, ax, title: str = "Plot of energy", label: str = "energy"
):
) -> None:
"""
Plot the distribution of energy per atom on the output vs the input.

Adapted and adjusted from libatoms GAP tutorial page https://libatoms.github.io/GAP/gap_fitting_tutorial.html.

Parameters
----------
in_file:
Expand Down Expand Up @@ -610,7 +613,7 @@ def force_plot(
symbol: str = "Si",
title: str = "Plot of force",
label: str = "force for ",
):
) -> float:
"""
Plot the distribution of force components per atom on the output vs the input.

Expand Down Expand Up @@ -700,7 +703,7 @@ def plot_energy_forces(
species_list: list | None = None,
train_name: str = "train.extxyz",
test_name: str = "test.extxyz",
):
) -> None:
"""
Plot energy and forces of the data.

Expand Down
6 changes: 5 additions & 1 deletion autoplex/data/phonons/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ def ml_phonon_maker_preparation(
bulk_relax_maker: ForceFieldRelaxMaker,
phonon_displacement_maker: ForceFieldStaticMaker,
static_energy_maker: ForceFieldStaticMaker,
):
) -> tuple[
ForceFieldRelaxMaker | None,
ForceFieldStaticMaker | None,
ForceFieldStaticMaker | None,
]:
"""
Prepare the MLPhononMaker for the respective MLIP model.

Expand Down
2 changes: 1 addition & 1 deletion autoplex/fitting/common/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def machine_learning_fit(
**kwargs,
):
"""
Maker for fitting potential(s).
Job for fitting potential(s).

Parameters
----------
Expand Down
27 changes: 16 additions & 11 deletions autoplex/fitting/common/regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

import traceback
from contextlib import suppress
from typing import TYPE_CHECKING, Any

import numpy as np
from scipy.spatial import ConvexHull, Delaunay

if TYPE_CHECKING:
from ase import Atoms


def set_sigma(
atoms,
Expand All @@ -21,7 +25,7 @@ def set_sigma(
element_order=None,
max_energy=20.0,
config_type_override=None,
):
) -> list[Atoms]:
"""
Handle automatic regularisation based on distance to convex hull, amongst other things.

Expand Down Expand Up @@ -216,6 +220,7 @@ def set_sigma(


def get_convex_hull(atoms, energy_name="energy", **kwargs):
# CE I don't get what the function returns
QuantumChemist marked this conversation as resolved.
Show resolved Hide resolved
"""
Calculate simple linear (E,V) convex hull.

Expand Down Expand Up @@ -276,7 +281,7 @@ def get_convex_hull(atoms, energy_name="energy", **kwargs):
return lower_half_hull_points, p


def get_e_distance_to_hull(hull: np.array, at, energy_name="energy", **kwargs):
def get_e_distance_to_hull(hull: np.array, at, energy_name="energy", **kwargs) -> float:
"""
Calculate the distance of a structure to the linear convex hull in energy.

Expand Down Expand Up @@ -315,7 +320,7 @@ def get_e_distance_to_hull(hull: np.array, at, energy_name="energy", **kwargs):
)


def get_intersect(a1, a2, b1, b2):
def get_intersect(a1, a2, b1, b2) -> tuple[float, float] | tuple:
"""
Return the point of intersection of the lines passing through a2,a1 and b2,b1.

Expand All @@ -339,7 +344,7 @@ def get_intersect(a1, a2, b1, b2):
return x / z, y / z


def get_x(at, element_order=None):
def get_x(at, element_order=None) -> float | int:
"""
Calculate the mole-fraction of a structure.

Expand Down Expand Up @@ -379,7 +384,7 @@ def get_x(at, element_order=None):

def label_stoichiometry_volume(
ats, isolated_atoms_energies, e_name, element_order=None
):
): # CE I don't get what the function returns
"""
Calculate the stoichiometry, energy, and volume coordinates for forming the convex hull.

Expand Down Expand Up @@ -412,7 +417,7 @@ def label_stoichiometry_volume(
return p.T[:, np.argsort(p.T[0])].T


def point_in_triangle_2D(p1, p2, p3, pn):
def point_in_triangle_2D(p1, p2, p3, pn) -> bool:
"""
Check if a point is inside a triangle in 2D.

Expand Down Expand Up @@ -449,7 +454,7 @@ def point_in_triangle_2D(p1, p2, p3, pn):
)


def point_in_triangle_ND(pn, *preg):
def point_in_triangle_ND(pn, *preg) -> bool:
"""
Check if a point is inside a region of hyperplanes in N dimensions.

Expand All @@ -467,7 +472,7 @@ def point_in_triangle_ND(pn, *preg):
return hull.find_simplex(pn) >= 0


def calculate_hull_3D(p):
def calculate_hull_3D(p) -> ConvexHull:
"""
Calculate the convex hull in 3D.

Expand All @@ -492,7 +497,7 @@ def calculate_hull_3D(p):
return hull


def calculate_hull_ND(p):
def calculate_hull_ND(p) -> ConvexHull:
"""
Calculate the convex hull in ND (N>=3).

Expand Down Expand Up @@ -531,7 +536,7 @@ def calculate_hull_ND(p):

def get_e_distance_to_hull_3D(
hull, at, isolated_atoms_energies=None, energy_name="energy", element_order=None
):
) -> float:
"""
Calculate the energy distance to the convex hull in 3D.

Expand Down Expand Up @@ -579,7 +584,7 @@ def get_e_distance_to_hull_3D(
return 1e6


def piecewise_linear(x, vals):
def piecewise_linear(x, vals) -> Any:
"""
Piecewise linear.

Expand Down
Loading