Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorporating Equivariance #194

Merged
merged 3 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions hydragnn/models/Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
output_type: list,
config_heads: dict,
loss_function_type: str,
equivariance: bool,
ilossweights_hyperp: int = 1, # if =1, considering weighted losses for different tasks and treat the weights as hyper parameters
loss_weights: list = [1.0, 1.0, 1.0], # weights for losses of different tasks
ilossweights_nll: int = 0, # if =1, using the scalar uncertainty as weights, as in paper# https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
self.batch_norms_node_hidden = ModuleList()
self.convs_node_output = ModuleList()
self.batch_norms_node_output = ModuleList()
self.equivariance = equivariance

self.loss_function = loss_function_selection(loss_function_type)
self.ilossweights_nll = ilossweights_nll
Expand Down Expand Up @@ -109,8 +111,11 @@ def _init_conv(self):
self.feature_layers.append(BatchNorm(self.hidden_dim))

def _conv_args(self, data):
conv_args = {"edge_index": data.edge_index}
if (data.edge_attr is not None) and (self.use_edge_attr):
conv_args = {"edge_index": data.edge_index.to(torch.long)}
if self.use_edge_attr:
assert (
data.edge_attr is not None
), "Data must have edge attributes if use_edge_attributes is set."
conv_args.update({"edge_attr": data.edge_attr})
return conv_args

Expand Down Expand Up @@ -243,11 +248,12 @@ def _multihead(self):

def forward(self, data):
x = data.x
pos = data.pos

### encoder part ####
conv_args = self._conv_args(data)
for conv, feat_layer in zip(self.graph_convs, self.feature_layers):
c = conv(x=x, **conv_args)
c, pos = conv(x=x, pos=pos, **conv_args)
x = F.relu(feat_layer(c))

#### multi-head decoder part####
Expand Down
19 changes: 17 additions & 2 deletions hydragnn/models/CGCNNStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
import torch.nn.functional as F
from torch.nn import ModuleList
from torch_geometric.nn import CGConv, BatchNorm, global_mean_pool
from torch_geometric.nn import CGConv, BatchNorm, global_mean_pool, Sequential
from .Base import Base


Expand Down Expand Up @@ -40,14 +40,29 @@ def __init__(
)

