Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Descriptors and some other stuff WIP #67

Merged
merged 3 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 114 additions & 31 deletions openqdc/utils/descriptors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, List

import datamol as dm
import numpy as np
from ase.atoms import Atoms
from numpy import ndarray
Expand All @@ -10,25 +11,96 @@


class Descriptor(ABC):
"""
Base class for all descriptors.
Descriptors are used to transform 3D atomic structures into feature vectors.
"""

_model: Any

def __init__(self, *, species: List[str], **kwargs):
def __init__(self, *, species: List[str], **kwargs) -> None:
"""
Parameters
----------
species : List[str]
List of chemical species for the descriptor embedding.
kwargs : dict
Additional keyword arguments to be passed to the descriptor model.
"""
self.chemical_species = species
self._model = self.instantiate_model(**kwargs)

@property
def model(self):
def model(self) -> Any:
"""Simple property that returns the model."""
return self._model

@abstractmethod
def instantiate_model(self, **kwargs) -> Any:
"""
Instantiate the descriptor model with the provided kwargs parameters
and return it. The model will be stored in the _model attribute.
If a package is required to instantiate the model, it should be checked
using the requires_package decorator or in the method itself.

Parameters
----------
kwargs : dict
Additional keyword arguments to be passed to the descriptor model.
"""
raise NotImplementedError

@abstractmethod
def calculate(self, atoms: Atoms) -> ndarray:
def calculate(self, atoms: Atoms, **kwargs) -> ndarray:
"""
Calculate the descriptor for a single given Atoms object.

Parameters
----------
atoms : Atoms
Ase Atoms object to calculate the descriptor for.

Returns
-------
ndarray
ndarray containing the descriptor values
"""
raise NotImplementedError

def from_xyz(self, positions: np.ndarray, atomic_numbers: np.ndarray):
def fit_transform(self, atoms: List[Atoms], **kwargs) -> List[ndarray]:
"""Parallelized version of the calculate method.
Parameters
----------
atoms : List[Atoms]
List of Ase Atoms object to calculate the descriptor for.
kwargs : dict
Additional keyword arguments to be passed to the datamol parallelized model.

Returns
-------
List[ndarray]
List of ndarray containing the descriptor values
"""

descr_values = dm.parallelized(self.calculate, atoms, scheduler="threads", **kwargs)
return descr_values

def from_xyz(self, positions: np.ndarray, atomic_numbers: np.ndarray) -> ndarray:
"""
Calculate the descriptor from positions and atomic numbers of a single structure.

Parameters
----------
positions : np.ndarray (n_atoms, 3)
Positions of the chemical structure.
atomic_numbers : np.ndarray (n_atoms,)
Atomic numbers of the chemical structure.

Returns
-------
ndarray
ndarray containing the descriptor values
"""
atoms = to_atoms(positions, atomic_numbers)
return self.calculate(atoms)

Expand Down Expand Up @@ -61,31 +133,12 @@ def instantiate_model(self, **kwargs):
compression=compression,
)

def calculate(self, atoms: Atoms) -> ndarray:
return self.model.create(atoms, centers=atoms.positions)


class MBTR(SOAP):
@requires_package("dscribe")
def instantiate_model(self, **kwargs):
from dscribe.descriptors import MBTR as MBTRModel

r_cut = kwargs.pop("r_cut", 5.0)
geometry = kwargs.pop("geometry", {"function": "inverse_distance"})
grid = kwargs.pop("grid", {"min": 0, "max": 1, "n": 100, "sigma": 0.1})
weighting = kwargs.pop("weighting", {"function": "exp", "scale": 0.5, "threshold": 1e-3})
normalization = kwargs.pop("normalization", "l2")
periodic = kwargs.pop("periodic", False)

