Skip to content

Commit

Permalink
Add example
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jan 6, 2025
1 parent ccd82ef commit d9d6f03
Show file tree
Hide file tree
Showing 13 changed files with 119 additions and 42 deletions.
3 changes: 3 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ furo # sphinx theme
# jax[cpu], because python -m pip install jax, which would be triggered
# by the main package's dependencies, does not install jaxlib
jax[cpu] >= 0.4.18

# metatensor and metatensor-torch for the metatensor API
metatensor-torch
1 change: 1 addition & 0 deletions docs/src/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ different languages and frameworks it supports.
python-api
pytorch-api
jax-api
metatensor-api

Although the Julia API is not fully documented yet, basic usage examples are available
`here <https://github.com/lab-cosmo/sphericart/blob/main/julia/README.md>`_.
1 change: 1 addition & 0 deletions docs/src/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"python": ("https://docs.python.org/3", None),
"torch": ("https://pytorch.org/docs/stable/", None),
"e3nn": ("https://docs.e3nn.org/en/latest/", None),
"metatensor": ("https://docs.metatensor.org/latest/index.html", None),
}

html_theme = "furo"
Expand Down
1 change: 1 addition & 0 deletions docs/src/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ floating-point arithmetics, and they evaluate the mean relative error between th
pytorch-examples
jax-examples
spherical-complex
metatensor-examples

Although comprehensive Julia examples are not fully available yet, basic usage is illustrated
`here <https://github.com/lab-cosmo/sphericart/blob/main/julia/README.md>`_.
2 changes: 1 addition & 1 deletion docs/src/jax-api.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
JAX API
===========
=======

The `sphericart.jax` module aims to provide a functional-style and
`JAX`-friendly framework. As a result, it does not follow the same syntax as
Expand Down
28 changes: 28 additions & 0 deletions docs/src/metatensor-api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
Metatensor API
==============

``sphericart`` can be used in conjunction with
`metatensor <https://docs.metatensor.org/latest/index.html>`_ in order to attach
metadata to inputs and outputs, as well as to naturally obtain spherical harmonics,
gradients and Hessians in a single object.

Here is the API reference for the ``sphericart.metatensor`` and
``sphericart.metatensor.torch`` modules.

sphericart.metatensor
---------------------

.. autoclass:: sphericart.metatensor.SphericalHarmonics
:members:

.. autoclass:: sphericart.metatensor.SolidHarmonics
:members:

sphericart.metatensor.torch
---------------------------

.. autoclass:: sphericart.metatensor.torch.SphericalHarmonics
:members:

.. autoclass:: sphericart.metatensor.torch.SolidHarmonics
:members:
13 changes: 13 additions & 0 deletions docs/src/metatensor-examples.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Using sphericart with metatensor
--------------------------------

``sphericart`` can be used in conjunction with
`metatensor <https://docs.metatensor.org/latest/index.html>`_ in order to attach
metadata to inputs and outputs, as well as to naturally obtain spherical harmonics,
gradients and Hessians in a single object.

This example shows how to use the ``sphericart.metatensor`` module to compute
spherical harmonics, their gradients and their Hessians.

.. literalinclude:: ../../examples/metatensor/example.py
:language: python
54 changes: 54 additions & 0 deletions examples/metatensor/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
from metatensor import Labels, TensorBlock, TensorMap

import sphericart
import sphericart.metatensor


l_max = 15
n_samples = 100

xyz = TensorMap(
keys=Labels.single(),
blocks=[
TensorBlock(
values=np.random.rand(n_samples, 3, 1),
samples=Labels(
names=["sample"],
values=np.arange(n_samples).reshape(-1, 1),
),
components=[
Labels(
names=["xyz"],
values=np.arange(3).reshape(-1, 1),
)
],
properties=Labels.single(),
)
],
)

calculator = sphericart.metatensor.SphericalHarmonics(l_max)

spherical_harmonics = calculator.compute(xyz)

for single_l in range(l_max + 1):
spherical_single_l = spherical_harmonics.block({"o3_lambda": single_l})

# check values against pure sphericart
assert np.allclose(
spherical_single_l.values.squeeze(-1),
sphericart.SphericalHarmonics(single_l).compute(
xyz.block().values.squeeze(-1)
)[:, single_l**2 : (single_l + 1) ** 2],
)

# further example: obtaining gradients of l = 2 spherical harmonics
spherical_harmonics = calculator.compute_with_gradients(xyz)
l_2_gradients = spherical_harmonics.block({"o3_lambda": 2}).gradient("positions")