def get_conv(self, input_dim, _):
return CGConv(
cgcnn = CGConv(
channels=input_dim,
dim=self.edge_dim,
aggr="add",
batch_norm=False,
bias=True,
)

input_args = "x, pos, edge_index"
conv_args = "x, edge_index"

if self.use_edge_attr:
input_args += ", edge_attr"
conv_args += ", edge_attr"

return Sequential(
input_args,
[
(cgcnn, conv_args + " -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
],
)

def _init_node_conv(self):
"""It overwrites _init_node_conv() in Base since purely convolutional layers in _init_node_conv() is not implemented yet.
Here it serves as a temporary place holder. Purely cgcnn conv is not feasible for node feature predictions with
Expand Down
3 changes: 2 additions & 1 deletion hydragnn/models/DIMEStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ def get_conv(self, input_dim, output_dim):
act=SiLU(),
)
return Sequential(
"x, rbf, sbf, i, j, idx_kj, idx_ji",
"x, pos, rbf, sbf, i, j, idx_kj, idx_ji",
[
(lin, "x -> x"),
(emb, "x, rbf, i, j -> x1"),
(inter, "x1, rbf, sbf, idx_kj, idx_ji -> x2"),
(dec, "x2, rbf, i -> c"),
(lambda x, pos: [x, pos], "c, pos -> c, pos"),
],
)

Expand Down
63 changes: 44 additions & 19 deletions hydragnn/models/EGCLStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,52 @@ def __init__(
super().__init__(*args, **kwargs)
pass

def get_conv(self, input_dim, output_dim):
def _init_conv(self):
last_layer = 1 == self.num_conv_layers
self.graph_convs.append(
self.get_conv(self.input_dim, self.hidden_dim, last_layer)
)
self.feature_layers.append(nn.Identity())
for i in range(self.num_conv_layers - 1):
last_layer = i == self.num_conv_layers - 2
conv = self.get_conv(self.hidden_dim, self.hidden_dim, last_layer)
self.graph_convs.append(conv)
self.feature_layers.append(nn.Identity())

def get_conv(self, input_dim, output_dim, last_layer=False):
egcl = E_GCL(
input_channels=input_dim,
output_channels=output_dim,
hidden_channels=self.hidden_dim,
edge_attr_dim=self.edge_dim,
equivariant=self.equivariance and not last_layer,
)
return Sequential(
"x, edge_index, coord, edge_attr",
[
(egcl, "x, edge_index, coord, edge_attr -> x"),
],
)

if self.equivariance and not last_layer:
return Sequential(
"x, pos, edge_index, edge_attr",
[
(egcl, "x, pos, edge_index, edge_attr -> x, pos"),
],
)
else:
return Sequential(
"x, pos, edge_index, edge_attr",
[
(egcl, "x, pos, edge_index, edge_attr -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
],
)

def _conv_args(self, data):
if self.edge_dim > 0:
conv_args = {
"edge_index": data.edge_index,
"coord": data.pos,
"edge_attr": data.edge_attr,
}
else:
conv_args = {
"edge_index": data.edge_index,
"coord": data.pos,
"edge_attr": None,
}

Expand Down Expand Up @@ -105,7 +126,7 @@ def __init__(
clamp=False,
norm_diff=True,
tanh=True,
coord_mlp=False,
equivariant=False,
) -> None:
super(E_GCL, self).__init__()
input_edge = input_channels * 2
Expand All @@ -115,7 +136,7 @@ def __init__(
self.norm_diff = norm_diff
self.tanh = tanh
edge_coords_nf = 1
self.coord_mlp = coord_mlp
self.equivariant = equivariant
self.edge_attr_dim = edge_attr_dim

self.edge_mlp = nn.Sequential(
Expand All @@ -133,12 +154,13 @@ def __init__(
nn.Linear(hidden_channels, output_channels),
)

layer = nn.Linear(hidden_channels, 1, bias=False)
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

self.clamp = clamp

if self.coord_mlp:
if self.equivariant:

layer = nn.Linear(hidden_channels, 1, bias=False)
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)

coord_mlp = []
coord_mlp.append(nn.Linear(hidden_channels, hidden_channels))
coord_mlp.append(act_fn)
Expand Down Expand Up @@ -184,7 +206,7 @@ def coord_model(self, coord, edge_index, coord_diff, edge_feat):
trans, min=-100, max=100
) # This is never activated but just in case it case it explosed it may save the train
agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))
coord += agg * self.coords_weight
coord = coord + agg * self.coords_weight
allaffa marked this conversation as resolved.
Show resolved Hide resolved
return coord

def coord2radial(self, edge_index, coord):
Expand All @@ -198,15 +220,18 @@ def coord2radial(self, edge_index, coord):

return radial, coord_diff

def forward(self, x, edge_index, coord, edge_attr, node_attr=None):
def forward(self, x, coord, edge_index, edge_attr, node_attr=None):
row, col = edge_index
radial, coord_diff = self.coord2radial(edge_index, coord)
# Message Passing
edge_feat = self.edge_model(x[row], x[col], radial, edge_attr)
if self.coord_mlp:
if self.equivariant:
coord = self.coord_model(coord, edge_index, coord_diff, edge_feat)
x, agg = self.node_model(x, edge_index, edge_feat, node_attr)
return x # , coord, edge_attr
if self.equivariant:
return x, coord
else:
return x


def unsorted_segment_sum(data, segment_ids, num_segments):
Expand Down
21 changes: 18 additions & 3 deletions hydragnn/models/GATStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import torch
import torch.nn.functional as F
from torch.nn import ModuleList
from torch.nn import Sequential, ReLU, Linear
from torch_geometric.nn import GATv2Conv, BatchNorm
from torch.nn import ReLU, Linear
from torch_geometric.nn import GATv2Conv, BatchNorm, Sequential

from .Base import Base

Expand Down Expand Up @@ -89,7 +89,7 @@ def _init_node_conv(self):
self.batch_norms_node_output.append(BatchNorm(self.head_dims[ihead]))

def get_conv(self, input_dim, output_dim, concat):
return GATv2Conv(
gat = GATv2Conv(
in_channels=input_dim,
out_channels=output_dim,
heads=self.heads,
Expand All @@ -99,5 +99,20 @@ def get_conv(self, input_dim, output_dim, concat):
concat=concat,
)

input_args = "x, pos, edge_index"
conv_args = "x, edge_index"

if self.use_edge_attr:
input_args += ", edge_attr"
conv_args += ", edge_attr"

return Sequential(
input_args,
[
(gat, conv_args + " -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
],
)

def __str__(self):
return "GATStack"
15 changes: 13 additions & 2 deletions hydragnn/models/GINStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import ModuleList
from torch_geometric.nn import GINConv, BatchNorm
from torch_geometric.nn import GINConv, BatchNorm, Sequential

from .Base import Base

Expand All @@ -23,7 +23,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_conv(self, input_dim, output_dim):
return GINConv(
gin = GINConv(
nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.ReLU(),
Expand All @@ -33,5 +33,16 @@ def get_conv(self, input_dim, output_dim):
train_eps=True,
)

input_args = "x, pos, edge_index"
conv_args = "x, edge_index"

return Sequential(
input_args,
[
(gin, conv_args + " -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
],
)

def __str__(self):
return "GINStack"
17 changes: 14 additions & 3 deletions hydragnn/models/MFCStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import torch
import torch.nn.functional as F
from torch.nn import ModuleList
from torch.nn import Sequential, ReLU, Linear
from torch_geometric.nn import MFConv, BatchNorm, global_mean_pool
from torch.nn import ReLU, Linear
from torch_geometric.nn import MFConv, BatchNorm, global_mean_pool, Sequential

from .Base import Base

Expand All @@ -30,11 +30,22 @@ def __init__(
super().__init__(*args, **kwargs)

def get_conv(self, input_dim, output_dim):
return MFConv(
mfc = MFConv(
in_channels=input_dim,
out_channels=output_dim,
max_degree=self.max_degree,
)

input_args = "x, pos, edge_index"
conv_args = "x, edge_index"

return Sequential(
input_args,
[
(mfc, conv_args + " -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
],
)

def __str__(self):
return "MFCStack"
19 changes: 17 additions & 2 deletions hydragnn/models/PNAStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
import torch.nn.functional as F
from torch.nn import ModuleList
from torch_geometric.nn import PNAConv, BatchNorm, global_mean_pool
from torch_geometric.nn import PNAConv, BatchNorm, global_mean_pool, Sequential
from .Base import Base


Expand All @@ -38,7 +38,7 @@ def __init__(
super().__init__(*args, **kwargs)

def get_conv(self, input_dim, output_dim):
return PNAConv(
pna = PNAConv(
in_channels=input_dim,
out_channels=output_dim,
aggregators=self.aggregators,
Expand All @@ -50,5 +50,20 @@ def get_conv(self, input_dim, output_dim):
divide_input=False,
)

input_args = "x, pos, edge_index"
conv_args = "x, edge_index"

if self.use_edge_attr:
input_args += ", edge_attr"
conv_args += ", edge_attr"

return Sequential(
input_args,
[
(pna, conv_args + " -> x"),
(lambda x, pos: [x, pos], "x, pos -> x, pos"),
],
)

def __str__(self):
return "PNAStack"
Loading