Skip to content

Commit

Permalink
Merge pull request #273 from pyt-team/cell_models_input_logic
Browse files Browse the repository at this point in the history
Update the Input Logic of Cell Complex Models
  • Loading branch information
ninamiolane authored Apr 8, 2024
2 parents 8ea75a7 + db96096 commit c349941
Show file tree
Hide file tree
Showing 10 changed files with 234 additions and 176 deletions.
1 change: 1 addition & 0 deletions test/nn/cell/test_can.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_forward(self):
heads=1,
n_layers=2,
att_lift=False,
pooling=True,
).to(device)

x_0 = torch.rand(2, 2)
Expand Down
46 changes: 28 additions & 18 deletions topomodelx/nn/cell/can.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ class CAN(torch.nn.Module):
Number of CAN layers.
att_lift : bool, default=True
Whether to apply a lift the signal from node-level to edge-level input.
pooling : bool, default=False
Whether to apply pooling operation.
k_pool : float, default=0.5
The pooling ratio i.e, the fraction of r-cells to keep after the pooling operation.
**kwargs : optional
Additional arguments CANLayer.
References
----------
Expand All @@ -54,7 +58,9 @@ def __init__(
att_activation=None,
n_layers=2,
att_lift=True,
pooling=False,
k_pool=0.5,
**kwargs,
):
super().__init__()

Expand All @@ -81,6 +87,7 @@ def __init__(
att_activation=att_activation,
aggr_func="sum",
update_func="relu",
**kwargs,
)
)

Expand All @@ -96,23 +103,23 @@ def __init__(
att_activation=att_activation,
aggr_func="sum",
update_func="relu",
**kwargs,
)
)

