Skip to content

Commit

Permalink
Make models scriptable (#30)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Goscinski <alex.goscinski@posteo.de>
  • Loading branch information
frostedoyster and agoscinski authored Sep 7, 2023
1 parent d40d33c commit 728ffab
Show file tree
Hide file tree
Showing 20 changed files with 457 additions and 316 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import rascaline
rascaline._c_lib._get_library()
from torch_spex.le import Jn_zeros
from equistore import Labels
from metatensor.torch import Labels
from torch_spex.spherical_expansions import SphericalExpansion
from torch_spex.structures import InMemoryDataset, TransformerNeighborList, TransformerProperty, collate_nl

Expand Down
190 changes: 110 additions & 80 deletions examples/alchemical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from torch_spex.atomic_composition import AtomicComposition
from power_spectrum import PowerSpectrum
from torch_spex.normalize import get_average_number_of_neighbors, normalize_true, normalize_false
import equistore

from typing import Dict
from metatensor.torch import TensorMap

# Conversions

Expand Down Expand Up @@ -99,6 +101,7 @@ def get_sse(first, second):
print("Average number of atoms per structure:", average_number_of_atoms)

all_species = np.sort(np.unique(np.concatenate([train_structure.numbers for train_structure in train_structures] + [test_structure.numbers for test_structure in test_structures])))
all_species = [int(species) for species in all_species] # convert to Python ints for tracer
print(f"All species: {all_species}")


Expand Down Expand Up @@ -132,26 +135,63 @@ def __init__(self, hypers, all_species, do_forces) -> None:
self.comp_calculator = AtomicComposition(all_species)
self.composition_coefficients = None # Needs to be set from outside
self.do_forces = do_forces
self.normalize = normalize
self.average_number_of_atoms = average_number_of_atoms

def forward(self, structure_batch, is_training=True):
def forward(self, structure_batch: Dict[str, torch.Tensor], is_training: bool = True):

n_structures = len(structure_batch["positions"])
energies = torch.zeros((n_structures,), device=device, dtype=torch.get_default_dtype())
n_structures = structure_batch["cells"].shape[0]
energies = torch.zeros(
(n_structures,),
dtype=structure_batch["positions"].dtype,
device=structure_batch["positions"].device,
)

if self.do_forces:
for structure_positions in structure_batch["positions"]:
structure_positions.requires_grad = True
structure_batch["positions"].requires_grad_(True)

# print("Calculating spherical expansion")
spherical_expansion = self.spherical_expansion_calculator(**structure_batch)
spherical_expansion = self.spherical_expansion_calculator(
positions = structure_batch["positions"],
cells = structure_batch["cells"],
species = structure_batch["species"],
cell_shifts = structure_batch["cell_shifts"],
centers = structure_batch["centers"],
pairs = structure_batch["pairs"],
structure_centers = structure_batch["structure_centers"],
structure_pairs = structure_batch["structure_pairs"],
structure_offsets = structure_batch["structure_offsets"]
)
ps = self.ps_calculator(spherical_expansion)
if normalize: ps = equistore.divide(ps, 10.0) # BUG ????????????????????????//dafsdjf;asdkjfhladsjhbf

# print("Calculating energies")
self._apply_layer(energies, ps, self.nu2_model)
if normalize: energies = energies / np.sqrt(average_number_of_atoms)

comp = self.comp_calculator.compute(**structure_batch)
atomic_energies = []
structure_indices = []
for ai, layer_ai in self.nu2_model.items():
block = ps.block({"a_i": int(ai)})
# print(block.values)
features = block.values.squeeze(dim=1)
structure_indices.append(block.samples.column("structure"))
atomic_energies.append(
layer_ai(features).squeeze(dim=-1)
)
atomic_energies = torch.concat(atomic_energies)
structure_indices = torch.concatenate(structure_indices)
# print("Before aggregation", torch.mean(atomic_energies), get_2_mom(atomic_energies))
energies.index_add_(dim=0, index=structure_indices, source=atomic_energies)
if self.normalize: energies = energies * self.average_number_of_atoms**(-0.5)

comp = self.comp_calculator(
positions = structure_batch["positions"],
cells = structure_batch["cells"],
species = structure_batch["species"],
cell_shifts = structure_batch["cell_shifts"],
centers = structure_batch["centers"],
pairs = structure_batch["pairs"],
structure_centers = structure_batch["structure_centers"],
structure_pairs = structure_batch["structure_pairs"],
structure_offsets = structure_batch["structure_offsets"]
)
energies += comp @ self.composition_coefficients

# print("Computing forces by backpropagation")
Expand All @@ -162,93 +202,67 @@ def forward(self, structure_batch, is_training=True):

return energies, forces

def predict_epoch(self, data_loader):

predicted_energies = []
predicted_forces = []
for batch in data_loader:
batch.pop("energies")
batch.pop("forces")
predicted_energies_batch, predicted_forces_batch = model(batch, is_training=False)
predicted_energies.append(predicted_energies_batch)
predicted_forces.extend(predicted_forces_batch) # the predicted forces for the batch are themselves a list

predicted_energies = torch.concatenate(predicted_energies, dim=0)
predicted_forces = torch.concatenate(predicted_forces, dim=0)
return predicted_energies, predicted_forces
def predict_epoch(model, data_loader):

predicted_energies = []
predicted_forces = []
for batch in data_loader:
batch.pop("energies")
batch.pop("forces")
predicted_energies_batch, predicted_forces_batch = model(batch, is_training=False)
predicted_energies.append(predicted_energies_batch)
predicted_forces.append(predicted_forces_batch)

predicted_energies = torch.concatenate(predicted_energies, dim=0)
predicted_forces = torch.concatenate(predicted_forces, dim=0)
return predicted_energies, predicted_forces


def train_epoch(self, data_loader, force_weight):

if optimizer_name == "Adam":
def train_epoch(model, data_loader, force_weight):

if optimizer_name == "Adam":
total_loss = 0.0
for batch in data_loader:
energies = batch.pop("energies")
forces = batch.pop("forces")
optimizer.zero_grad()
predicted_energies, predicted_forces = model(batch)

loss = get_sse(predicted_energies, energies)
if do_forces:
forces = forces.to(device)
loss += force_weight * get_sse(predicted_forces, forces)
loss.backward()
optimizer.step()
total_loss += loss.item()
else:
def closure():
optimizer.zero_grad()
total_loss = 0.0
for batch in data_loader:
energies = batch.pop("energies")
forces = batch.pop("forces")
optimizer.zero_grad()
predicted_energies, predicted_forces = model(batch)

loss = get_sse(predicted_energies, energies)
if do_forces:
forces = forces.to(device)
predicted_forces = torch.concatenate(predicted_forces)
loss += force_weight * get_sse(predicted_forces, forces)
loss.backward()
optimizer.step()
total_loss += loss.item()
else:
def closure():
optimizer.zero_grad()
total_loss = 0.0
for batch in data_loader:
energies = batch.pop("energies")
forces = batch.pop("forces")
predicted_energies, predicted_forces = model(batch)

loss = get_sse(predicted_energies, energies)
if do_forces:
forces = forces.to(device)
predicted_forces = torch.concatenate(predicted_forces)
loss += force_weight * get_sse(predicted_forces, forces)
loss.backward()
total_loss += loss.item()
print(total_loss)
return total_loss

total_loss = optimizer.step(closure)
return total_loss

def _apply_layer(self, energies, tmap, layer):
atomic_energies = []
structure_indices = []
for a_i in self.all_species:
block = tmap.block(a_i=a_i)
# print(block.values)
features = block.values.squeeze(dim=1)
structure_indices.append(block.samples["structure"])
atomic_energies.append(
layer[str(a_i)](features).squeeze(dim=-1)
)
atomic_energies = torch.concat(atomic_energies)
structure_indices = torch.LongTensor(np.concatenate(structure_indices))
# print("Before aggregation", torch.mean(atomic_energies), get_2_mom(atomic_energies))

energies.index_add_(dim=0, index=structure_indices.to(device), source=atomic_energies)
# THIS IN-PLACE MODIFICATION HAS TO CHANGE!
print(total_loss)
return total_loss

# def print_state()... Would print loss, train errors, validation errors, test errors, ...
total_loss = optimizer.step(closure)
return total_loss

model = Model(hypers, all_species, do_forces=do_forces).to(device)
# print(model)

if optimizer_name == "Adam":
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
batch_size = 8 # Batch for training speed
else:
optimizer = torch.optim.LBFGS(model.parameters(), line_search_fn="strong_wolfe", history_size=128)
batch_size = 128 # Batch for memory


print("Precomputing neighborlists")

transformers = [
Expand Down Expand Up @@ -287,12 +301,28 @@ def _apply_layer(self, energies, tmap, layer):
batch.pop("energies")
batch.pop("forces")
train_comp.append(
comp_calculator.compute(**batch)
comp_calculator(**batch)
)
train_comp = torch.concatenate(train_comp)
c_comp = torch.linalg.solve(train_comp.T @ train_comp, train_comp.T @ train_energies)

model = Model(hypers, all_species, do_forces=do_forces).to(device)
model.composition_coefficients = c_comp

# Deactivate kernel fusion which slows down the model.
# With kernel fusion, our model would be recompiled at every call
# due to the varying shapes of the involved tensors (neighborlists
# can vary between different structures and batches)
# Perhaps [("DYNAMIC", 1)] can offer better performance
torch.jit.set_fusion_strategy([("DYNAMIC", 0)])
model = torch.jit.script(model)

if optimizer_name == "Adam":
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
else:
optimizer = torch.optim.LBFGS(model.parameters(), line_search_fn="strong_wolfe", history_size=128)


print("Finished linear fit for one-body energies")


Expand Down Expand Up @@ -323,15 +353,15 @@ def _apply_layer(self, energies, tmap, layer):

# print(torch.cuda.max_memory_allocated())

predicted_train_energies, predicted_train_forces = model.predict_epoch(predict_train_data_loader)
predicted_test_energies, predicted_test_forces = model.predict_epoch(predict_test_data_loader)
predicted_train_energies, predicted_train_forces = predict_epoch(model, predict_train_data_loader)
predicted_test_energies, predicted_test_forces = predict_epoch(model, predict_test_data_loader)

print()
print(f"Epoch number {epoch}, Total loss: {get_sse(predicted_train_energies, train_energies)+force_weight*get_sse(predicted_train_forces, train_forces)}")
print(f"Energy errors: Train RMSE: {get_rmse(predicted_train_energies, train_energies)}, Train MAE: {get_mae(predicted_train_energies, train_energies)}, Test RMSE: {get_rmse(predicted_test_energies, test_energies)}, Test MAE: {get_mae(predicted_test_energies, test_energies)}")
if do_forces:
print(f"Force errors: Train RMSE: {get_rmse(predicted_train_forces, train_forces)}, Train MAE: {get_mae(predicted_train_forces, train_forces)}, Test RMSE: {get_rmse(predicted_test_forces, test_forces)}, Test MAE: {get_mae(predicted_test_forces, test_forces)}")

_ = model.train_epoch(train_data_loader, force_weight)
_ = train_epoch(model, train_data_loader, force_weight)

#print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20))
21 changes: 10 additions & 11 deletions examples/power_spectrum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import numpy as np
from typing import List

from equistore import TensorMap, Labels, TensorBlock
from metatensor.torch import TensorMap, Labels, TensorBlock

class PowerSpectrum(torch.nn.Module):

Expand All @@ -10,31 +10,30 @@ def __init__(self, l_max, all_species):

self.l_max = l_max
self.all_species = all_species


def forward(self, spex):
def forward(self, spex: TensorMap):

keys = []
blocks = []
keys : List[List[int]] = []
blocks : List[TensorBlock] = []
for a_i in self.all_species:
ps_values_ai = []
for l in range(self.l_max+1):
cg = 1.0/np.sqrt(2*l+1)
block_ai_l = spex.block(lam=l, a_i=a_i)
cg = (2*l+1)**(-0.5)
block_ai_l = spex.block({"lam": l, "a_i": a_i})
c_ai_l = block_ai_l.values

# same as this:
# ps_ai_l = cg*torch.einsum("ima, imb -> iab", c_ai_l, c_ai_l)
# but faster:
ps_ai_l = cg*torch.sum(c_ai_l.unsqueeze(2)*c_ai_l.unsqueeze(3), dim=1)

ps_ai_l = ps_ai_l.reshape(c_ai_l.shape[0], c_ai_l.shape[2]**2)
ps_ai_l = ps_ai_l.reshape(c_ai_l.shape[0], c_ai_l.shape[2]*c_ai_l.shape[2])
ps_values_ai.append(ps_ai_l)
ps_values_ai = torch.concatenate(ps_values_ai, dim=-1)

block = TensorBlock(
values=ps_values_ai,
samples=block_ai_l.samples,
samples=spex.block({"lam": 0, "a_i": a_i}).samples,
components=[],
properties=Labels.range("property", ps_values_ai.shape[-1])
)
Expand All @@ -44,7 +43,7 @@ def forward(self, spex):
power_spectrum = TensorMap(
keys = Labels(
names = ("a_i",),
values = np.array(keys), # .reshape((-1, 2)),
values = torch.tensor(keys), # .reshape((-1, 2)),
),
blocks = blocks
)
Expand Down
Loading

0 comments on commit 728ffab

Please sign in to comment.