Skip to content

Commit

Permalink
Fix pbc conv args Final Draft (#306)
Browse files Browse the repository at this point in the history
* rebase to recent Justin changes

* revise implementation and adjust Lennard Jones to use pbc correctly

* Add data.pbc to all examples and dataset creations

* polish naming and adjust EGCL stack to use pbc in all operations

* Revise scf stack with vector encoding and pos update

* refactor accounting for parent _embedding() override and which models handle edge_attr

* polish formatting and add super()._embedding()

* revising shift zeros creation

* typos

* require data.pbc

* typo data.pbc

* typo

* fix assert typo

* change open catalyst example to not have data.pbs assignment

* slight comment change
  • Loading branch information
RylieWeaver authored Nov 9, 2024
1 parent c43a294 commit 98e8bc3
Show file tree
Hide file tree
Showing 14 changed files with 144 additions and 77 deletions.
12 changes: 6 additions & 6 deletions examples/LennardJones/LJ_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def __init__(self, dirpath, config, dist=False, sampling=None):
self.dataset.append(self.transform_input_to_data_object_base(filepath))

def transform_input_to_data_object_base(self, filepath):

# Using readline()
file = open(filepath, "r")

Expand Down Expand Up @@ -174,6 +173,11 @@ def transform_input_to_data_object_base(self, filepath):
.unsqueeze(0)
.to(torch.float32),
energy=torch.tensor(total_energy).unsqueeze(0).to(torch.float32),
pbc=[
True,
True,
True,
], # LJ example always has periodic boundary conditions
)

# Create pbc edges and lengths
Expand Down Expand Up @@ -205,7 +209,6 @@ def deterministic_graph_data(
unit_cell_z_range: list = [3, 4],
relative_maximum_atomic_displacement: float = 1e-1,
):

comm = MPI.COMM_WORLD
comm_size = comm.Get_size()
comm_rank = comm.Get_rank()
Expand Down Expand Up @@ -330,6 +333,7 @@ def create_configuration(
data.supercell_size = torch.diag(
torch.tensor([supercell_size_x, supercell_size_y, supercell_size_z])
)
data.pbc = [True, True, True]

create_graph_connectivity_pbc = get_radius_graph_pbc(
radius_cutoff, max_num_neighbors
Expand Down Expand Up @@ -379,27 +383,23 @@ class AtomicStructureHandler:
def __init__(
self, list_atom_types, bravais_lattice_constants, radius_cutoff, formula
):

self.bravais_lattice_constants = bravais_lattice_constants
self.radius_cutoff = radius_cutoff
self.formula = formula

def compute(self, data):

assert data.pos.shape[0] == data.x.shape[0]

interatomic_potential = torch.zeros([data.pos.shape[0], 1])
interatomic_forces = torch.zeros([data.pos.shape[0], 3])

for node_id in range(data.pos.shape[0]):

neighbor_list_indices = torch.where(data.edge_index[0, :] == node_id)[
0
].tolist()
neighbor_list = data.edge_index[1, neighbor_list_indices]

for neighbor_id, edge_id in zip(neighbor_list, neighbor_list_indices):

neighbor_pos = data.pos[neighbor_id, :]
distance_vector = data.pos[neighbor_id, :] - data.pos[node_id, :]

Expand Down
24 changes: 18 additions & 6 deletions hydragnn/models/DIMEStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch_geometric.utils import scatter

from .Base import Base
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths


class DIMEStack(Base):
Expand Down Expand Up @@ -144,23 +145,34 @@ def get_conv(self, input_dim, output_dim):
)

def _embedding(self, data):
super()._embedding(data)

assert (
data.pos is not None
), "DimeNet requires node positions (data.pos) to be set."

# Calculate triplet indices
i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(
data.edge_index, num_nodes=data.x.size(0)
)
dist = (data.pos[i] - data.pos[j]).pow(2).sum(dim=-1).sqrt()

# Calculate angles.
pos_i = data.pos[idx_i]
pos_ji, pos_ki = data.pos[idx_j] - pos_i, data.pos[idx_k] - pos_i
# Calculate edge_vec and edge_dist
edge_vec, edge_dist = get_edge_vectors_and_lengths(
data.pos, data.edge_index, data.edge_shifts
)

# Calculate angles
pos_ji = edge_vec[idx_ji]
pos_kj = edge_vec[idx_kj]
pos_ki = (
pos_kj + pos_ji
) # It's important to calculate the vectors separately and then add in case of periodic boundary conditions
a = (pos_ji * pos_ki).sum(dim=-1)
b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
angle = torch.atan2(b, a)

rbf = self.rbf(dist)
sbf = self.sbf(dist, angle, idx_kj)
rbf = self.rbf(edge_dist.squeeze())
sbf = self.sbf(edge_dist.squeeze(), angle, idx_kj)

conv_args = {
"rbf": rbf,
Expand Down
25 changes: 13 additions & 12 deletions hydragnn/models/EGCLStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .Base import Base

from hydragnn.utils.model import unsorted_segment_mean
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths


class EGCLStack(Base):
Expand Down Expand Up @@ -89,6 +90,12 @@ def get_conv(self, input_dim, output_dim, last_layer=False):
)

def _embedding(self, data):
super()._embedding(data)

data.edge_shifts = torch.zeros(
(data.edge_index.size(1), 3), device=data.edge_index.device
) # Override. pbc edge shifts are currently not supported in positional update models

if self.edge_dim > 0:
conv_args = {
"edge_index": data.edge_index,
Expand Down Expand Up @@ -229,20 +236,14 @@ def coord_model(self, coord, edge_index, coord_diff, edge_feat):
coord = coord + agg * self.coords_weight
return coord

def coord2radial(self, edge_index, coord):
row, col = edge_index
coord_diff = coord[row] - coord[col]
radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1)

if self.norm_diff:
norm = torch.sqrt(radial) + 1
coord_diff = coord_diff / (norm)

return radial, coord_diff

def forward(self, x, coord, edge_index, edge_attr, node_attr=None):
row, col = edge_index
radial, coord_diff = self.coord2radial(edge_index, coord)
edge_shifts = torch.zeros(
(edge_index.size(1), 3), device=edge_index.device
) # pbc edge shifts are currently not supported in positional update models
coord_diff, radial = get_edge_vectors_and_lengths(
coord, edge_index, edge_shifts, normalize=self.norm_diff, eps=1.0
)
# Message Passing
edge_feat = self.edge_model(x[row], x[col], radial, edge_attr)
if self.equivariant:
Expand Down
1 change: 1 addition & 0 deletions hydragnn/models/MACEStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
NonLinearMultiheadDecoderBlock,
LinearMultiheadDecoderBlock,
)
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths

