-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
189 lines (153 loc) · 5.79 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
Layer of DirSNN.
Adapted from https://github.com/pyt-team/TopoModelX/blob/main/topomodelx/nn/simplicial/sccnn_layer.py
"""
import torch
from torch.nn.parameter import Parameter
class DirSNNLayer(torch.nn.Module):
r"""Layer of a Directed Simplicial Neural Network.
Parameters
----------
in_channels_1 : int
Dimensions of input features on edges
out_channels_1 : int
Dimensions of output features on edges
conv_order : int
Convolution order of the simplicial filter
n_adjs : int
Number of adjacency matrixes used
aggr_norm : bool = False
Whether to normalize the aggregated message by the neighborhood size.
update_func : str, default = None
Activation function used in aggregation layers.
initialization : str, default = "xavier_normal"
Weight initialization method.
"""
def __init__(
self,
in_channels_1,
out_channels_1,
conv_order,
n_adjs = 1,
aggr_norm: bool = False,
update_func=None,
initialization: str = "xavier_normal",
) -> None:
super().__init__()
self.in_channels_1 = in_channels_1
self.out_channels_1 = out_channels_1
self.conv_order = conv_order
self.aggr_norm = aggr_norm
self.update_func = update_func
self.initialization = initialization
self.n_adjs = n_adjs
assert initialization in ["xavier_uniform", "xavier_normal"]
assert self.conv_order > 0
self.weight_1 = Parameter(
torch.Tensor(
self.in_channels_1,
self.out_channels_1,
conv_order * n_adjs + 1,
)
)
self.reset_parameters()
def reset_parameters(self, gain: float = 1.414):
r"""Reset learnable parameters.
Parameters
----------
gain : float
Gain for the weight initialization.
Notes
-----
This function will be called by subclasses of
MessagePassing that have trainable weights.
"""
if self.initialization == "xavier_uniform":
torch.nn.init.xavier_uniform_(self.weight_1, gain=gain)
elif self.initialization == "xavier_normal":
torch.nn.init.xavier_normal_(self.weight_1, gain=gain)
else:
raise RuntimeError(
"Initialization method not recognized. "
"Should be either xavier_uniform or xavier_normal."
)
def aggr_norm_func(self, conv_operator, x):
r"""Perform aggregation normalization."""
neighborhood_size = torch.sum(conv_operator.to_dense(), dim=1)
neighborhood_size_inv = 1 / neighborhood_size
neighborhood_size_inv[~(torch.isfinite(neighborhood_size_inv))] = 0
x = torch.einsum("i,ij->ij ", neighborhood_size_inv, x)
x[~torch.isfinite(x)] = 0
return x
def update(self, x):
"""Update embeddings on each cell (step 4).
Parameters
----------
x : torch.Tensor, shape = (batch_size,n_target_cells, out_channels)
Feature tensor.
Returns
-------
torch.Tensor, shape = (n_target_cells, out_channels)
Updated output features on target cells.
"""
if self.update_func == "sigmoid":
return torch.sigmoid(x)
if self.update_func == "relu":
return torch.nn.functional.relu(x)
if self.update_func == "leaky_relu":
return torch.nn.functional.leaky_relu(x)
return None
def chebyshev_conv(self, conv_operator, conv_order, x):
r"""Perform Chebyshev convolution.
Parameters
----------
conv_operator : torch.sparse, shape = (n_simplices,n_simplices)
Convolution operator e.g., the adjacency matrix, or the Hodge Laplacians.
conv_order : int
The order of the convolution.
x : torch.Tensor, shape = (batch_size,n_simplices,num_channels)
Feature tensor.
Returns
-------
torch.Tensor
Output tensor. x[:, :, k] = (conv_operator@....@conv_operator) @ x.
"""
batch_size, num_simplices, num_channels = x.shape
X = torch.empty(size=(batch_size, num_simplices, num_channels, conv_order))
if self.aggr_norm:
X[:, :, :, 0] = torch.matmul(conv_operator, x)
X[:, :, :, 0] = self.aggr_norm_func(conv_operator, X[:, :, :, 0])
for k in range(1, conv_order):
X[:, :, :, k] = torch.matmul(conv_operator, X[:, :, :, k - 1])
X[:, :, :, k] = self.aggr_norm_func(conv_operator, X[:, :, :, k])
else:
X[:, :, :, 0] = torch.matmul(conv_operator, x)
for k in range(1, conv_order):
X[:, :, :, k] = torch.matmul(conv_operator, X[:, :, :, k - 1])
return X
def forward(self, x_1, laplacian_all):
r"""Forward computation.
Parameters
----------
x_1 : torch.Tensor, shape = (batch_size,n_edges,in_channels_1),
Edge features
laplacian_all: tuple of tensors, len = n_adjs
Tuple of adjacency tensors
Returns
-------
y_1 : torch.Tensor
Output features on edges.
"""
num_edges = x_1.shape[1]
identity_1 = torch.eye(num_edges)
# convolution in the edge space
x_1_all = []
x_1_all.append(torch.unsqueeze(identity_1 @ x_1, 3))
for adj in laplacian_all:
x_1_all.append(self.chebyshev_conv(adj, self.conv_order, x_1))
assert len(x_1_all) == self.n_adjs + 1
x_1_all = torch.cat(x_1_all, dim=3)
y_1 = torch.einsum("bnik,iok->bno", x_1_all, self.weight_1)
if self.update_func is None:
return y_1
return self.update(y_1)