diff --git a/CHANGELOG.md b/CHANGELOG.md index 064e09c4..99159037 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,3 +8,5 @@ `ConfigurationSelection` Node - add `memory = zntrack.params(1000)` to `ConfigurationComparison` - add `threshold: float = zntrack.params(None)` to `KernelSelection` +- add `reduction_axis = zntrack.params` and + `dim_reduction: str = zntrack.params` to `ThresholdSelection` diff --git a/ipsuite/configuration_selection/threshold.py b/ipsuite/configuration_selection/threshold.py index 8fd3d488..fc745ecb 100644 --- a/ipsuite/configuration_selection/threshold.py +++ b/ipsuite/configuration_selection/threshold.py @@ -10,6 +10,29 @@ from ipsuite.configuration_selection import ConfigurationSelection +def mean_reduction(values, axis): + return np.mean(values, axis=axis) + + +def max_reduction(values, axis): + return np.max(values, axis=axis) + + +def check_dimension(values): + if values.ndim > 1: + raise ValueError( + f"Value dimension is {values.ndim} != 1. " + "Reduce the dimension by defining dim_reduction, " + "use mean or max to get (n_structures,) shape." + ) + + +REDUCTIONS = { + "mean": mean_reduction, + "max": max_reduction, +} + + class ThresholdSelection(ConfigurationSelection): """Select atoms based on a given threshold. @@ -19,7 +42,7 @@ class ThresholdSelection(ConfigurationSelection): Attributes ---------- key: str - the key in 'calc.results' to select from + The key in 'calc.results' to select from threshold: float, optional All values above (or below if negative) this threshold will be selected. If n_configurations is given, 'self.threshold' will be prioritized, @@ -28,16 +51,25 @@ class ThresholdSelection(ConfigurationSelection): For visualizing the selection a reference value can be given. For 'energy_uncertainty' this would typically be 'energy'. n_configurations: int, optional - number of configurations to select. + Number of configurations to select. min_distance: int, optional - minimum distance between selected configurations. + Minimum distance between selected configurations. + dim_reduction: str, optional + Reduces the dimensionality of the chosen uncertainty along the specified axis + by calculating either the maximum or mean value. + + Choose from ["max", "mean"] + reduction_axis: tuple(int), optional + Specifies the axis along which the reduction occurs. """ - key = zntrack.zn.params("energy_uncertainty") - reference = zntrack.zn.params("energy") - threshold = zntrack.zn.params(None) - n_configurations = zntrack.zn.params(None) - min_distance: int = zntrack.zn.params(1) + key = zntrack.params("energy_uncertainty") + reference = zntrack.params("energy") + threshold = zntrack.params(None) + n_configurations = zntrack.params(None) + min_distance: int = zntrack.params(1) + dim_reduction: str = zntrack.params(None) + reduction_axis = zntrack.params((1, 2)) def _post_init_(self): if self.threshold is None and self.n_configurations is None: @@ -45,7 +77,9 @@ def _post_init_(self): return super()._post_init_() - def select_atoms(self, atoms_lst: typing.List[ase.Atoms]) -> typing.List[int]: + def select_atoms( + self, atoms_lst: typing.List[ase.Atoms], save_fig: bool = True + ) -> typing.List[int]: """Take every nth (step) object of a given atoms list. Parameters @@ -58,7 +92,16 @@ def select_atoms(self, atoms_lst: typing.List[ase.Atoms]) -> typing.List[int]: typing.List[int]: list containing the taken indices """ + + self.reduction_axis = tuple(self.reduction_axis) values = np.array([atoms.calc.results[self.key] for atoms in atoms_lst]) + + if self.dim_reduction is not None: + reduction_fn = REDUCTIONS[self.dim_reduction] + values = reduction_fn(values, self.reduction_axis) + + check_dimension(values) + if self.threshold is not None: if self.threshold < 0: indices = np.where(values < self.threshold)[0] @@ -74,6 +117,11 @@ def select_atoms(self, atoms_lst: typing.List[ase.Atoms]) -> typing.List[int]: else: indices = np.argsort(values) + selection = self.get_selection(indices) + + return selection + + def get_selection(self, indices): selected = [] for val in indices: # If the value is close to any of the already selected values, skip it. @@ -92,6 +140,9 @@ def _get_plot(self, atoms_lst: typing.List[ase.Atoms], indices: typing.List[int] reference = np.array( [atoms.calc.results[self.reference] for atoms in atoms_lst] ) + if reference.ndim > 1: + reference = np.max(reference, axis=self.reduction_axis) + fig, ax, _ = plot_with_uncertainty( {"std": values, "mean": reference}, ylabel=self.key, diff --git a/tests/conftest.py b/tests/conftest.py index 68b86293..4ef1bca1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,6 +50,8 @@ def atoms_list() -> typing.List[ase.Atoms]: atoms=atom, energy=idx / 21, forces=np.random.randn(2, 3), + energy_uncertainty=idx + 2, + forces_uncertainty=np.full((2, 3), 2.0) + idx, ) return atoms diff --git a/tests/unit_tests/configuration_selection/test_threshold.py b/tests/unit_tests/configuration_selection/test_threshold.py new file mode 100644 index 00000000..87efe45d --- /dev/null +++ b/tests/unit_tests/configuration_selection/test_threshold.py @@ -0,0 +1,36 @@ +import numpy as np +import pytest + +from ipsuite.configuration_selection import ThresholdSelection + + +@pytest.mark.parametrize( + "key, reference, dim_reduction, reduction_axis", + [ + ("energy_uncertainty", "energy", None, (1, 2)), + ("forces_uncertainty", "forces", "max", (1, 2)), + ("forces_uncertainty", "forces", "mean", (1, 2)), + ("forces_uncertainty", "forces", None, (1, 2)), + ], +) +def test_get_selected_atoms(atoms_list, key, reference, dim_reduction, reduction_axis): + threshold = ThresholdSelection( + key=key, + reference=reference, + dim_reduction=dim_reduction, + reduction_axis=reduction_axis, + data=None, + threshold=1.0, + n_configurations=5, + min_distance=5, + ) + + if "forces_uncertainty" in key and dim_reduction is None: + with pytest.raises(ValueError): + selected_atoms = threshold.select_atoms(atoms_list, save_fig=False) + else: + selected_atoms = threshold.select_atoms(atoms_list, save_fig=False) + test_selection = np.linspace(20, 0, 5, dtype=int).tolist() + assert len(set(selected_atoms)) == 5 + assert isinstance(selected_atoms, list) + assert selected_atoms == test_selection