# Etc
import numpy as np
Expand Down
33 changes: 16 additions & 17 deletions hydragnn/models/PAINNStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.utils.checkpoint import checkpoint

from .Base import Base
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths


class PAINNStack(Base):
Expand Down Expand Up @@ -125,24 +126,25 @@ def get_conv(self, input_dim, output_dim, last_layer=False):
)

def _embedding(self, data):
super()._embedding(data)

assert (
data.pos is not None
), "PAINNNet requires node positions (data.pos) to be set."
), "PAINN requires node positions (data.pos) to be set."

# Calculate relative vectors and distances
i, j = data.edge_index[0], data.edge_index[1]
diff = data.pos[i] - data.pos[j]
dist = diff.pow(2).sum(dim=-1).sqrt()
norm_diff = diff / dist.unsqueeze(-1)
# Get normalized edge vectors and lengths
norm_edge_vec, edge_dist = get_edge_vectors_and_lengths(
data.pos, data.edge_index, data.edge_shifts, normalize=True
)

# Instantiate tensor to hold equivariant traits
v = torch.zeros(data.x.size(0), 3, data.x.size(1), device=data.x.device)
data.v = v

conv_args = {
"edge_index": data.edge_index.t().to(torch.long),
"diff": norm_diff,
"dist": dist,
"diff": norm_edge_vec,
"dist": edge_dist,
}

return data.x, data.v, conv_args
Expand Down Expand Up @@ -171,9 +173,8 @@ def forward(self, node_scalar, node_vector, edge, edge_diff, edge_dist):
filter_weight = self.filter_layer(
sinc_expansion(edge_dist, self.edge_size, self.cutoff)
)
filter_weight = filter_weight * cosine_cutoff(edge_dist, self.cutoff).unsqueeze(
-1
)
filter_weight = filter_weight * cosine_cutoff(edge_dist, self.cutoff)

scalar_out = self.scalar_message_mlp(node_scalar)
filter_out = filter_weight * scalar_out[edge[:, 1]]

Expand All @@ -185,9 +186,9 @@ def forward(self, node_scalar, node_vector, edge, edge_diff, edge_dist):

# num_pairs * 3 * node_size, num_pairs * node_size
message_vector = node_vector[edge[:, 1]] * gate_state_vector.unsqueeze(1)
edge_vector = gate_edge_vector.unsqueeze(1) * (
edge_diff / edge_dist.unsqueeze(-1)
).unsqueeze(-1)
edge_vector = gate_edge_vector.unsqueeze(1) * (edge_diff / edge_dist).unsqueeze(
-1
)
message_vector = message_vector + edge_vector

