Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Feb 9, 2024
2 parents 5438575 + b7c7531 commit 183a3a5
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 11 deletions.
5 changes: 4 additions & 1 deletion examples/power_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def forward(self, spex: TensorMap):
values=ps_values_ai,
samples=spex.block({"lam": 0, "a_i": a_i}).samples,
components=[],
properties=Labels.range("property", ps_values_ai.shape[-1])
properties=Labels(
"property",
torch.arange(ps_values_ai.shape[-1], device=ps_values_ai.device).reshape(-1, 1)
)
)
keys.append([a_i])
blocks.append(block)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_spherical_expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TestEthanol1SphericalExpansion:

def test_vector_expansion_coeffs(self):
tm_ref = metatensor.torch.load("tests/data/vector_expansion_coeffs-ethanol1_0-data.npz")
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
tm_ref = tm_ref.to(device=self.device, dtype=self.dtype)
# we need to sort both computed and reference pair expansion coeffs,
# because ase.neighborlist can get different neighborlist order for some reasons
tm_ref = metatensor.torch.sort(tm_ref)
Expand All @@ -51,7 +51,7 @@ def test_vector_expansion_coeffs(self):

def test_spherical_expansion_coeffs(self):
tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-ethanol1_0-data.npz")
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
tm_ref = tm_ref.to(device=self.device, dtype=self.dtype)
spherical_expansion_calculator = SphericalExpansion(self.hypers, self.all_species).to(self.device, self.dtype)
with torch.no_grad():
tm = spherical_expansion_calculator.forward(**self.batch)
Expand All @@ -70,7 +70,7 @@ def test_spherical_expansion_coeffs_alchemical(self):
with open("tests/data/expansion_coeffs-ethanol1_0-alchemical-hypers.json", "r") as f:
hypers = json.load(f)
tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-ethanol1_0-alchemical-seed0-data.npz")
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
tm_ref = tm_ref.to(device=self.device, dtype=self.dtype)
torch.manual_seed(0)
spherical_expansion_calculator = SphericalExpansion(hypers, self.all_species).to(self.device, self.dtype)
# Because setting seed seems not be enough to get the same initial combination matrix
Expand Down Expand Up @@ -111,7 +111,7 @@ class TestArtificialSphericalExpansion:

def test_vector_expansion_coeffs(self):
tm_ref = metatensor.torch.load("tests/data/vector_expansion_coeffs-artificial-data.npz")
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
tm_ref = tm_ref.to(device=self.device, dtype=self.dtype)
tm_ref = metatensor.torch.sort(tm_ref)
vector_expansion = VectorExpansion(self.hypers, self.all_species).to(self.device, self.dtype)
with torch.no_grad():
Expand All @@ -120,7 +120,7 @@ def test_vector_expansion_coeffs(self):

def test_spherical_expansion_coeffs(self):
tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-artificial-data.npz")
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
tm_ref = tm_ref.to(device=self.device, dtype=self.dtype)
spherical_expansion_calculator = SphericalExpansion(self.hypers, self.all_species).to(self.device, self.dtype)
with torch.no_grad():
tm = spherical_expansion_calculator.forward(**self.batch)
Expand All @@ -132,7 +132,7 @@ def test_spherical_expansion_coeffs_artificial(self):
with open("tests/data/expansion_coeffs-artificial-alchemical-hypers.json", "r") as f:
hypers = json.load(f)
tm_ref = metatensor.torch.load("tests/data/spherical_expansion_coeffs-artificial-alchemical-seed0-data.npz")
tm_ref = metatensor.torch.to(tm_ref, device=self.device, dtype=self.dtype)
tm_ref = tm_ref.to(device=self.device, dtype=self.dtype)
spherical_expansion_calculator = SphericalExpansion(hypers, self.all_species).to(self.device, self.dtype)
with torch.no_grad():
spherical_expansion_calculator.vector_expansion_calculator.radial_basis_calculator.combination_matrix.weight.copy_(
Expand Down
5 changes: 4 additions & 1 deletion torch_spex/radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def __init__(self, hypers, all_species) -> None:
self.is_alchemical = False
self.n_pseudo_species = 0 # dummy for torchscript
self.combination_matrix = torch.nn.Linear(1, 1) # dummy for torchscript
self.species_neighbor_labels = Labels.empty("dummy")
self.species_neighbor_labels = Labels(
names=["dummy"],
values=torch.empty((0, 1), dtype=torch.int)
)

self.apply_mlp = False
if hypers["mlp"]:
Expand Down
12 changes: 9 additions & 3 deletions torch_spex/spherical_expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class SphericalExpansion(torch.nn.Module):
>>> loader = DataLoader(dataset, batch_size=1, collate_fn=collate_nl)
>>> batch = next(iter(loader))
>>> spherical_expansion = SphericalExpansion(hypers, [1, 8])
>>> expansion = spherical_expansion.forward(**batch)
>>> expansion = spherical_expansion(**batch)
>>> print(expansion.keys)
Labels(
a_i lam sigma
Expand Down Expand Up @@ -363,7 +363,10 @@ def forward(self,
)
)
else:
properties = Labels.range("n", n_max_l)
properties = Labels(
names=["n"],
values = torch.arange(n_max_l, device=vector_expansion_l.device).reshape(n_max_l, 1)
)
vector_expansion_blocks.append(
TensorBlock(
values = vector_expansion_l.reshape(vector_expansion_l.shape[0], 2*l+1, -1),
Expand Down Expand Up @@ -427,7 +430,10 @@ def get_cartesian_vectors(positions, cells, species, cell_shifts, centers, pairs
values = torch.tensor([-1, 0, 1], dtype=torch.int32, device=direction_vectors.device).reshape((-1, 1))
)
],
properties = Labels.single().to(direction_vectors.device)
properties = Labels(
names=["_"],
values=torch.zeros((1, 1), dtype=torch.int, device=direction_vectors.device)
)
)

return block
4 changes: 4 additions & 0 deletions torch_spex/splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def __init__(
self.spline_values = concatenated_values[sort_indices]
self.spline_derivatives = concatenated_derivatives[sort_indices]

self.spline_positions = self.spline_positions.to(torch.get_default_dtype())
self.spline_values = self.spline_values.to(torch.get_default_dtype())
self.spline_derivatives = self.spline_derivatives.to(torch.get_default_dtype())

def compute(self, positions):
x = positions
delta_x = self.spline_positions[1] - self.spline_positions[0]
Expand Down

0 comments on commit 183a3a5

Please sign in to comment.