Skip to content

Commit

Permalink
Only collect hidden edge encodings when necessary (#319)
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret authored Apr 25, 2023
1 parent 76f0e50 commit 0b57382
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
1 change: 1 addition & 0 deletions src/gnn_tracking/models/edge_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
super().__init__()
if residual_kwargs is None:
residual_kwargs = {}
residual_kwargs["collect_hidden_edge_embeds"] = use_intermediate_edge_embeddings
self.relu = nn.ReLU()

self.ec_node_encoder = MLP(
Expand Down
36 changes: 22 additions & 14 deletions src/gnn_tracking/models/resin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _convex_combination(
residue: T,
alpha_residue: float,
) -> T:
"""Helper function for JIT compilation"""
"""Helper function for JIT compilation. Use `convext_combination` instead."""
assert 0 <= alpha_residue <= 1
return alpha_residue * residue + (1 - alpha_residue) * delta

Expand All @@ -48,6 +48,7 @@ def __init__(
layers: list[nn.Module],
*,
alpha: float = 0.5,
collect_hidden_edge_embeds: bool = False,
):
"""Apply a list of layers in sequence with residual connections for the nodes.
This is an abstract base class that does not contain code for the type of
Expand All @@ -59,10 +60,13 @@ def __init__(
Args:
layers: List of layers
alpha: Strength of the node embedding residual connection
collect_hidden_edge_embeds: Whether to collect the edge embeddings from all
layers (can be set to false to save memory)
"""
super().__init__()
self.layers = nn.ModuleList(layers)
self._alpha = alpha
self._collect_hidden_edge_embeds = collect_hidden_edge_embeds

def forward(self, x, edge_index, edge_attr) -> tuple[T, T, list[T]]:
"""Forward pass
Expand All @@ -74,12 +78,12 @@ def forward(self, x, edge_index, edge_attr) -> tuple[T, T, list[T]]:
Returns:
node embedding, edge_embedding, concatenated edge embeddings from all
levels (including ``edge_attr``)
levels (including ``edge_attr``, unless collect_hidden_edges is False)
"""
return self._forward(x, edge_index, edge_attr)

@abstractmethod
def _forward(self, x, edge_index, edge_attr) -> tuple[T, T, list[T]]:
def _forward(self, x, edge_index, edge_attr) -> tuple[T, T, list[T] | None]:
pass


Expand All @@ -90,16 +94,17 @@ def __init__(self, *args, **kwargs):
"""
super().__init__(*args, **kwargs)

def _forward(self, x, edge_index, edge_attr) -> tuple[T, T, list[T]]:
edge_attrs = [edge_attr]
def _forward(self, x, edge_index, edge_attr) -> tuple[T, T, list[T] | None]:
edge_attrs = [edge_attr] if self._collect_hidden_edge_embeds else None
for layer in self.layers:
delta_x, edge_attr = layer(x, edge_index, edge_attr)
x = convex_combination(
delta=relu(delta_x),
residue=x,
alpha_residue=self._alpha,
)
edge_attrs.append(edge_attr)
if self._collect_hidden_edge_embeds:
edge_attrs.append(edge_attr)
return x, edge_attr, edge_attrs


Expand All @@ -123,8 +128,9 @@ def __init__(
add_bn: Add batch norms
**kwargs: Arguments to `ResidualNetwork`
"""
if not len(layers) % 2 == 0:
raise ValueError("Only even number of layers allowed at the moment")
if len(layers) % 2 != 0:
msg = "Only even number of layers allowed at the moment"
raise ValueError(msg)
super().__init__(layers=layers, **kwargs)
_node_batch_norms = []
_edge_batch_norms = []
Expand All @@ -138,8 +144,8 @@ def __init__(
self._node_batch_norms = nn.ModuleList(_node_batch_norms)
self._edge_batch_norms = nn.ModuleList(_edge_batch_norms)

def _forward(self, x, edge_index, edge_attr) -> tuple[T, T, list[T]]:
edge_attrs = [edge_attr]
def _forward(self, x, edge_index, edge_attr) -> tuple[T, T, list[T] | None]:
edge_attrs = [edge_attr] if self._collect_hidden_edge_embeds else None
for i_layer_pair in range(len(self.layers) // 2):
i0 = 2 * i_layer_pair
hidden_x, hidden_edge_attr = self.layers[i0](
Expand All @@ -154,7 +160,8 @@ def _forward(self, x, edge_index, edge_attr) -> tuple[T, T, list[T]]:
relu(self._edge_batch_norms[i1](hidden_edge_attr)),
)
x = convex_combination(delta=delta_x, residue=x, alpha_residue=self._alpha)
edge_attrs.append(edge_attr)
if self._collect_hidden_edge_embeds:
edge_attrs.append(edge_attr)
return x, edge_attr, edge_attrs


Expand All @@ -178,7 +185,7 @@ def __init__(
self._residual_layer = connect_to

def _forward(self, x, edge_index, edge_attr) -> tuple[T, T, list[T]]:
edge_attrs = [edge_attr]
edge_attrs = [edge_attr] if self._collect_hidden_edge_embeds else None
x_residue = None
for i_layer in range(len(self.layers)):
if i_layer == self._residual_layer:
Expand All @@ -187,7 +194,8 @@ def _forward(self, x, edge_index, edge_attr) -> tuple[T, T, list[T]]:
x = convex_combination(
delta=relu(delta_x), residue=x_residue, alpha_residue=self._alpha
)
edge_attrs.append(edge_attr)
if self._collect_hidden_edge_embeds:
edge_attrs.append(edge_attr)
return x, edge_attr, edge_attrs


Expand Down Expand Up @@ -263,5 +271,5 @@ def concat_edge_embeddings_length(self) -> int:
return self.edge_dim * (len(self.network.layers) // 2 + 1)
return self.edge_dim * (len(self.network.layers) + 1)

def forward(self, x, edge_index, edge_attr) -> tuple[T, T, list[T], list[T]]:
def forward(self, x, edge_index, edge_attr) -> tuple[T, T, list[T], list[T] | None]:
return self.network.forward(x, edge_index, edge_attr)

0 comments on commit 0b57382

Please sign in to comment.