Skip to content

Commit

Permalink
Add edge_attr into PAINN (#314)
Browse files Browse the repository at this point in the history
* Add edge_attr into PAINN

* fix typo
  • Loading branch information
RylieWeaver authored Dec 4, 2024
1 parent 93e9ad7 commit 18251d0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 15 deletions.
52 changes: 40 additions & 12 deletions hydragnn/models/PAINNStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch import nn
from torch_geometric import nn as geom_nn
from torch.utils.checkpoint import checkpoint
from torch_geometric.typing import OptTensor

from .Base import Base
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths
Expand All @@ -31,15 +32,15 @@ class PAINNStack(Base):

def __init__(
self,
# edge_dim: int, # To-Do: Add edge_features
input_args,
conv_args,
edge_dim: int,
num_radial: int,
radius: float,
*args,
**kwargs
):
# self.edge_dim = edge_dim
self.edge_dim = edge_dim
self.num_radial = num_radial
self.radius = radius

Expand All @@ -61,7 +62,10 @@ def get_conv(self, input_dim, output_dim, last_layer=False):
hidden_dim > 1
), "PainnNet requires more than one hidden dimension between input_dim and output_dim."
self_inter = PainnMessage(
node_size=input_dim, edge_size=self.num_radial, cutoff=self.radius
node_size=input_dim,
num_radial=self.num_radial,
cutoff=self.radius,
edge_dim=self.edge_dim,
)
cross_inter = PainnUpdate(node_size=input_dim, last_layer=last_layer)
"""
Expand All @@ -82,7 +86,7 @@ def get_conv(self, input_dim, output_dim, last_layer=False):
[
(
self_inter,
"inv_node_feat, equiv_node_feat, edge_index, diff, dist -> inv_node_feat, equiv_node_feat",
self.conv_args + " -> inv_node_feat, equiv_node_feat",
),
(
cross_inter,
Expand All @@ -105,7 +109,7 @@ def get_conv(self, input_dim, output_dim, last_layer=False):
[
(
self_inter,
"inv_node_feat, equiv_node_feat, edge_index, diff, dist -> inv_node_feat, equiv_node_feat",
self.conv_args + " -> inv_node_feat, equiv_node_feat",
),
(
cross_inter,
Expand Down Expand Up @@ -147,33 +151,57 @@ def _embedding(self, data):
"dist": edge_dist,
}

if self.use_edge_attr:
assert (
data.edge_attr is not None
), "Data must have edge attributes if use_edge_attributes is set."
conv_args.update({"edge_attr": data.edge_attr})

return data.x, data.v, conv_args


class PainnMessage(nn.Module):
"""Message function"""

def __init__(self, node_size: int, edge_size: int, cutoff: float):
def __init__(self, node_size: int, num_radial: int, cutoff: float, edge_dim: int):
super().__init__()

self.node_size = node_size
self.edge_size = edge_size
self.num_radial = num_radial
self.cutoff = cutoff
self.edge_dim = edge_dim

self.scalar_message_mlp = nn.Sequential(
nn.Linear(node_size, node_size),
nn.SiLU(),
nn.Linear(node_size, node_size * 3),
)

self.filter_layer = nn.Linear(edge_size, node_size * 3)
self.filter_layer = nn.Linear(num_radial, node_size * 3)

if self.edge_dim is not None:
self.edge_filter = nn.Sequential(
nn.Linear(self.edge_dim, node_size),
nn.SiLU(),
nn.Linear(node_size, node_size * 3),
)

def forward(self, node_scalar, node_vector, edge, edge_diff, edge_dist):
def forward(
self,
node_scalar,
node_vector,
edge,
edge_diff,
edge_dist,
edge_attr: OptTensor = None,
):
# remember to use v_j, s_j but not v_i, s_i
filter_weight = self.filter_layer(
sinc_expansion(edge_dist, self.edge_size, self.cutoff)
sinc_expansion(edge_dist, self.num_radial, self.cutoff)
)
filter_weight = filter_weight * cosine_cutoff(edge_dist, self.cutoff)
if edge_attr is not None:
filter_weight = filter_weight * self.edge_filter(edge_attr)

scalar_out = self.scalar_message_mlp(node_scalar)
filter_out = filter_weight * scalar_out[edge[:, 1]]
Expand Down Expand Up @@ -260,13 +288,13 @@ def forward(self, node_scalar, node_vector):
return node_scalar + delta_s


def sinc_expansion(edge_dist: torch.Tensor, edge_size: int, cutoff: float):
def sinc_expansion(edge_dist: torch.Tensor, num_radial: int, cutoff: float):
"""
Calculate sinc radial basis function:
sin(n * pi * d / d_cut) / d
"""
n = torch.arange(edge_size, device=edge_dist.device) + 1
n = torch.arange(num_radial, device=edge_dist.device) + 1
return torch.sin(edge_dist * n * torch.pi / cutoff) / edge_dist


Expand Down
4 changes: 2 additions & 2 deletions hydragnn/models/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,9 @@ def create_model(

elif model_type == "PAINN":
model = PAINNStack(
# edge_dim, # To-do add edge_features
"inv_node_feat, equiv_node_feat, edge_index, diff, dist",
"",
"inv_node_feat, equiv_node_feat, edge_index, diff, dist",
edge_dim,
num_radial,
radius,
input_dim,
Expand Down
3 changes: 2 additions & 1 deletion hydragnn/utils/input_config_parsing/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def update_config_edge_dim(config):
"GAT",
"PNA",
"PNAPlus",
"PAINN",
"PNAEq",
"CGCNN",
"SchNet",
Expand All @@ -162,7 +163,7 @@ def update_config_edge_dim(config):
if "edge_features" in config and config["edge_features"]:
assert (
config["model_type"] in edge_models
), "Edge features can only be used with GAT, PNA, PNAPlus, PNAEq, CGCNN, SchNet, EGNN, DimeNet, MACE"
), "Edge features can only be used with GAT, PNA, PNAPlus, PAINN, PNAEq, CGCNN, SchNet, EGNN, DimeNet, MACE"
config["edge_dim"] = len(config["edge_features"])
elif config["model_type"] == "CGCNN":
# CG always needs an integer edge_dim
Expand Down

0 comments on commit 18251d0

Please sign in to comment.