# sum message
Expand Down Expand Up @@ -266,9 +267,7 @@ def sinc_expansion(edge_dist: torch.Tensor, edge_size: int, cutoff: float):
sin(n * pi * d / d_cut) / d
"""
n = torch.arange(edge_size, device=edge_dist.device) + 1
return torch.sin(
edge_dist.unsqueeze(-1) * n * torch.pi / cutoff
) / edge_dist.unsqueeze(-1)
return torch.sin(edge_dist * n * torch.pi / cutoff) / edge_dist


def cosine_cutoff(edge_dist: torch.Tensor, cutoff: float):
Expand Down
23 changes: 16 additions & 7 deletions hydragnn/models/PNAEqStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from torch_geometric.nn.aggr.scaler import DegreeScalerAggregation
from torch_geometric.typing import Adj, OptTensor

# HydraGNN
from .Base import Base
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths


class PNAEqStack(Base):
Expand Down Expand Up @@ -156,16 +158,17 @@ def get_conv(self, input_dim, output_dim, last_layer=False):
)

def _embedding(self, data):
super()._embedding(data)

assert (
data.pos is not None
), "PNAEq requires node positions (data.pos) to be set."

# Calculate relative vectors and distances
i, j = data.edge_index[0], data.edge_index[1]
diff = data.pos[i] - data.pos[j]
dist = diff.pow(2).sum(dim=-1).sqrt()
rbf = self.rbf(dist)
norm_diff = diff / dist.unsqueeze(-1)
# Edge vector and distance features
norm_edge_vec, edge_dist = get_edge_vectors_and_lengths(
data.pos, data.edge_index, data.edge_shifts, normalize=True
)
rbf = self.rbf(edge_dist.squeeze())

# Instantiate tensor to hold equivariant traits
v = torch.zeros(data.x.size(0), 3, data.x.size(1), device=data.x.device)
Expand All @@ -174,9 +177,15 @@ def _embedding(self, data):
conv_args = {
"edge_index": data.edge_index.t().to(torch.long),
"edge_rbf": rbf,
"edge_vec": norm_diff,
"edge_vec": norm_edge_vec,
}

if self.use_edge_attr:
assert (
data.edge_attr is not None
), "Data must have edge attributes if use_edge_attributes is set."
conv_args.update({"edge_attr": data.edge_attr})

return data.x, data.v, conv_args


Expand Down
13 changes: 9 additions & 4 deletions hydragnn/models/PNAPlusStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

# HydraGNN
from .Base import Base
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths


class PNAPlusStack(Base):
Expand Down Expand Up @@ -98,14 +99,18 @@ def get_conv(self, input_dim, output_dim):
)

def _embedding(self, data):
super()._embedding(data)

assert (
data.pos is not None
), "PNA+ requires node positions (data.pos) to be set."

j, i = data.edge_index # j->i
dist = (data.pos[i] - data.pos[j]).pow(2).sum(dim=-1).sqrt()
rbf = self.rbf(dist)
# rbf = dist.unsqueeze(-1)
# Radial embedding
_, edge_dist = get_edge_vectors_and_lengths(
data.pos, data.edge_index, data.edge_shifts
)
rbf = self.rbf(edge_dist.squeeze())

conv_args = {"edge_index": data.edge_index.to(torch.long), "rbf": rbf}

if self.use_edge_attr:
Expand Down
24 changes: 12 additions & 12 deletions hydragnn/models/SCFStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torch_geometric.nn import Sequential as PyGSeq
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.models.schnet import (
CFConv,
GaussianSmearing,
RadiusInteractionGraph,
ShiftedSoftplus,
Expand All @@ -27,6 +26,7 @@
from .Base import Base

from hydragnn.utils.model import unsorted_segment_mean
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths


class SCFStack(Base):
Expand Down Expand Up @@ -135,12 +135,17 @@ def get_conv(self, input_dim, output_dim, last_layer):
)

def _embedding(self, data):
super()._embedding(data)

if (self.use_edge_attr) and (self.equivariance):
raise Exception(
"For SchNet if using edge attributes, then E(3)-equivariance cannot be ensured. Please disable equivariance or edge attributes."
)
elif self.use_edge_attr:
edge_index = data.edge_index
data.edge_shifts = torch.zeros(
(data.edge_index.size(1), 3), device=data.edge_index.device
) # Override. pbc edge shifts are currently not supported in positional update models
edge_weight = data.edge_attr.norm(dim=-1)

conv_args = {
Expand Down Expand Up @@ -218,7 +223,12 @@ def forward(
x = self.lin1(x)

if self.equivariant:
radial, coord_diff = self.coord2radial(edge_index, pos)
edge_shifts = torch.zeros(
(edge_index.size(1), 3), device=edge_index.device
) # pbc edge shifts are currently not supported in positional update models
coord_diff, radial = get_edge_vectors_and_lengths(
pos, edge_index, edge_shifts, normalize=True, eps=1.0
)
pos = self.coord_model(pos, edge_index, coord_diff, W)

x = self.propagate(edge_index, x=x, W=W)
Expand All @@ -230,13 +240,3 @@ def forward(

def message(self, x_j: Tensor, W: Tensor) -> Tensor:
return x_j * W

def coord2radial(self, edge_index, coord):
row, col = edge_index
coord_diff = coord[row] - coord[col]
radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1)

norm = torch.sqrt(radial) + 1
coord_diff = coord_diff / (norm)

return radial, coord_diff
Loading

0 comments on commit 98e8bc3

Please sign in to comment.