# further example: obtaining Hessians of l = 2 spherical harmonics
spherical_harmonics = calculator.compute_with_hessians(xyz)
l_2_hessians = spherical_harmonics.block(
{"o3_lambda": 2}
).gradient("positions").gradient("positions")
1 change: 1 addition & 0 deletions python/src/sphericart/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .spherical_harmonics import SphericalHarmonics, SolidHarmonics # noqa
from . import metatensor # noqa
31 changes: 8 additions & 23 deletions python/src/sphericart/metatensor/spherical_harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,11 @@ def __init__(self, l_max: int):
]
self.precomputed_xyz_components = Labels(
names=["xyz"],
values=np.arange(2).reshape(-1, 1),
)
self.precomputed_xyz_1_components = Labels(
names=["xyz_1"],
values=np.arange(2).reshape(-1, 1),
values=np.arange(3).reshape(-1, 1),
)
self.precomputed_xyz_2_components = Labels(
names=["xyz_2"],
values=np.arange(2).reshape(-1, 1),
values=np.arange(3).reshape(-1, 1),
)
self.precomputed_properties = Labels.single()

Expand All @@ -56,7 +52,6 @@ def compute(self, xyz: TensorMap) -> TensorMap:
xyz.block().samples,
self.precomputed_mu_components,
self.precomputed_xyz_components,
self.precomputed_xyz_1_components,
self.precomputed_xyz_2_components,
self.precomputed_properties,
)
Expand All @@ -72,7 +67,6 @@ def compute_with_gradients(self, xyz: TensorMap) -> TensorMap:
xyz.block().samples,
self.precomputed_mu_components,
self.precomputed_xyz_components,
self.precomputed_xyz_1_components,
self.precomputed_xyz_2_components,
self.precomputed_properties,
sh_gradients,
Expand All @@ -88,10 +82,9 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap:
self.precomputed_keys,
xyz.block().samples,
self.precomputed_mu_components,
self.precomputed_properties,
self.precomputed_xyz_components,
self.precomputed_xyz_1_components,
self.precomputed_xyz_2_components,
self.precomputed_properties,
sh_gradients,
sh_hessians,
)
Expand All @@ -117,15 +110,11 @@ def __init__(self, l_max: int):
]
self.precomputed_xyz_components = Labels(
names=["xyz"],
values=np.arange(2).reshape(-1, 1),
)
self.precomputed_xyz_1_components = Labels(
names=["xyz_1"],
values=np.arange(2).reshape(-1, 1),
values=np.arange(3).reshape(-1, 1),
)
self.precomputed_xyz_2_components = Labels(
names=["xyz_2"],
values=np.arange(2).reshape(-1, 1),
values=np.arange(3).reshape(-1, 1),
)
self.precomputed_properties = Labels.single()

Expand All @@ -138,7 +127,6 @@ def compute(self, xyz: np.ndarray) -> TensorMap:
xyz.block().samples,
self.precomputed_mu_components,
self.precomputed_xyz_components,
self.precomputed_xyz_1_components,
self.precomputed_xyz_2_components,
self.precomputed_properties,
)
Expand All @@ -154,7 +142,6 @@ def compute_with_gradients(self, xyz: np.ndarray) -> TensorMap:
xyz.block().samples,
self.precomputed_mu_components,
self.precomputed_xyz_components,
self.precomputed_xyz_1_components,
self.precomputed_xyz_2_components,
self.precomputed_properties,
sh_gradients,
Expand All @@ -171,7 +158,6 @@ def compute_with_hessians(self, xyz: np.ndarray) -> TensorMap:
xyz.block().samples,
self.precomputed_mu_components,
self.precomputed_xyz_components,
self.precomputed_xyz_1_components,
self.precomputed_xyz_2_components,
self.precomputed_properties,
sh_gradients,
Expand All @@ -198,7 +184,6 @@ def _wrap_into_tensor_map(
samples: Labels,
components: List[Labels],
xyz_components: Labels,
xyz_1_components: Labels,
xyz_2_components: Labels,
properties: Labels,
sh_gradients: Optional[np.ndarray] = None,
Expand All @@ -223,17 +208,17 @@ def _wrap_into_tensor_map(
sh_gradients_block = metatensor_module.TensorBlock(
values=sh_gradients[:, :, l_start:l_end, None],
samples=samples,
components=[components[l], xyz_components],
components=[xyz_components, components[l]],
properties=properties,
)
if sh_hessians is not None:
sh_hessians_block = metatensor_module.TensorBlock(
values=sh_hessians[:, :, :, l_start:l_end, None],
samples=samples,
components=[
components[l],
xyz_1_components,
xyz_2_components,
xyz_components,
components[l],
],
properties=properties,
)
Expand Down
2 changes: 2 additions & 0 deletions sphericart-torch/python/sphericart/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from ._build_torch_version import BUILD_TORCH_VERSION
import re

from . import metatensor # noqa


def parse_version_string(version_string):
match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string)
Expand Down
22 changes: 4 additions & 18 deletions sphericart-torch/python/sphericart/torch/metatensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,11 @@ def __init__(self, l_max: int):
]
self.precomputed_xyz_components = Labels(
names=["xyz"],
values=torch.arange(2).reshape(-1, 1),
)
self.precomputed_xyz_1_components = Labels(
names=["xyz_1"],
values=torch.arange(2).reshape(-1, 1),
values=torch.arange(3).reshape(-1, 1),
)
self.precomputed_xyz_2_components = Labels(
names=["xyz_2"],
values=torch.arange(2).reshape(-1, 1),
values=torch.arange(3).reshape(-1, 1),
)
self.precomputed_properties = Labels.single()