layers.append(
PoolLayer(
k_pool=k_pool,
in_channels_0=out_channels * heads,
signal_pool_activation=torch.nn.Sigmoid(),
readout=True,
if pooling:
layers.append(
PoolLayer(
k_pool=k_pool,
in_channels_0=out_channels * heads,
signal_pool_activation=torch.nn.Sigmoid(),
readout=True,
**kwargs,
)
)
)

self.layers = torch.nn.ModuleList(layers)

def forward(
self, x_0, x_1, neighborhood_0_to_0, lower_neighborhood, upper_neighborhood
):
def forward(self, x_0, x_1, adjacency_0, down_laplacian_1, up_laplacian_1):
"""Forward pass.
Parameters
Expand All @@ -121,28 +128,31 @@ def forward(
Input features on the nodes (0-cells).
x_1 : torch.Tensor, shape = (n_edges, in_channels_1)
Input features on the edges (1-cells).
neighborhood_0_to_0 : torch.Tensor, shape = (n_nodes, n_nodes)
adjacency_0 : torch.Tensor, shape = (n_nodes, n_nodes)
Neighborhood matrix from nodes to nodes.
lower_neighborhood : torch.Tensor, shape = (-, -)
down_laplacian_1 : torch.Tensor, shape = (-, -)
Lower Neighbourhood matrix.
upper_neighborhood : torch.Tensor, shape = (-, -)
up_laplacian_1 : torch.Tensor, shape = (-, -)
Upper neighbourhood matrix.
Returns
-------
torch.Tensor, shape = (num_pooled_edges, heads * out_channels)
Final hidden representations of pooled edges.
"""
adjacency_0 = adjacency_0.coalesce()
down_laplacian_1 = down_laplacian_1.coalesce()
up_laplacian_1 = up_laplacian_1.coalesce()
if hasattr(self, "lift_layer"):
x_1 = self.lift_layer(x_0, neighborhood_0_to_0, x_1)
x_1 = self.lift_layer(x_0, adjacency_0.coalesce(), x_1)

for layer in self.layers:
if isinstance(layer, PoolLayer):
x_1, lower_neighborhood, upper_neighborhood = layer(
x_1, lower_neighborhood, upper_neighborhood
x_1, down_laplacian_1, up_laplacian_1 = layer(
x_1, down_laplacian_1, up_laplacian_1
)
else:
x_1 = layer(x_1, lower_neighborhood, upper_neighborhood)
x_1 = layer(x_1, down_laplacian_1, up_laplacian_1)
x_1 = F.dropout(x_1, p=0.5, training=self.training)

return x_1
53 changes: 26 additions & 27 deletions topomodelx/nn/cell/can_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ def message(self, x_source, x_target=None):
) # (num_edges, heads)
return self.signal_lift_activation(edge_signal)

def forward(self, x_0, neighborhood_0_to_0) -> torch.Tensor: # type: ignore[override]
def forward(self, x_0, adjacency_0) -> torch.Tensor: # type: ignore[override]
"""Forward pass.
Parameters
----------
x_0 : torch.Tensor, shape = (num_nodes, in_channels_0)
Node signal.
neighborhood_0_to_0 : torch.Tensor, shape = (num_nodes, num_nodes)
adjacency_0 : torch.Tensor, shape = (num_nodes, num_nodes)
Sparse neighborhood matrix.
Returns
Expand All @@ -164,7 +164,7 @@ def forward(self, x_0, neighborhood_0_to_0) -> torch.Tensor: # type: ignore[ove
Edge signal.
"""
# Extract source and target nodes from the graph's edge index
source, target = neighborhood_0_to_0.indices() # (num_edges,)
source, target = adjacency_0.indices() # (num_edges,)

# Extract the node signal of the source and target nodes
x_source = x_0[source] # (num_edges, in_channels_0)
Expand Down Expand Up @@ -228,14 +228,14 @@ def reset_parameters(self) -> None:
"""Reinitialize learnable parameters using Xavier uniform initialization."""
self.lifts.reset_parameters()

def forward(self, x_0, neighborhood_0_to_0, x_1=None) -> torch.Tensor:
def forward(self, x_0, adjacency_0, x_1=None) -> torch.Tensor:
r"""Forward pass.
Parameters
----------
x_0 : torch.Tensor, shape = (num_nodes, in_channels_0)
Node signal.
neighborhood_0_to_0 : torch.Tensor, shape = (2, num_edges)
adjacency_0 : torch.Tensor, shape = (2, num_edges)
Edge index.
x_1 : torch.Tensor, shape = (num_edges, in_channels_1), optional
Edge signal.
Expand All @@ -256,7 +256,7 @@ def forward(self, x_0, neighborhood_0_to_0, x_1=None) -> torch.Tensor:
\end{align*}
"""
# Lift the node signal for each attention head
attention_heads_x_1 = self.lifts(x_0, neighborhood_0_to_0)
attention_heads_x_1 = self.lifts(x_0, adjacency_0)

# Combine the output edge signals using the specified readout strategy
readout_methods = {
Expand Down Expand Up @@ -323,17 +323,17 @@ def reset_parameters(self) -> None:
init.xavier_uniform_(self.att_pool.data, gain=gain)

def forward( # type: ignore[override]
self, x, lower_neighborhood, upper_neighborhood
self, x, down_laplacian_1, up_laplacian_1
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Forward pass.
Parameters
----------
x : torch.Tensor, shape = (n_r_cells, in_channels_r)
Input r-cell signal.
lower_neighborhood : torch.Tensor
down_laplacian_1 : torch.Tensor
Lower neighborhood matrix.
upper_neighborhood : torch.Tensor
up_laplacian_1 : torch.Tensor
Upper neighbourhood matrix.
Returns
Expand Down Expand Up @@ -364,23 +364,19 @@ def forward( # type: ignore[override]
out = scatter_add(out, top_indices, dim=0, dim_size=x.size(0))[top_indices]

# Update lower and upper neighborhood matrices with the top-k pooled r-cells
lower_neighborhood_modified = torch.index_select(
lower_neighborhood, 0, top_indices
down_laplacian_1_modified = torch.index_select(down_laplacian_1, 0, top_indices)
down_laplacian_1_modified = torch.index_select(
down_laplacian_1_modified, 1, top_indices
)
lower_neighborhood_modified = torch.index_select(
lower_neighborhood_modified, 1, top_indices
)
upper_neighborhood_modified = torch.index_select(
upper_neighborhood, 0, top_indices
)
upper_neighborhood_modified = torch.index_select(
upper_neighborhood_modified, 1, top_indices
up_laplacian_1_modified = torch.index_select(up_laplacian_1, 0, top_indices)
up_laplacian_1_modified = torch.index_select(
up_laplacian_1_modified, 1, top_indices
)
# return sparse matrices of neighborhood
return (
out,
lower_neighborhood_modified.to_sparse().float().coalesce(),
upper_neighborhood_modified.to_sparse().float().coalesce(),
down_laplacian_1_modified.to_sparse().float().coalesce(),
up_laplacian_1_modified.to_sparse().float().coalesce(),
)


Expand Down Expand Up @@ -805,6 +801,8 @@ class CANLayer(torch.nn.Module):
Version of the layer, by default "v1" which is the same as the original CAN layer. While "v2" has the same attetion mechanism as the GATv2 layer.
share_weights : bool, default=False
This option is valid only for "v2". If True, the weights of the linear transformation applied to the source and target features are shared, by default False.
**kwargs : optional
Additional arguments of CAN layer.
Notes
-----
Expand All @@ -823,11 +821,12 @@ def __init__(
concat: bool = True,
skip_connection: bool = True,
att_activation: torch.nn.Module | None = None,
add_self_loops: bool = False,
add_self_loops: bool = True,
aggr_func: Literal["mean", "sum"] = "sum",
update_func: Literal["relu", "sigmoid", "tanh"] | None = "relu",
version: Literal["v1", "v2"] = "v1",
share_weights: bool = False,
**kwargs,
) -> None:
super().__init__()

Expand Down Expand Up @@ -910,16 +909,16 @@ def reset_parameters(self) -> None:
if hasattr(self, "lin"):
self.lin.reset_parameters()

def forward(self, x, lower_neighborhood, upper_neighborhood) -> torch.Tensor:
def forward(self, x, down_laplacian_1, up_laplacian_1) -> torch.Tensor:
r"""Forward pass.
Parameters
----------
x : torch.Tensor, shape = (n_k_cells, channels)
Input features on the r-cell of the cell complex.
lower_neighborhood : torch.sparse, shape = (n_k_cells, n_k_cells)
down_laplacian_1 : torch.sparse, shape = (n_k_cells, n_k_cells)
Lower neighborhood matrix mapping r-cells to r-cells (A_k_low).
upper_neighborhood : torch.sparse, shape = (n_k_cells, n_k_cells)
up_laplacian_1 : torch.sparse, shape = (n_k_cells, n_k_cells)
Upper neighborhood matrix mapping r-cells to r-cells (A_k_up).
Returns
Expand All @@ -945,8 +944,8 @@ def forward(self, x, lower_neighborhood, upper_neighborhood) -> torch.Tensor:
\end{align*}
"""
# message and within-neighborhood aggregation
lower_x = self.lower_att(x, lower_neighborhood)
upper_x = self.upper_att(x, upper_neighborhood)
lower_x = self.lower_att(x, down_laplacian_1)
upper_x = self.upper_att(x, up_laplacian_1)

# skip connection
if hasattr(self, "lin"):
Expand Down
12 changes: 8 additions & 4 deletions topomodelx/nn/cell/ccxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class CCXN(torch.nn.Module):
Number of CCXN layers.
att : bool
Whether to use attention.
**kwargs : optional
Additional arguments CCXNLayer.
References
----------
Expand All @@ -36,6 +38,7 @@ def __init__(
in_channels_2,
n_layers=2,
att=False,
**kwargs,
):
super().__init__()

Expand All @@ -45,11 +48,12 @@ def __init__(
in_channels_1=in_channels_1,
in_channels_2=in_channels_2,
att=att,
**kwargs,
)
for _ in range(n_layers)
)

def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
def forward(self, x_0, x_1, adjacency_0, incidence_2_t):
"""Forward computation through layers.
Parameters
Expand All @@ -58,9 +62,9 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
Input features on the nodes (0-cells).
x_1 : torch.Tensor, shape = (n_edges, in_channels_1)
Input features on the edges (1-cells).
neighborhood_0_to_0 : torch.Tensor, shape = (n_nodes, n_nodes)
adjacency_0 : torch.Tensor, shape = (n_nodes, n_nodes)
Adjacency matrix of rank 0 (up).
neighborhood_1_to_2 : torch.Tensor, shape = (n_faces, n_edges)
incidence_2_t : torch.Tensor, shape = (n_faces, n_edges)
Transpose of boundary matrix of rank 2.
Returns
Expand All @@ -73,5 +77,5 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
Final hidden states of the faces (2-cells).
"""
for layer in self.layers:
x_0, x_1, x_2 = layer(x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2)
x_0, x_1, x_2 = layer(x_0, x_1, adjacency_0, incidence_2_t)
return (x_0, x_1, x_2)
18 changes: 8 additions & 10 deletions topomodelx/nn/cell/ccxn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ class CCXNLayer(torch.nn.Module):
Dimension of input features on faces (2-cells).
att : bool, default=False
Whether to use attention.
Notes
-----
This is the architecture proposed for entire complex classification.
**kwargs : optional
Additional arguments for the modules of the CCXN layer.
References
----------
Expand All @@ -45,7 +43,7 @@ class CCXNLayer(torch.nn.Module):
"""

def __init__(
self, in_channels_0, in_channels_1, in_channels_2, att: bool = False
self, in_channels_0, in_channels_1, in_channels_2, att: bool = False, **kwargs
) -> None:
super().__init__()
self.conv_0_to_0 = Conv(
Expand All @@ -55,7 +53,7 @@ def __init__(
in_channels=in_channels_1, out_channels=in_channels_2, att=att
)

def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2, x_2=None):
def forward(self, x_0, x_1, adjacency_0, incidence_2_t, x_2=None):
r"""Forward pass.
The forward pass was initially proposed in [1]_.
Expand Down Expand Up @@ -97,9 +95,9 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2, x_2=None):
Input features on the nodes of the cell complex.
x_1 : torch.Tensor, shape = (n_1_cells, channels)
Input features on the edges of the cell complex.
neighborhood_0_to_0 : torch.sparse, shape = (n_0_cells, n_0_cells)
adjacency_0 : torch.sparse, shape = (n_0_cells, n_0_cells)
Neighborhood matrix mapping nodes to nodes (A_0_up).
neighborhood_1_to_2 : torch.sparse, shape = (n_2_cells, n_1_cells)
incidence_2_t : torch.sparse, shape = (n_2_cells, n_1_cells)
Neighborhood matrix mapping edges to faces (B_2^T).
x_2 : torch.Tensor, shape = (n_2_cells, channels)
Input features on the faces of the cell complex.
Expand All @@ -113,10 +111,10 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2, x_2=None):
x_0 = torch.nn.functional.relu(x_0)
x_1 = torch.nn.functional.relu(x_1)

x_0 = self.conv_0_to_0(x_0, neighborhood_0_to_0)
x_0 = self.conv_0_to_0(x_0, adjacency_0)
x_0 = torch.nn.functional.relu(x_0)

x_2 = self.conv_1_to_2(x_1, neighborhood_1_to_2, x_2)
x_2 = self.conv_1_to_2(x_1, incidence_2_t, x_2)
x_2 = torch.nn.functional.relu(x_2)

return x_0, x_1, x_2
Loading

0 comments on commit c349941

Please sign in to comment.