-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
149 lines (119 loc) · 4.06 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
DirSNN and DirSNNClassifier implementation
"""
import torch
from layers import DirSNNLayer
class DirSNNClassifier(torch.nn.Module):
"""Readout classifier on top of DirSNN
Parameters
----------
n_classes : int
Number of classes for classification task
edge_channels: int
Dimension of edge features
n_layers: int
Number of layers.
n_hid_conv: int
Dimension of features of hidden layers on edges.
n_hid_mlp : int
Hidden dimension of readout MLP
conv_order: int
Order of convolutions, we consider the same order for all convolutions.
n_adjs: int
Number of adjacency matrixes used.
aggr_norm: bool
Whether to normalize the aggregation.
update_func: str
Update function for the simplicial complex convolution.
"""
def __init__(self, n_classes, edge_channels, n_layers, n_hid_conv = 32,
n_hid_mlp = 32, conv_order=1,
n_adjs=1,aggr_norm=False, update_func="relu"):
super().__init__()
self.n_classes = n_classes
self.n_hid_conv = n_hid_conv
self.n_hid_mlp = n_hid_mlp
self.scconv = DirSNN(edge_channels, n_layers=n_layers, n_hid=n_hid_conv,
conv_order=conv_order, n_adjs=n_adjs,
aggr_norm=aggr_norm, update_func=update_func)
self.readout = torch.nn.Sequential(
torch.nn.Linear(self.n_hid_conv, self.n_hid_mlp),
torch.nn.LeakyReLU(),
torch.nn.Linear(self.n_hid_mlp, self.n_classes)
)
self.readout.apply(self.weights_init)
self.log_softmax = torch.nn.LogSoftmax(dim=1)
def weights_init(self, m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
torch.nn.init.zeros_(m.bias)
def forward(self, x_1, adjs):
x_1 = self.scconv(x_1, adjs)
# maxpool on edge features and log softmax for class likelihoods
x = self.readout(x_1.max(1)[0])
y = self.log_softmax(x)
return y
class DirSNN(torch.nn.Module):
"""DirSNN implementation [1].
Parameters
----------
edge_channels: int
Dimension of edge features
n_layers: int
Number of layers.
n_hid: int
Dimension of features of hidden layers on edges.
conv_order: int
Order of convolutions, we consider the same order for all convolutions.
n_adjs: int
Number of adjacency matrixes used.
aggr_norm: bool
Whether to normalize the aggregation.
update_func: str
Update function for the simplicial complex convolution.
"""
def __init__(
self,
edge_channels,
n_layers=2,
n_hid=1,
conv_order=1,
n_adjs=1,
aggr_norm=False,
update_func=None,
):
super().__init__()
# first layer
# we use an MLP to map the features on simplices of different dimensions to the same dimension
self.in_linear_1 = torch.nn.Linear(edge_channels, n_hid)
self.layers = torch.nn.ModuleList(
DirSNNLayer(
in_channels_1=n_hid,
out_channels_1=n_hid,
conv_order=conv_order,
n_adjs=n_adjs,
aggr_norm=aggr_norm,
update_func=update_func,
)
for _ in range(n_layers)
)
def forward(self, x_1, adjs):
"""Forward computation.
Parameters
----------
x_1 : torch.Tensor, shape = (batch_size, n_edges, n_features).
laplacian_all : tuple of tensors
Tuple of Laplacian tensors.
Each entry shape = (n_edges,n_edges).
Returns
-------
x_1 : torch.Tensor
Output edge representations.
Shape = (batch_size, n_edges, output_size).
"""
in_x_1 = self.in_linear_1(x_1)
# Forward through DirSNN layers
x_1 = in_x_1
for layer in self.layers:
x_1 = layer(x_1, adjs)
return x_1