diff --git a/examples/LennardJones/LJ_data.py b/examples/LennardJones/LJ_data.py index f8f088747..fd59a1c68 100644 --- a/examples/LennardJones/LJ_data.py +++ b/examples/LennardJones/LJ_data.py @@ -106,7 +106,6 @@ def __init__(self, dirpath, config, dist=False, sampling=None): self.dataset.append(self.transform_input_to_data_object_base(filepath)) def transform_input_to_data_object_base(self, filepath): - # Using readline() file = open(filepath, "r") @@ -174,6 +173,11 @@ def transform_input_to_data_object_base(self, filepath): .unsqueeze(0) .to(torch.float32), energy=torch.tensor(total_energy).unsqueeze(0).to(torch.float32), + pbc=[ + True, + True, + True, + ], # LJ example always has periodic boundary conditions ) # Create pbc edges and lengths @@ -205,7 +209,6 @@ def deterministic_graph_data( unit_cell_z_range: list = [3, 4], relative_maximum_atomic_displacement: float = 1e-1, ): - comm = MPI.COMM_WORLD comm_size = comm.Get_size() comm_rank = comm.Get_rank() @@ -330,6 +333,7 @@ def create_configuration( data.supercell_size = torch.diag( torch.tensor([supercell_size_x, supercell_size_y, supercell_size_z]) ) + data.pbc = [True, True, True] create_graph_connectivity_pbc = get_radius_graph_pbc( radius_cutoff, max_num_neighbors @@ -379,27 +383,23 @@ class AtomicStructureHandler: def __init__( self, list_atom_types, bravais_lattice_constants, radius_cutoff, formula ): - self.bravais_lattice_constants = bravais_lattice_constants self.radius_cutoff = radius_cutoff self.formula = formula def compute(self, data): - assert data.pos.shape[0] == data.x.shape[0] interatomic_potential = torch.zeros([data.pos.shape[0], 1]) interatomic_forces = torch.zeros([data.pos.shape[0], 3]) for node_id in range(data.pos.shape[0]): - neighbor_list_indices = torch.where(data.edge_index[0, :] == node_id)[ 0 ].tolist() neighbor_list = data.edge_index[1, neighbor_list_indices] for neighbor_id, edge_id in zip(neighbor_list, neighbor_list_indices): - neighbor_pos = data.pos[neighbor_id, :] distance_vector = data.pos[neighbor_id, :] - data.pos[node_id, :] diff --git a/hydragnn/models/DIMEStack.py b/hydragnn/models/DIMEStack.py index 715dd5656..2670bc0cc 100644 --- a/hydragnn/models/DIMEStack.py +++ b/hydragnn/models/DIMEStack.py @@ -26,6 +26,7 @@ from torch_geometric.utils import scatter from .Base import Base +from hydragnn.utils.model.operations import get_edge_vectors_and_lengths class DIMEStack(Base): @@ -144,23 +145,34 @@ def get_conv(self, input_dim, output_dim): ) def _embedding(self, data): + super()._embedding(data) + assert ( data.pos is not None ), "DimeNet requires node positions (data.pos) to be set." + + # Calculate triplet indices i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets( data.edge_index, num_nodes=data.x.size(0) ) - dist = (data.pos[i] - data.pos[j]).pow(2).sum(dim=-1).sqrt() - # Calculate angles. - pos_i = data.pos[idx_i] - pos_ji, pos_ki = data.pos[idx_j] - pos_i, data.pos[idx_k] - pos_i + # Calculate edge_vec and edge_dist + edge_vec, edge_dist = get_edge_vectors_and_lengths( + data.pos, data.edge_index, data.edge_shifts + ) + + # Calculate angles + pos_ji = edge_vec[idx_ji] + pos_kj = edge_vec[idx_kj] + pos_ki = ( + pos_kj + pos_ji + ) # It's important to calculate the vectors separately and then add in case of periodic boundary conditions a = (pos_ji * pos_ki).sum(dim=-1) b = torch.cross(pos_ji, pos_ki).norm(dim=-1) angle = torch.atan2(b, a) - rbf = self.rbf(dist) - sbf = self.sbf(dist, angle, idx_kj) + rbf = self.rbf(edge_dist.squeeze()) + sbf = self.sbf(edge_dist.squeeze(), angle, idx_kj) conv_args = { "rbf": rbf, diff --git a/hydragnn/models/EGCLStack.py b/hydragnn/models/EGCLStack.py index 3a9a74928..7e93061d6 100644 --- a/hydragnn/models/EGCLStack.py +++ b/hydragnn/models/EGCLStack.py @@ -16,6 +16,7 @@ from .Base import Base from hydragnn.utils.model import unsorted_segment_mean +from hydragnn.utils.model.operations import get_edge_vectors_and_lengths class EGCLStack(Base): @@ -89,6 +90,12 @@ def get_conv(self, input_dim, output_dim, last_layer=False): ) def _embedding(self, data): + super()._embedding(data) + + data.edge_shifts = torch.zeros( + (data.edge_index.size(1), 3), device=data.edge_index.device + ) # Override. pbc edge shifts are currently not supported in positional update models + if self.edge_dim > 0: conv_args = { "edge_index": data.edge_index, @@ -229,20 +236,14 @@ def coord_model(self, coord, edge_index, coord_diff, edge_feat): coord = coord + agg * self.coords_weight return coord - def coord2radial(self, edge_index, coord): - row, col = edge_index - coord_diff = coord[row] - coord[col] - radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1) - - if self.norm_diff: - norm = torch.sqrt(radial) + 1 - coord_diff = coord_diff / (norm) - - return radial, coord_diff - def forward(self, x, coord, edge_index, edge_attr, node_attr=None): row, col = edge_index - radial, coord_diff = self.coord2radial(edge_index, coord) + edge_shifts = torch.zeros( + (edge_index.size(1), 3), device=edge_index.device + ) # pbc edge shifts are currently not supported in positional update models + coord_diff, radial = get_edge_vectors_and_lengths( + coord, edge_index, edge_shifts, normalize=self.norm_diff, eps=1.0 + ) # Message Passing edge_feat = self.edge_model(x[row], x[col], radial, edge_attr) if self.equivariant: diff --git a/hydragnn/models/MACEStack.py b/hydragnn/models/MACEStack.py index 8746bc47c..69b3f9440 100644 --- a/hydragnn/models/MACEStack.py +++ b/hydragnn/models/MACEStack.py @@ -64,6 +64,7 @@ NonLinearMultiheadDecoderBlock, LinearMultiheadDecoderBlock, ) +from hydragnn.utils.model.operations import get_edge_vectors_and_lengths # Etc import numpy as np diff --git a/hydragnn/models/PAINNStack.py b/hydragnn/models/PAINNStack.py index 01cb43816..561c40a24 100644 --- a/hydragnn/models/PAINNStack.py +++ b/hydragnn/models/PAINNStack.py @@ -20,6 +20,7 @@ from torch.utils.checkpoint import checkpoint from .Base import Base +from hydragnn.utils.model.operations import get_edge_vectors_and_lengths class PAINNStack(Base): @@ -125,15 +126,16 @@ def get_conv(self, input_dim, output_dim, last_layer=False): ) def _embedding(self, data): + super()._embedding(data) + assert ( data.pos is not None - ), "PAINNNet requires node positions (data.pos) to be set." + ), "PAINN requires node positions (data.pos) to be set." - # Calculate relative vectors and distances - i, j = data.edge_index[0], data.edge_index[1] - diff = data.pos[i] - data.pos[j] - dist = diff.pow(2).sum(dim=-1).sqrt() - norm_diff = diff / dist.unsqueeze(-1) + # Get normalized edge vectors and lengths + norm_edge_vec, edge_dist = get_edge_vectors_and_lengths( + data.pos, data.edge_index, data.edge_shifts, normalize=True + ) # Instantiate tensor to hold equivariant traits v = torch.zeros(data.x.size(0), 3, data.x.size(1), device=data.x.device) @@ -141,8 +143,8 @@ def _embedding(self, data): conv_args = { "edge_index": data.edge_index.t().to(torch.long), - "diff": norm_diff, - "dist": dist, + "diff": norm_edge_vec, + "dist": edge_dist, } return data.x, data.v, conv_args @@ -171,9 +173,8 @@ def forward(self, node_scalar, node_vector, edge, edge_diff, edge_dist): filter_weight = self.filter_layer( sinc_expansion(edge_dist, self.edge_size, self.cutoff) ) - filter_weight = filter_weight * cosine_cutoff(edge_dist, self.cutoff).unsqueeze( - -1 - ) + filter_weight = filter_weight * cosine_cutoff(edge_dist, self.cutoff) + scalar_out = self.scalar_message_mlp(node_scalar) filter_out = filter_weight * scalar_out[edge[:, 1]] @@ -185,9 +186,9 @@ def forward(self, node_scalar, node_vector, edge, edge_diff, edge_dist): # num_pairs * 3 * node_size, num_pairs * node_size message_vector = node_vector[edge[:, 1]] * gate_state_vector.unsqueeze(1) - edge_vector = gate_edge_vector.unsqueeze(1) * ( - edge_diff / edge_dist.unsqueeze(-1) - ).unsqueeze(-1) + edge_vector = gate_edge_vector.unsqueeze(1) * (edge_diff / edge_dist).unsqueeze( + -1 + ) message_vector = message_vector + edge_vector # sum message @@ -266,9 +267,7 @@ def sinc_expansion(edge_dist: torch.Tensor, edge_size: int, cutoff: float): sin(n * pi * d / d_cut) / d """ n = torch.arange(edge_size, device=edge_dist.device) + 1 - return torch.sin( - edge_dist.unsqueeze(-1) * n * torch.pi / cutoff - ) / edge_dist.unsqueeze(-1) + return torch.sin(edge_dist * n * torch.pi / cutoff) / edge_dist def cosine_cutoff(edge_dist: torch.Tensor, cutoff: float): diff --git a/hydragnn/models/PNAEqStack.py b/hydragnn/models/PNAEqStack.py index 22946e140..c20bea646 100644 --- a/hydragnn/models/PNAEqStack.py +++ b/hydragnn/models/PNAEqStack.py @@ -32,7 +32,9 @@ from torch_geometric.nn.aggr.scaler import DegreeScalerAggregation from torch_geometric.typing import Adj, OptTensor +# HydraGNN from .Base import Base +from hydragnn.utils.model.operations import get_edge_vectors_and_lengths class PNAEqStack(Base): @@ -156,16 +158,17 @@ def get_conv(self, input_dim, output_dim, last_layer=False): ) def _embedding(self, data): + super()._embedding(data) + assert ( data.pos is not None ), "PNAEq requires node positions (data.pos) to be set." - # Calculate relative vectors and distances - i, j = data.edge_index[0], data.edge_index[1] - diff = data.pos[i] - data.pos[j] - dist = diff.pow(2).sum(dim=-1).sqrt() - rbf = self.rbf(dist) - norm_diff = diff / dist.unsqueeze(-1) + # Edge vector and distance features + norm_edge_vec, edge_dist = get_edge_vectors_and_lengths( + data.pos, data.edge_index, data.edge_shifts, normalize=True + ) + rbf = self.rbf(edge_dist.squeeze()) # Instantiate tensor to hold equivariant traits v = torch.zeros(data.x.size(0), 3, data.x.size(1), device=data.x.device) @@ -174,9 +177,15 @@ def _embedding(self, data): conv_args = { "edge_index": data.edge_index.t().to(torch.long), "edge_rbf": rbf, - "edge_vec": norm_diff, + "edge_vec": norm_edge_vec, } + if self.use_edge_attr: + assert ( + data.edge_attr is not None + ), "Data must have edge attributes if use_edge_attributes is set." + conv_args.update({"edge_attr": data.edge_attr}) + return data.x, data.v, conv_args diff --git a/hydragnn/models/PNAPlusStack.py b/hydragnn/models/PNAPlusStack.py index 0a104ba03..5fd2daec3 100644 --- a/hydragnn/models/PNAPlusStack.py +++ b/hydragnn/models/PNAPlusStack.py @@ -34,6 +34,7 @@ # HydraGNN from .Base import Base +from hydragnn.utils.model.operations import get_edge_vectors_and_lengths class PNAPlusStack(Base): @@ -98,14 +99,18 @@ def get_conv(self, input_dim, output_dim): ) def _embedding(self, data): + super()._embedding(data) + assert ( data.pos is not None ), "PNA+ requires node positions (data.pos) to be set." - j, i = data.edge_index # j->i - dist = (data.pos[i] - data.pos[j]).pow(2).sum(dim=-1).sqrt() - rbf = self.rbf(dist) - # rbf = dist.unsqueeze(-1) + # Radial embedding + _, edge_dist = get_edge_vectors_and_lengths( + data.pos, data.edge_index, data.edge_shifts + ) + rbf = self.rbf(edge_dist.squeeze()) + conv_args = {"edge_index": data.edge_index.to(torch.long), "rbf": rbf} if self.use_edge_attr: diff --git a/hydragnn/models/SCFStack.py b/hydragnn/models/SCFStack.py index c132c8ecd..ab1d347ff 100644 --- a/hydragnn/models/SCFStack.py +++ b/hydragnn/models/SCFStack.py @@ -18,7 +18,6 @@ from torch_geometric.nn import Sequential as PyGSeq from torch_geometric.nn import MessagePassing from torch_geometric.nn.models.schnet import ( - CFConv, GaussianSmearing, RadiusInteractionGraph, ShiftedSoftplus, @@ -27,6 +26,7 @@ from .Base import Base from hydragnn.utils.model import unsorted_segment_mean +from hydragnn.utils.model.operations import get_edge_vectors_and_lengths class SCFStack(Base): @@ -135,12 +135,17 @@ def get_conv(self, input_dim, output_dim, last_layer): ) def _embedding(self, data): + super()._embedding(data) + if (self.use_edge_attr) and (self.equivariance): raise Exception( "For SchNet if using edge attributes, then E(3)-equivariance cannot be ensured. Please disable equivariance or edge attributes." ) elif self.use_edge_attr: edge_index = data.edge_index + data.edge_shifts = torch.zeros( + (data.edge_index.size(1), 3), device=data.edge_index.device + ) # Override. pbc edge shifts are currently not supported in positional update models edge_weight = data.edge_attr.norm(dim=-1) conv_args = { @@ -218,7 +223,12 @@ def forward( x = self.lin1(x) if self.equivariant: - radial, coord_diff = self.coord2radial(edge_index, pos) + edge_shifts = torch.zeros( + (edge_index.size(1), 3), device=edge_index.device + ) # pbc edge shifts are currently not supported in positional update models + coord_diff, radial = get_edge_vectors_and_lengths( + pos, edge_index, edge_shifts, normalize=True, eps=1.0 + ) pos = self.coord_model(pos, edge_index, coord_diff, W) x = self.propagate(edge_index, x=x, W=W) @@ -230,13 +240,3 @@ def forward( def message(self, x_j: Tensor, W: Tensor) -> Tensor: return x_j * W - - def coord2radial(self, edge_index, coord): - row, col = edge_index - coord_diff = coord[row] - coord[col] - radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1) - - norm = torch.sqrt(radial) + 1 - coord_diff = coord_diff / (norm) - - return radial, coord_diff diff --git a/hydragnn/preprocess/graph_samples_checks_and_updates.py b/hydragnn/preprocess/graph_samples_checks_and_updates.py index b4162d742..c0142f011 100644 --- a/hydragnn/preprocess/graph_samples_checks_and_updates.py +++ b/hydragnn/preprocess/graph_samples_checks_and_updates.py @@ -138,35 +138,59 @@ class RadiusGraphPBC(RadiusGraph): def __call__(self, data): data.edge_attr = None + data.edge_shifts = None assert ( "batch" not in data ), "Periodic boundary conditions not currently supported on batches." assert hasattr( data, "supercell_size" ), "The data must contain the size of the supercell to apply periodic boundary conditions." + assert hasattr( + data, "pbc" + ), "The data must contain data.pbc as a bool (True) or list of bools for the dimensions ([True, False, True]) to apply periodic boundary conditions." + # NOTE Cutoff radius being less than half the smallest supercell dimension is a sufficient, but not necessary condition for no dupe connections. + # However, to prevent an issue from being unobserved until long into an experiment, we assert this condition. + assert ( + self.r < min(torch.diagonal(data.supercell_size)) / 2 + ), "Cutoff radius must be smaller than half the smallest supercell dimension." ase_atom_object = ase.Atoms( positions=data.pos, cell=data.supercell_size, - pbc=True, + pbc=data.pbc, ) - # ā€˜iā€™ : first atom index - # ā€˜jā€™ : second atom index + # 'i' : first atom index + # 'j' : second atom index + # 'd' : absolute distance + # 'S' : shift vector # https://wiki.fysik.dtu.dk/ase/ase/neighborlist.html#ase.neighborlist.neighbor_list - edge_src, edge_dst, edge_length = ase.neighborlist.neighbor_list( - "ijd", a=ase_atom_object, cutoff=self.r, self_interaction=self.loop + ( + edge_src, + edge_dst, + edge_length, + edge_cell_shifts, + ) = ase.neighborlist.neighbor_list( + "ijdS", a=ase_atom_object, cutoff=self.r, self_interaction=self.loop ) data.edge_index = torch.stack( - [torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0 + [torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], + dim=0, # Shape: [2, n_edges] ) # ensure no duplicate edges - num_edges = data.edge_index.size(1) - data.coalesce() - assert num_edges == data.edge_index.size( + unique_edge_index, unique_indices = torch.unique( + data.edge_index, dim=1, return_inverse=False # Shape: [n_edges] + ) + assert unique_edge_index.unsqueeze(0).size(1) == data.edge_index.size( 1 ), "Adding periodic boundary conditions would result in duplicate edges. Cutoff radius must be reduced or system size increased." - data.edge_attr = torch.tensor(edge_length, dtype=torch.float).unsqueeze(1) + data.edge_attr = torch.tensor(edge_length, dtype=torch.float).unsqueeze( + 1 + ) # Shape: [n_edges, 1] + # ASE returns whether the cell was shifted or not (-1,0,1). Multiply by the cell size to get the actual shift + data.edge_shifts = torch.matmul( + torch.tensor(edge_cell_shifts).float(), data.supercell_size.float() + ) # Shape: [n_edges, 3] return data diff --git a/hydragnn/preprocess/serialized_dataset_loader.py b/hydragnn/preprocess/serialized_dataset_loader.py index 3b385f936..f084f909f 100644 --- a/hydragnn/preprocess/serialized_dataset_loader.py +++ b/hydragnn/preprocess/serialized_dataset_loader.py @@ -125,6 +125,8 @@ def load_serialized_data(self, dataset_path: str): dataset[:] = [rotational_invariance(data) for data in dataset] if self.periodic_boundary_conditions: + for data in dataset: + data.pbc = [True, True, True] # edge lengths already added manually if using PBC, so no need to call Distance. compute_edges = get_radius_graph_pbc( radius=self.radius, diff --git a/hydragnn/utils/datasets/abstractrawdataset.py b/hydragnn/utils/datasets/abstractrawdataset.py index d81d4bfd5..e20b166bc 100644 --- a/hydragnn/utils/datasets/abstractrawdataset.py +++ b/hydragnn/utils/datasets/abstractrawdataset.py @@ -339,6 +339,8 @@ def __build_edge(self): self.dataset[:] = [rotational_invariance(data) for data in self.dataset] if self.periodic_boundary_conditions: + for data in dataset: + data.pbc = [True, True, True] # edge lengths already added manually if using PBC, so no need to call Distance. compute_edges = get_radius_graph_pbc( radius=self.radius, diff --git a/hydragnn/utils/input_config_parsing/config_utils.py b/hydragnn/utils/input_config_parsing/config_utils.py index a165ee7b7..57eaab8d8 100644 --- a/hydragnn/utils/input_config_parsing/config_utils.py +++ b/hydragnn/utils/input_config_parsing/config_utils.py @@ -140,7 +140,7 @@ def update_config_equivariance(config): if "equivariance" in config and config["equivariance"]: assert ( config["model_type"] in equivariant_models - ), "E(3) equivariance can only be ensured for EGNN, SchNet, and MACE." + ), "E(3) equivariance can only be ensured for EGNN, SchNet, PNAEq, PAINN, and MACE." elif "equivariance" not in config: config["equivariance"] = False return config @@ -148,11 +148,20 @@ def update_config_equivariance(config): def update_config_edge_dim(config): config["edge_dim"] = None - edge_models = ["PNAPlus", "PNA", "CGCNN", "SchNet", "EGNN", "DimeNet", "MACE"] + edge_models = [ + "PNA", + "PNAPlus", + "PNAEq", + "CGCNN", + "SchNet", + "EGNN", + "DimeNet", + "MACE", + ] if "edge_features" in config and config["edge_features"]: assert ( config["model_type"] in edge_models - ), "Edge features can only be used with DimeNet, MACE, EGNN, SchNet, PNA, PNAPlus, and CGCNN." + ), "Edge features can only be used with DimeNet, MACE, EGNN, SchNet, PNA, PNAPlus, PNAEq, and CGCNN." config["edge_dim"] = len(config["edge_features"]) elif config["model_type"] == "CGCNN": # CG always needs an integer edge_dim diff --git a/hydragnn/utils/model/operations.py b/hydragnn/utils/model/operations.py index 5101f70ac..dc23d02c4 100644 --- a/hydragnn/utils/model/operations.py +++ b/hydragnn/utils/model/operations.py @@ -17,6 +17,7 @@ # Function for the computation of edge vectors and lengths (MIT License (see MIT.md)) # Authors: Ilyes Batatia, Gregor Simm and David Kovacs ########################################################################################### +# NOTE Shifts are applied to take into account periodic boundary conditions. If there are no PBCs, the shifts are 0. def get_edge_vectors_and_lengths( positions: torch.Tensor, # [n_nodes, 3] edge_index: torch.Tensor, # [2, n_edges] diff --git a/tests/test_periodic_boundary_conditions.py b/tests/test_periodic_boundary_conditions.py index a81e0b9f6..570736210 100644 --- a/tests/test_periodic_boundary_conditions.py +++ b/tests/test_periodic_boundary_conditions.py @@ -84,6 +84,7 @@ def pytest_periodic_h2(): data.supercell_size = torch.tensor( [[3.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0]] ) + data.pbc = [True, True, True] data.atom_types = [1, 1] data.pos = torch.tensor([[1.0, 1.0, 1.0], [1.43, 1.43, 1.43]]) data.x = torch.tensor([[3, 5, 7], [9, 11, 13]]) @@ -108,6 +109,7 @@ def pytest_periodic_bcc_large(): # Convert to PyG data2 = Data() data2.supercell_size = torch.tensor(supercell.cell[:]) + data2.pbc = [True, True, True] data2.atom_types = np.ones(len(supercell)) * 27 data2.pos = torch.tensor(supercell.positions) data2.x = torch.randn(data2.pos.size(0), 1)