-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathegnn.py
347 lines (286 loc) · 12.3 KB
/
egnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import jax
from typing import Callable, List
import flax.linen as nn
import jax.numpy as jnp
import jraph
from jraph._src import utils
import e3nn_jax as e3nn
from models.mlp import MLP
def get_edge_mlp_updates(
d_hidden,
n_layers,
activation,
position_only=False,
tanh_out=False,
soft_edges=False,
apply_pbc=None,
n_radial_basis=32,
r_max=0.3,
) -> Callable:
def update_fn(
edges: jnp.array, senders: jnp.array, receivers: jnp.array, globals: jnp.array
) -> jnp.array:
"""update edge features
Args:
edges (jnp.ndarray): edge attributes
senders (jnp.ndarray): senders node attributes
receivers (jnp.ndarray): receivers node attributes
globals (jnp.ndarray): global features
Returns:
jnp.ndarray: updated edge features
"""
if position_only: # Positions only; no velocities
if senders.shape[-1] == 3:
x_i = senders
x_j = receivers
concats = globals
else:
x_i, h_i = senders[:, :3], senders[:, 3:]
x_j, h_j = receivers[:, :3], receivers[:, 3:]
concats = jnp.concatenate([h_i, h_j], -1)
if globals is not None:
concats = jnp.concatenate([concats, globals], -1)
else: # Positions and velocities
if senders.shape[-1] == 6:
x_i, v_i = senders[:, :3], senders[:, 3:6]
x_j, v_j = receivers[:, :3], receivers[:, 3:6]
concats = globals
else:
x_i, v_i, h_i = senders[:, :3], senders[:, 3:6], senders[:, 6:]
x_j, v_j, h_j = receivers[:, :3], receivers[:, 3:6], receivers[:, 6:]
concats = jnp.concatenate([h_i, h_j], -1)
if globals is not None:
concats = jnp.concatenate([concats, globals], -1)
# Messages from Eqs. (3) and (4)/(7)
phi_e = MLP(
[d_hidden] * (n_layers - 1) + [d_hidden],
activation=activation,
activate_final=True,
)
# Super special init for last layer of position MLP (from https://github.com/lucidrains/egnn-pytorch)
phi_x = MLP(
[d_hidden] * (n_layers - 1), activation=activation, activate_final=True
)
phi_x_last_layer = nn.Dense(
1,
use_bias=False,
kernel_init=jax.nn.initializers.variance_scaling(
scale=1e-2, mode="fan_in", distribution="uniform"
),
)
# Relative distances, optionally with Fourier features
EPS = 1e-7
r_ij = x_i - x_j + EPS
if apply_pbc:
r_ij = apply_pbc(r_ij)
d_ij = jnp.sqrt(jnp.sum(r_ij ** 2, axis=1, keepdims=False))
if n_radial_basis > 0:
d_ij = e3nn.bessel(d_ij, n_radial_basis, r_max)
d_ij = d_ij.reshape(d_ij.shape[0], -1)
# Get invariants
message_scalars = d_ij
if concats is not None:
message_scalars = jnp.concatenate([message_scalars, concats], axis=-1)
m_ij = phi_e(message_scalars)
# Optionally apply soft_edges
phi_a = MLP([d_hidden], activation=activation)
if soft_edges:
a_ij = phi_a(m_ij)
a_ij = nn.sigmoid(a_ij)
m_ij = m_ij * a_ij
trans = phi_x_last_layer(phi_x(m_ij))
# Optionally apply tanh to pre-factor to stabilize
if tanh_out:
trans = jax.nn.tanh(trans)
x_ij_trans = r_ij * trans
x_ij_trans = jnp.clip(x_ij_trans, -100.0, 100.0) # From original code
return x_ij_trans, m_ij
return update_fn
def get_node_mlp_updates(
d_hidden,
n_layers,
activation,
position_only=False,
decouple_pos_vel_updates=False,
apply_pbc=None
) -> Callable:
def update_fn(
nodes: jnp.array, senders: jnp.array, receivers: jnp.array, globals: jnp.array
) -> jnp.array:
"""update edge features
Args:
edges (jnp.ndarray): edge attributes
senders (jnp.ndarray): senders node attributes
receivers (jnp.ndarray): receivers node attributes
globals (jnp.ndarray): global features
Returns:
jnp.ndarray: updated edge features
"""
sum_x_ij, m_i = receivers # Get aggregated messages
if position_only: # Positions only; no velocities
if nodes.shape[-1] == 3: # No scalar attributes
x_i = nodes
x_i_p = x_i + sum_x_ij
if apply_pbc:
x_i_p = apply_pbc(x_i_p)
return x_i_p
else: # Additional scalar attributes
x_i, h_i = nodes[..., :3], nodes[..., 3:]
x_i_p = x_i + sum_x_ij
if apply_pbc:
x_i_p = apply_pbc(x_i_p)
phi_h = MLP(
[d_hidden] * (n_layers - 1) + [h_i.shape[-1]], activation=activation
)
concats = jnp.concatenate([h_i, m_i], -1)
if globals is not None:
concats = jnp.concatenate([concats, globals], -1)
h_i_p = h_i + phi_h(concats)
return jnp.concatenate([x_i_p, h_i_p], -1)
else: # Positions and velocities
if nodes.shape[-1] == 6: # No scalar attributes
x_i, v_i = nodes[:, :3], nodes[:, 3:6] # Split node attrs
# From Eqs. (6) and (7)
phi_v = MLP([d_hidden] * (n_layers - 1) + [1], activation=activation)
# Create some scalar attributes for use in future rounds
# Necessary for stability
concats = m_i
if globals is not None:
concats = jnp.concatenate([concats, globals], -1)
# Apply updates
v_i_p = phi_v(concats) * v_i + sum_x_ij
if decouple_pos_vel_updates:
# Decouple position and velocity updates
phi_xx = MLP(
[d_hidden] * (n_layers - 1) + [1], activation=activation
)
x_i_p = phi_xx(concats) * x_i + v_i_p
else:
# Assumes dynamical system with coupled position and velocity updates, as in original paper!
x_i_p = x_i + v_i_p
if apply_pbc:
x_i_p = apply_pbc(x_i_p)
return jnp.concatenate([x_i_p, v_i, concats], -1)
else: # Additional scalar attributes
x_i, v_i, h_i = (
nodes[:, :3],
nodes[:, 3:6],
nodes[:, 6:],
) # Split node attrs
# From Eqs. (6) and (7)
phi_v = MLP([d_hidden] * (n_layers - 1) + [1], activation=activation)
phi_h = MLP(
[d_hidden] * (n_layers - 1) + [h_i.shape[-1]], activation=activation
)
concats = h_i
if globals is not None:
concats = jnp.concatenate([concats, globals], -1)
# Apply updates
v_i_p = phi_v(concats) * v_i + sum_x_ij
if decouple_pos_vel_updates:
# Decouple position and velocity updates
phi_xx = MLP(
[d_hidden] * (n_layers - 1) + [1], activation=activation
)
x_i_p = phi_xx(concats) * x_i + v_i_p
else:
# Assumes dynamical system with coupled position and velocity updates, as in original paper!
x_i_p = x_i + v_i_p
if apply_pbc:
x_i_p = apply_pbc(x_i_p)
h_i_p = h_i + phi_h(concats) # Skip connection, as in original paper
return jnp.concatenate([x_i_p, v_i, h_i_p], -1)
return update_fn
class EGNN(nn.Module):
"""E(n) Equivariant Graph Neural Network (EGNN) following Satorras et al (2021; https://arxiv.org/abs/2102.09844)."""
# Attributes for all MLPs
message_passing_steps: int = 3
d_hidden: int = 128
n_layers: int = 3
activation: str = "gelu"
soft_edges: bool = True # Scale edges by a learnable function
positions_only: bool = False # (pos, vel, scalars) vs (pos, scalars)
tanh_out: bool = False
n_radial_basis: int = 16 # Number of radial basis functions for distance
r_max: float = 0.3 # Maximum distance for radial basis functions
decouple_pos_vel_updates: bool = False # Use extra MLP to decouple position and velocity updates
message_passing_agg: str = "sum" # "sum", "mean", "max"
readout_agg: str = "mean"
mlp_readout_widths: List[int] = (4, 2, 2) # Factor of d_hidden for global readout MLPs
task: str = "graph" # "graph" or "node"
n_outputs: int = 1 # Number of outputs for graph-level readout
apply_pbc: Callable = None
@nn.compact
def __call__(self, graphs: jraph.GraphsTuple) -> jraph.GraphsTuple:
"""Apply equivariant graph convolutional layers to graph
Args:
graphs (jraph.GraphsTuple): Input graph
Returns:
jraph.GraphsTuple: Updated graph
"""
processed_graphs = graphs
if processed_graphs.globals is not None:
processed_graphs = processed_graphs._replace(
globals=processed_graphs.globals.reshape(1, -1)
)
activation = getattr(nn, self.activation)
if self.message_passing_agg not in ["sum", "mean", "max"]:
raise ValueError(
f"Invalid message passing aggregation function {self.message_passing_agg}"
)
aggregate_edges_for_nodes_fn = getattr(
utils, f"segment_{self.message_passing_agg}"
)
# Apply message-passing rounds
for _ in range(self.message_passing_steps):
# Node and edge update functions
update_node_fn = get_node_mlp_updates(
self.d_hidden,
self.n_layers,
activation,
position_only=self.positions_only,
decouple_pos_vel_updates=self.decouple_pos_vel_updates,
apply_pbc=self.apply_pbc
)
update_edge_fn = get_edge_mlp_updates(
self.d_hidden,
self.n_layers,
activation,
position_only=self.positions_only,
tanh_out=self.tanh_out,
soft_edges=self.soft_edges,
apply_pbc=self.apply_pbc,
n_radial_basis=self.n_radial_basis,
r_max=self.r_max
)
# Instantiate graph network and apply EGCL
graph_net = jraph.GraphNetwork(
update_node_fn=update_node_fn,
update_edge_fn=update_edge_fn,
aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn,
)
processed_graphs = graph_net(processed_graphs)
if self.task == "node":
return processed_graphs
elif self.task == "graph":
# Aggregate residual node features; only use positions, optionally
if self.readout_agg not in ["sum", "mean", "max"]:
raise ValueError(
f"Invalid global aggregation function {self.message_passing_agg}"
)
readout_agg_fn = getattr(jnp, f"{self.readout_agg}")
agg_nodes = readout_agg_fn(processed_graphs.nodes, axis=0)
if processed_graphs.globals is not None:
agg_nodes = jnp.concatenate([agg_nodes, processed_graphs.globals]) #use tpcf
norm = nn.LayerNorm()
agg_nodes = norm(agg_nodes)
# Readout and return
mlp = MLP([
self.mlp_readout_widths[0] * agg_nodes.shape[-1]] + \
[w * self.d_hidden for w in self.mlp_readout_widths[1:]] + \
[self.n_outputs,]
)
out = mlp(agg_nodes)
return out
else:
raise ValueError(f"Invalid task {self.task}")