Expand All @@ -59,7 +55,6 @@ def compute(self, xyz: TensorMap) -> TensorMap:
xyz.block().samples,
self.precomputed_mu_components,
self.precomputed_xyz_components,
self.precomputed_xyz_1_components,
self.precomputed_xyz_2_components,
self.precomputed_properties,
metatensor_module=metatensor.torch,
Expand All @@ -76,7 +71,6 @@ def compute_with_gradients(self, xyz: TensorMap) -> TensorMap:
xyz.block().samples,
self.precomputed_mu_components,
self.precomputed_xyz_components,
self.precomputed_xyz_1_components,
self.precomputed_xyz_2_components,
self.precomputed_properties,
sh_gradients,
Expand All @@ -95,7 +89,6 @@ def compute_with_hessians(self, xyz: TensorMap) -> TensorMap:
self.precomputed_mu_components,
self.precomputed_properties,
self.precomputed_xyz_components,
self.precomputed_xyz_1_components,
self.precomputed_xyz_2_components,
sh_gradients,
sh_hessians,
Expand Down Expand Up @@ -123,15 +116,11 @@ def __init__(self, l_max: int):
]
self.precomputed_xyz_components = Labels(
names=["xyz"],
values=torch.arange(2).reshape(-1, 1),
)
self.precomputed_xyz_1_components = Labels(
names=["xyz_1"],
values=torch.arange(2).reshape(-1, 1),
values=torch.arange(3).reshape(-1, 1),
)
self.precomputed_xyz_2_components = Labels(
names=["xyz_2"],
values=torch.arange(2).reshape(-1, 1),
values=torch.arange(3).reshape(-1, 1),
)
self.precomputed_properties = Labels.single()

Expand All @@ -144,7 +133,6 @@ def compute(self, xyz: torch.Tensor) -> TensorMap:
xyz.block().samples,
self.precomputed_mu_components,
self.precomputed_xyz_components,
self.precomputed_xyz_1_components,
self.precomputed_xyz_2_components,
self.precomputed_properties,
metatensor_module=metatensor.torch,
Expand All @@ -161,7 +149,6 @@ def compute_with_gradients(self, xyz: torch.Tensor) -> TensorMap:
xyz.block().samples,
self.precomputed_mu_components,
self.precomputed_xyz_components,
self.precomputed_xyz_1_components,
self.precomputed_xyz_2_components,
self.precomputed_properties,
sh_gradients,
Expand All @@ -179,7 +166,6 @@ def compute_with_hessians(self, xyz: torch.Tensor) -> TensorMap:
xyz.block().samples,
self.precomputed_mu_components,
self.precomputed_xyz_components,
self.precomputed_xyz_1_components,
self.precomputed_xyz_2_components,
self.precomputed_properties,
sh_gradients,
Expand Down
2 changes: 2 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ deps =
numpy<2.0.0
torch
pytest
metatensor

passenv=
PIP_EXTRA_INDEX_URL
Expand All @@ -116,6 +117,7 @@ commands =
python examples/python/example.py
python examples/pytorch/example.py
python examples/jax/example.py
python examples/metatensor/example.py

python examples/python/spherical.py
python examples/python/complex.py
Expand Down

0 comments on commit d9d6f03

Please sign in to comment.