return MBTRModel(
species=self.chemical_species,
periodic=periodic,
r_cut=r_cut,
geometry=geometry,
grid=grid,
weighting=weighting,
normalization=normalization,
)
def calculate(self, atoms: Atoms, **kwargs) -> ndarray:
kwargs = kwargs or {}
if "centers" not in kwargs:
# add a center to every atom
kwargs["centers"] = list(range(len(atoms.positions)))
return self.model.create(atoms, **kwargs)


class ACSF(SOAP):
Expand All @@ -95,7 +148,7 @@ def instantiate_model(self, **kwargs):

r_cut = kwargs.pop("r_cut", 5.0)
g2_params = kwargs.pop("g2_params", [[1, 1], [1, 2], [1, 3]])
g3_params = kwargs.pop("g3_params", [[1], [1], [1], [2]])
g3_params = kwargs.pop("g3_params", [1, 1, 2, -1])
g4_params = kwargs.pop("g4_params", [[1, 1, 1], [1, 2, 1], [1, 1, -1], [1, 2, -1]])
g5_params = kwargs.pop("g5_params", [[1, 2, -1], [1, 1, 1], [-1, 1, 1], [1, 2, 1]])
periodic = kwargs.pop("periodic", False)
Expand All @@ -111,14 +164,44 @@ def instantiate_model(self, **kwargs):
)


class MBTR(SOAP):
@requires_package("dscribe")
def instantiate_model(self, **kwargs):
from dscribe.descriptors import MBTR as MBTRModel

geometry = kwargs.pop("geometry", {"function": "inverse_distance"})
grid = kwargs.pop("grid", {"min": 0, "max": 1, "n": 100, "sigma": 0.1})
weighting = kwargs.pop("weighting", {"function": "exp", "r_cut": 5, "threshold": 1e-3})
normalization = kwargs.pop("normalization", "l2")
normalize_gaussians = kwargs.pop("normalize_gaussians", True)
periodic = kwargs.pop("periodic", False)

return MBTRModel(
species=self.chemical_species,
periodic=periodic,
geometry=geometry,
grid=grid,
weighting=weighting,
normalize_gaussians=normalize_gaussians,
normalization=normalization,
)

def calculate(self, atoms: Atoms, **kwargs) -> ndarray:
return self.model.create(atoms, **kwargs)


# Dynamic mapping of available descriptors
AVAILABLE_DESCRIPTORS = {
str_name.lower(): cls
for str_name, cls in globals().items()
if isinstance(cls, type) and issubclass(cls, Descriptor) and str_name != "Descriptor"
if isinstance(cls, type) and issubclass(cls, Descriptor) and str_name != "Descriptor" # Exclude the base class
}


def get_descriptor(name: str) -> Descriptor:
"""
Utility function that returns a descriptor class from its name.
"""
try:
return AVAILABLE_DESCRIPTORS[name.lower()]
except KeyError:
Expand Down
35 changes: 35 additions & 0 deletions tests/test_descriptors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest

from openqdc import Dummy
from openqdc.utils.descriptors import ACSF, MBTR, SOAP, Descriptor


@pytest.fixture
def dummy():
return Dummy()


@pytest.mark.parametrize("model", [SOAP, ACSF, MBTR])
def test_init(model):
model = model(species=["H"])
assert isinstance(model, Descriptor)


@pytest.mark.parametrize("model", [SOAP, ACSF, MBTR])
def test_descriptor(model, dummy):
model = model(species=dummy.chemical_species)
results = model.fit_transform([dummy.get_ase_atoms(i) for i in range(4)])
assert len(results) == 4


@pytest.mark.parametrize("model", [SOAP, ACSF, MBTR])
def test_from_positions(model):
model = model(species=["H"])
_ = model.from_xyz([[0, 0, 0], [1, 1, 1]], [1, 1])


@pytest.mark.parametrize(
"model,override", [(SOAP, {"r_cut": 3.0}), (ACSF, {"r_cut": 3.0}), (MBTR, {"normalize_gaussians": False})]
)
def test_overwrite(model, override, dummy):
model = model(species=dummy.chemical_species, **override)
Loading