From aa37d6faf00ff0cec8b94e845096a8608bdeb102 Mon Sep 17 00:00:00 2001 From: gbg141 Date: Thu, 4 Apr 2024 19:54:38 +0200 Subject: [PATCH 1/4] CAN updated --- topomodelx/nn/cell/can.py | 46 ++++++---- topomodelx/nn/cell/can_layer.py | 53 ++++++------ tutorials/cell/can_train.ipynb | 144 +++++++++++++++++--------------- 3 files changed, 131 insertions(+), 112 deletions(-) diff --git a/topomodelx/nn/cell/can.py b/topomodelx/nn/cell/can.py index 582b2aab..a9e80a3e 100644 --- a/topomodelx/nn/cell/can.py +++ b/topomodelx/nn/cell/can.py @@ -31,8 +31,12 @@ class CAN(torch.nn.Module): Number of CAN layers. att_lift : bool, default=True Whether to apply a lift the signal from node-level to edge-level input. + pooling : bool, default=False + Whether to apply pooling operation. k_pool : float, default=0.5 The pooling ratio i.e, the fraction of r-cells to keep after the pooling operation. + **kwargs : optional + Additional arguments CANLayer. References ---------- @@ -54,7 +58,9 @@ def __init__( att_activation=None, n_layers=2, att_lift=True, + pooling=False, k_pool=0.5, + **kwargs, ): super().__init__() @@ -81,6 +87,7 @@ def __init__( att_activation=att_activation, aggr_func="sum", update_func="relu", + **kwargs, ) ) @@ -96,23 +103,23 @@ def __init__( att_activation=att_activation, aggr_func="sum", update_func="relu", + **kwargs, ) ) - - layers.append( - PoolLayer( - k_pool=k_pool, - in_channels_0=out_channels * heads, - signal_pool_activation=torch.nn.Sigmoid(), - readout=True, + if pooling: + layers.append( + PoolLayer( + k_pool=k_pool, + in_channels_0=out_channels * heads, + signal_pool_activation=torch.nn.Sigmoid(), + readout=True, + **kwargs, + ) ) - ) self.layers = torch.nn.ModuleList(layers) - def forward( - self, x_0, x_1, neighborhood_0_to_0, lower_neighborhood, upper_neighborhood - ): + def forward(self, x_0, x_1, adjacency_0, down_laplacian_1, up_laplacian_1): """Forward pass. Parameters @@ -121,11 +128,11 @@ def forward( Input features on the nodes (0-cells). x_1 : torch.Tensor, shape = (n_edges, in_channels_1) Input features on the edges (1-cells). - neighborhood_0_to_0 : torch.Tensor, shape = (n_nodes, n_nodes) + adjacency_0 : torch.Tensor, shape = (n_nodes, n_nodes) Neighborhood matrix from nodes to nodes. - lower_neighborhood : torch.Tensor, shape = (-, -) + down_laplacian_1 : torch.Tensor, shape = (-, -) Lower Neighbourhood matrix. - upper_neighborhood : torch.Tensor, shape = (-, -) + up_laplacian_1 : torch.Tensor, shape = (-, -) Upper neighbourhood matrix. Returns @@ -133,16 +140,19 @@ def forward( torch.Tensor, shape = (num_pooled_edges, heads * out_channels) Final hidden representations of pooled edges. """ + adjacency_0 = adjacency_0.coalesce() + down_laplacian_1 = down_laplacian_1.coalesce() + up_laplacian_1 = up_laplacian_1.coalesce() if hasattr(self, "lift_layer"): - x_1 = self.lift_layer(x_0, neighborhood_0_to_0, x_1) + x_1 = self.lift_layer(x_0, adjacency_0.coalesce(), x_1) for layer in self.layers: if isinstance(layer, PoolLayer): - x_1, lower_neighborhood, upper_neighborhood = layer( - x_1, lower_neighborhood, upper_neighborhood + x_1, down_laplacian_1, up_laplacian_1 = layer( + x_1, down_laplacian_1, up_laplacian_1 ) else: - x_1 = layer(x_1, lower_neighborhood, upper_neighborhood) + x_1 = layer(x_1, down_laplacian_1, up_laplacian_1) x_1 = F.dropout(x_1, p=0.5, training=self.training) return x_1 diff --git a/topomodelx/nn/cell/can_layer.py b/topomodelx/nn/cell/can_layer.py index ab3f59d6..912a0298 100644 --- a/topomodelx/nn/cell/can_layer.py +++ b/topomodelx/nn/cell/can_layer.py @@ -148,14 +148,14 @@ def message(self, x_source, x_target=None): ) # (num_edges, heads) return self.signal_lift_activation(edge_signal) - def forward(self, x_0, neighborhood_0_to_0) -> torch.Tensor: # type: ignore[override] + def forward(self, x_0, adjacency_0) -> torch.Tensor: # type: ignore[override] """Forward pass. Parameters ---------- x_0 : torch.Tensor, shape = (num_nodes, in_channels_0) Node signal. - neighborhood_0_to_0 : torch.Tensor, shape = (num_nodes, num_nodes) + adjacency_0 : torch.Tensor, shape = (num_nodes, num_nodes) Sparse neighborhood matrix. Returns @@ -164,7 +164,7 @@ def forward(self, x_0, neighborhood_0_to_0) -> torch.Tensor: # type: ignore[ove Edge signal. """ # Extract source and target nodes from the graph's edge index - source, target = neighborhood_0_to_0.indices() # (num_edges,) + source, target = adjacency_0.indices() # (num_edges,) # Extract the node signal of the source and target nodes x_source = x_0[source] # (num_edges, in_channels_0) @@ -228,14 +228,14 @@ def reset_parameters(self) -> None: """Reinitialize learnable parameters using Xavier uniform initialization.""" self.lifts.reset_parameters() - def forward(self, x_0, neighborhood_0_to_0, x_1=None) -> torch.Tensor: + def forward(self, x_0, adjacency_0, x_1=None) -> torch.Tensor: r"""Forward pass. Parameters ---------- x_0 : torch.Tensor, shape = (num_nodes, in_channels_0) Node signal. - neighborhood_0_to_0 : torch.Tensor, shape = (2, num_edges) + adjacency_0 : torch.Tensor, shape = (2, num_edges) Edge index. x_1 : torch.Tensor, shape = (num_edges, in_channels_1), optional Edge signal. @@ -256,7 +256,7 @@ def forward(self, x_0, neighborhood_0_to_0, x_1=None) -> torch.Tensor: \end{align*} """ # Lift the node signal for each attention head - attention_heads_x_1 = self.lifts(x_0, neighborhood_0_to_0) + attention_heads_x_1 = self.lifts(x_0, adjacency_0) # Combine the output edge signals using the specified readout strategy readout_methods = { @@ -323,7 +323,7 @@ def reset_parameters(self) -> None: init.xavier_uniform_(self.att_pool.data, gain=gain) def forward( # type: ignore[override] - self, x, lower_neighborhood, upper_neighborhood + self, x, down_laplacian_1, up_laplacian_1 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Forward pass. @@ -331,9 +331,9 @@ def forward( # type: ignore[override] ---------- x : torch.Tensor, shape = (n_r_cells, in_channels_r) Input r-cell signal. - lower_neighborhood : torch.Tensor + down_laplacian_1 : torch.Tensor Lower neighborhood matrix. - upper_neighborhood : torch.Tensor + up_laplacian_1 : torch.Tensor Upper neighbourhood matrix. Returns @@ -364,23 +364,19 @@ def forward( # type: ignore[override] out = scatter_add(out, top_indices, dim=0, dim_size=x.size(0))[top_indices] # Update lower and upper neighborhood matrices with the top-k pooled r-cells - lower_neighborhood_modified = torch.index_select( - lower_neighborhood, 0, top_indices + down_laplacian_1_modified = torch.index_select(down_laplacian_1, 0, top_indices) + down_laplacian_1_modified = torch.index_select( + down_laplacian_1_modified, 1, top_indices ) - lower_neighborhood_modified = torch.index_select( - lower_neighborhood_modified, 1, top_indices - ) - upper_neighborhood_modified = torch.index_select( - upper_neighborhood, 0, top_indices - ) - upper_neighborhood_modified = torch.index_select( - upper_neighborhood_modified, 1, top_indices + up_laplacian_1_modified = torch.index_select(up_laplacian_1, 0, top_indices) + up_laplacian_1_modified = torch.index_select( + up_laplacian_1_modified, 1, top_indices ) # return sparse matrices of neighborhood return ( out, - lower_neighborhood_modified.to_sparse().float().coalesce(), - upper_neighborhood_modified.to_sparse().float().coalesce(), + down_laplacian_1_modified.to_sparse().float().coalesce(), + up_laplacian_1_modified.to_sparse().float().coalesce(), ) @@ -805,6 +801,8 @@ class CANLayer(torch.nn.Module): Version of the layer, by default "v1" which is the same as the original CAN layer. While "v2" has the same attetion mechanism as the GATv2 layer. share_weights : bool, default=False This option is valid only for "v2". If True, the weights of the linear transformation applied to the source and target features are shared, by default False. + **kwargs : optional + Additional arguments of CAN layer. Notes ----- @@ -823,11 +821,12 @@ def __init__( concat: bool = True, skip_connection: bool = True, att_activation: torch.nn.Module | None = None, - add_self_loops: bool = False, + add_self_loops: bool = True, aggr_func: Literal["mean", "sum"] = "sum", update_func: Literal["relu", "sigmoid", "tanh"] | None = "relu", version: Literal["v1", "v2"] = "v1", share_weights: bool = False, + **kwargs, ) -> None: super().__init__() @@ -910,16 +909,16 @@ def reset_parameters(self) -> None: if hasattr(self, "lin"): self.lin.reset_parameters() - def forward(self, x, lower_neighborhood, upper_neighborhood) -> torch.Tensor: + def forward(self, x, down_laplacian_1, up_laplacian_1) -> torch.Tensor: r"""Forward pass. Parameters ---------- x : torch.Tensor, shape = (n_k_cells, channels) Input features on the r-cell of the cell complex. - lower_neighborhood : torch.sparse, shape = (n_k_cells, n_k_cells) + down_laplacian_1 : torch.sparse, shape = (n_k_cells, n_k_cells) Lower neighborhood matrix mapping r-cells to r-cells (A_k_low). - upper_neighborhood : torch.sparse, shape = (n_k_cells, n_k_cells) + up_laplacian_1 : torch.sparse, shape = (n_k_cells, n_k_cells) Upper neighborhood matrix mapping r-cells to r-cells (A_k_up). Returns @@ -945,8 +944,8 @@ def forward(self, x, lower_neighborhood, upper_neighborhood) -> torch.Tensor: \end{align*} """ # message and within-neighborhood aggregation - lower_x = self.lower_att(x, lower_neighborhood) - upper_x = self.upper_att(x, upper_neighborhood) + lower_x = self.lower_att(x, down_laplacian_1) + upper_x = self.upper_att(x, up_laplacian_1) # skip connection if hasattr(self, "lin"): diff --git a/tutorials/cell/can_train.ipynb b/tutorials/cell/can_train.ipynb index b50658bc..122dd622 100644 --- a/tutorials/cell/can_train.ipynb +++ b/tutorials/cell/can_train.ipynb @@ -104,14 +104,25 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:06:36.009880829Z", "start_time": "2023-05-31T09:06:34.285257706Z" } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import numpy as np\n", "import torch\n", @@ -135,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:53.006542411Z", @@ -175,7 +186,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:55.279147916Z", @@ -183,6 +194,15 @@ } }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip\n", + "Processing...\n", + "Done!\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -233,7 +253,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:55.832585216Z", @@ -242,8 +262,8 @@ }, "outputs": [], "source": [ - "lower_neighborhood_list = []\n", - "upper_neighborhood_list = []\n", + "down_laplacian_list = []\n", + "up_laplacian_list = []\n", "adjacency_0_list = []\n", "\n", "for cell_complex in cc_list:\n", @@ -251,20 +271,20 @@ " adjacency_0 = torch.from_numpy(adjacency_0.todense()).to_sparse()\n", " adjacency_0_list.append(adjacency_0)\n", "\n", - " lower_neighborhood_t = cell_complex.down_laplacian_matrix(rank=1)\n", - " lower_neighborhood_t = from_sparse(lower_neighborhood_t)\n", - " lower_neighborhood_list.append(lower_neighborhood_t)\n", + " down_laplacian_t = cell_complex.down_laplacian_matrix(rank=1)\n", + " down_laplacian_t = from_sparse(down_laplacian_t)\n", + " down_laplacian_list.append(down_laplacian_t)\n", "\n", " try:\n", - " upper_neighborhood_t = cell_complex.up_laplacian_matrix(rank=1)\n", - " upper_neighborhood_t = from_sparse(upper_neighborhood_t)\n", + " up_laplacian_t = cell_complex.up_laplacian_matrix(rank=1)\n", + " up_laplacian_t = from_sparse(up_laplacian_t)\n", " except ValueError:\n", - " upper_neighborhood_t = np.zeros(\n", - " (lower_neighborhood_t.shape[0], lower_neighborhood_t.shape[0])\n", + " up_laplacian_t = np.zeros(\n", + " (down_laplacian_t.shape[0], down_laplacian_t.shape[0])\n", " )\n", - " upper_neighborhood_t = torch.from_numpy(upper_neighborhood_t).to_sparse()\n", + " up_laplacian_t = torch.from_numpy(up_laplacian_t).to_sparse()\n", "\n", - " upper_neighborhood_list.append(upper_neighborhood_t)" + " up_laplacian_list.append(up_laplacian_t)" ] }, { @@ -278,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -307,8 +327,8 @@ " self.lin_0 = torch.nn.Linear(heads * out_channels, 128)\n", " self.lin_1 = torch.nn.Linear(128, num_classes)\n", "\n", - " def forward(self, x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood):\n", - " x = self.base_model(x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood)\n", + " def forward(self, x_0, x_1, adjacency, down_laplacian, up_laplacian):\n", + " x = self.base_model(x_0, x_1, adjacency, down_laplacian, up_laplacian)\n", " # max pooling over edges in each graph\n", " x = x.max(dim=0)[0]\n", " # Feed-Foward Neural Network to predict the graph label\n", @@ -318,7 +338,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:56.672913561Z", @@ -358,7 +378,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:40.411845803Z", @@ -399,9 +419,6 @@ " (lin): Linear(in_features=64, out_features=64, bias=False)\n", " (aggregation): Aggregation()\n", " )\n", - " (2): PoolLayer(\n", - " (signal_pool_activation): Sigmoid()\n", - " )\n", " )\n", " )\n", " (lin_0): Linear(in_features=64, out_features=128, bias=True)\n", @@ -409,7 +426,7 @@ ")" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -429,7 +446,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:41.150933630Z", @@ -441,11 +458,11 @@ "test_size = 0.3\n", "x_1_train, x_1_test = train_test_split(x_1_list, test_size=test_size, shuffle=False)\n", "x_0_train, x_0_test = train_test_split(x_0_list, test_size=test_size, shuffle=False)\n", - "lower_neighborhood_train, lower_neighborhood_test = train_test_split(\n", - " lower_neighborhood_list, test_size=test_size, shuffle=False\n", + "down_laplacian_train, down_laplacian_test = train_test_split(\n", + " down_laplacian_list, test_size=test_size, shuffle=False\n", ")\n", - "upper_neighborhood_train, upper_neighborhood_test = train_test_split(\n", - " upper_neighborhood_list, test_size=test_size, shuffle=False\n", + "up_laplacian_train, up_laplacian_test = train_test_split(\n", + " up_laplacian_list, test_size=test_size, shuffle=False\n", ")\n", "adjacency_0_train, adjacency_0_test = train_test_split(\n", " adjacency_0_list, test_size=test_size, shuffle=False\n", @@ -462,7 +479,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:42.918836083Z", @@ -474,26 +491,26 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 1 loss: 0.6200 Train_acc: 0.6947\n", + "Epoch: 1 loss: 0.6159 Train_acc: 0.6947\n", "Test_acc: 0.5965\n", - "Epoch: 2 loss: 0.6110 Train_acc: 0.6947\n", + "Epoch: 2 loss: 0.6099 Train_acc: 0.6947\n", "Test_acc: 0.5965\n", - "Epoch: 3 loss: 0.6054 Train_acc: 0.6947\n", + "Epoch: 3 loss: 0.6035 Train_acc: 0.6947\n", "Test_acc: 0.5965\n", - "Epoch: 4 loss: 0.5990 Train_acc: 0.6947\n", + "Epoch: 4 loss: 0.5966 Train_acc: 0.7176\n", "Test_acc: 0.6316\n", - "Epoch: 5 loss: 0.6021 Train_acc: 0.7099\n", + "Epoch: 5 loss: 0.5909 Train_acc: 0.7252\n", + "Test_acc: 0.6491\n", + "Epoch: 6 loss: 0.5983 Train_acc: 0.7099\n", "Test_acc: 0.6316\n", - "Epoch: 6 loss: 0.5911 Train_acc: 0.7252\n", + "Epoch: 7 loss: 0.5884 Train_acc: 0.7252\n", + "Test_acc: 0.6491\n", + "Epoch: 8 loss: 0.5909 Train_acc: 0.7176\n", "Test_acc: 0.6316\n", - "Epoch: 7 loss: 0.5889 Train_acc: 0.7176\n", + "Epoch: 9 loss: 0.5818 Train_acc: 0.7252\n", "Test_acc: 0.6316\n", - "Epoch: 8 loss: 0.5829 Train_acc: 0.7252\n", - "Test_acc: 0.6842\n", - "Epoch: 9 loss: 0.5786 Train_acc: 0.7252\n", - "Test_acc: 0.6491\n", - "Epoch: 10 loss: 0.5746 Train_acc: 0.7328\n", - "Test_acc: 0.6842\n" + "Epoch: 10 loss: 0.5879 Train_acc: 0.7252\n", + "Test_acc: 0.6316\n" ] } ], @@ -505,24 +522,24 @@ " num_samples = 0\n", " correct = 0\n", " model.train()\n", - " for x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood, y in zip(\n", + " for x_0, x_1, adjacency, down_laplacian, up_laplacian, y in zip(\n", " x_0_train,\n", " x_1_train,\n", " adjacency_0_train,\n", - " lower_neighborhood_train,\n", - " upper_neighborhood_train,\n", + " down_laplacian_train,\n", + " up_laplacian_train,\n", " y_train,\n", " strict=True,\n", " ):\n", " x_0 = x_0.float().to(device)\n", " x_1, y = x_1.float().to(device), torch.tensor(y, dtype=torch.long).to(device)\n", " adjacency = adjacency.float().to(device)\n", - " lower_neighborhood, upper_neighborhood = (\n", - " lower_neighborhood.float().to(device),\n", - " upper_neighborhood.float().to(device),\n", + " down_laplacian, up_laplacian = (\n", + " down_laplacian.float().to(device),\n", + " up_laplacian.float().to(device),\n", " )\n", " opt.zero_grad()\n", - " y_hat = model(x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood)\n", + " y_hat = model(x_0, x_1, adjacency, down_laplacian, up_laplacian)\n", " loss = crit(y_hat, y)\n", " correct += (y_hat.argmax() == y).sum().item()\n", " num_samples += 1\n", @@ -538,12 +555,12 @@ " with torch.no_grad():\n", " num_samples = 0\n", " correct = 0\n", - " for x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood, y in zip(\n", + " for x_0, x_1, adjacency, down_laplacian, up_laplacian, y in zip(\n", " x_0_test,\n", " x_1_test,\n", " adjacency_0_test,\n", - " lower_neighborhood_test,\n", - " upper_neighborhood_test,\n", + " down_laplacian_test,\n", + " up_laplacian_test,\n", " y_test,\n", " strict=True,\n", " ):\n", @@ -553,13 +570,11 @@ " torch.tensor(y, dtype=torch.long).to(device),\n", " )\n", " adjacency = adjacency.float().to(device)\n", - " lower_neighborhood, upper_neighborhood = (\n", - " lower_neighborhood.float().to(device),\n", - " upper_neighborhood.float().to(device),\n", - " )\n", - " y_hat = model(\n", - " x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood\n", + " down_laplacian, up_laplacian = (\n", + " down_laplacian.float().to(device),\n", + " up_laplacian.float().to(device),\n", " )\n", + " y_hat = model(x_0, x_1, adjacency, down_laplacian, up_laplacian)\n", " correct += (y_hat.argmax() == y).sum().item()\n", " num_samples += 1\n", " test_acc = correct / num_samples\n", @@ -576,9 +591,9 @@ ], "metadata": { "kernelspec": { - "display_name": "venv_modelx", + "display_name": "venv_tmx", "language": "python", - "name": "venv_modelx" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -591,11 +606,6 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" - }, - "vscode": { - "interpreter": { - "hash": "97e7f600578393f7b22fad5e1bb04e54aa849deabd28651fd7e27af1b0c8a034" - } } }, "nbformat": 4, From 8d1c1d4989529fc1f925fcef7fb6e39297938b7c Mon Sep 17 00:00:00 2001 From: gbg141 Date: Thu, 4 Apr 2024 20:19:21 +0200 Subject: [PATCH 2/4] CWN Update --- topomodelx/nn/cell/cwn.py | 22 ++++++----- topomodelx/nn/cell/cwn_layer.py | 68 ++++++++++++++++++++++----------- tutorials/cell/cwn_train.ipynb | 38 +++++++++++------- 3 files changed, 83 insertions(+), 45 deletions(-) diff --git a/topomodelx/nn/cell/cwn.py b/topomodelx/nn/cell/cwn.py index 71bb193e..319d1325 100644 --- a/topomodelx/nn/cell/cwn.py +++ b/topomodelx/nn/cell/cwn.py @@ -21,6 +21,8 @@ class CWN(torch.nn.Module): Dimension of hidden features. n_layers : int Number of CWN layers. + **kwargs : optional + Additional arguments CWNLayer. References ---------- @@ -37,6 +39,7 @@ def __init__( in_channels_2, hid_channels, n_layers, + **kwargs, ): super().__init__() @@ -50,6 +53,7 @@ def __init__( in_channels_1=hid_channels, in_channels_2=hid_channels, out_channels=hid_channels, + **kwargs, ) for _ in range(n_layers) ) @@ -59,9 +63,9 @@ def forward( x_0, x_1, x_2, - neighborhood_1_to_1, - neighborhood_2_to_1, - neighborhood_0_to_1, + adjacency_0, + incidence_2, + incidence_1_t, ): """Forward computation through projection, convolutions, linear layers and average pooling. @@ -73,11 +77,11 @@ def forward( Input features on the edges (1-cells). x_2 : torch.Tensor, shape = (n_faces, in_channels_2) Input features on the faces (2-cells). - neighborhood_1_to_1 : torch.Tensor, shape = (n_edges, n_edges) + adjacency_0 : torch.Tensor, shape = (n_edges, n_edges) Upper-adjacency matrix of rank 1. - neighborhood_2_to_1 : torch.Tensor, shape = (n_edges, n_faces) + incidence_2 : torch.Tensor, shape = (n_edges, n_faces) Boundary matrix of rank 2. - neighborhood_0_to_1 : torch.Tensor, shape = (n_edges, n_nodes) + incidence_1_t : torch.Tensor, shape = (n_edges, n_nodes) Coboundary matrix of rank 1. Returns @@ -98,9 +102,9 @@ def forward( x_0, x_1, x_2, - neighborhood_1_to_1, - neighborhood_2_to_1, - neighborhood_0_to_1, + adjacency_0, + incidence_2, + incidence_1_t, ) return x_0, x_1, x_2 diff --git a/topomodelx/nn/cell/cwn_layer.py b/topomodelx/nn/cell/cwn_layer.py index 36b983d0..fb32d1f8 100644 --- a/topomodelx/nn/cell/cwn_layer.py +++ b/topomodelx/nn/cell/cwn_layer.py @@ -56,9 +56,8 @@ class CWNLayer(nn.Module): If None is passed, a default implementation of this module is used (check the docstring of _CWNDefaultUpdate for more detail). - Notes - ----- - This is the architecture proposed for entire complex classification. + **kwargs : optional + Additional arguments for the modules of the CWN layer. References ---------- @@ -78,6 +77,7 @@ def __init__( conv_0_to_1=None, aggregate_fn=None, update_fn=None, + **kwargs, ) -> None: super().__init__() self.conv_1_to_1 = ( @@ -104,9 +104,9 @@ def forward( x_0, x_1, x_2, - neighborhood_1_to_1, - neighborhood_2_to_1, - neighborhood_0_to_1, + adjacency_0, + incidence_2, + incidence_1_t, ): r"""Forward pass. @@ -159,11 +159,11 @@ def forward( Input features on the r-cells. x_2 : torch.Tensor, shape = (n_{r+1}_cells, in_channels_{r+1}) Input features on the (r+1)-cells. - neighborhood_1_to_1 : torch.sparse, shape = (n_{r}_cells, n_{r}_cells) + adjacency_0 : torch.sparse, shape = (n_{r}_cells, n_{r}_cells) Neighborhood matrix mapping r-cells to r-cells (A_{up,r}). - neighborhood_2_to_1 : torch.sparse, shape = (n_{r}_cells, n_{r+1}_cells) + incidence_2 : torch.sparse, shape = (n_{r}_cells, n_{r+1}_cells) Neighborhood matrix mapping (r+1)-cells to r-cells (B_{r+1}). - neighborhood_0_to_1 : torch.sparse, shape = (n_{r}_cells, n_{r-1}_cells) + incidence_1_t : torch.sparse, shape = (n_{r}_cells, n_{r-1}_cells) Neighborhood matrix mapping (r-1)-cells to r-cells (B^T_r). Returns @@ -180,10 +180,8 @@ def forward( Architectures of topological deep learning: a survey on topological neural networks (2023). https://arxiv.org/abs/2304.10031. """ - x_convolved_1_to_1 = self.conv_1_to_1( - x_1, x_2, neighborhood_1_to_1, neighborhood_2_to_1 - ) - x_convolved_0_to_1 = self.conv_0_to_1(x_0, x_1, neighborhood_0_to_1) + x_convolved_1_to_1 = self.conv_1_to_1(x_1, x_2, adjacency_0, incidence_2) + x_convolved_0_to_1 = self.conv_0_to_1(x_0, x_1, incidence_1_t) x_aggregated = self.aggregate_fn(x_convolved_1_to_1, x_convolved_0_to_1) return self.update_fn(x_aggregated, x_1) @@ -195,6 +193,15 @@ class _CWNDefaultFirstConv(nn.Module): The self.forward method of this module must be treated as a protocol for the first convolutional step in CWN layer. + + Parameters + ---------- + in_channels_1 : int + Dimension of input features on r-cells (edges in case r = 1). + in_channels_2 : int + Dimension of input features on (r+1)-cells (faces in case r = 1). + out_channels : int + Dimension of output features on r-cells. """ def __init__(self, in_channels_1, in_channels_2, out_channels) -> None: @@ -206,7 +213,7 @@ def __init__(self, in_channels_1, in_channels_2, out_channels) -> None: in_channels_2, out_channels, aggr_norm=False, update_func=None ) - def forward(self, x_1, x_2, neighborhood_1_to_1, neighborhood_2_to_1): + def forward(self, x_1, x_2, adjacency_0, incidence_2): r"""Forward pass. Parameters @@ -215,9 +222,9 @@ def forward(self, x_1, x_2, neighborhood_1_to_1, neighborhood_2_to_1): Input features on the (r-1)-cells. x_2 : torch.Tensor, shape = (n_{r}_cells, in_channels_{r}) Input features on the r-cells. - neighborhood_1_to_1 : torch.sparse, shape = (n_{r}_cells, n_{r}_cells) + adjacency_0 : torch.sparse, shape = (n_{r}_cells, n_{r}_cells) Neighborhood matrix mapping r-cells to r-cells (A_{up,r}). - neighborhood_2_to_1 : torch.sparse, shape = (n_{r}_cells, n_{r+1}_cells) + incidence_2 : torch.sparse, shape = (n_{r}_cells, n_{r+1}_cells) Neighborhood matrix mapping (r+1)-cells to r-cells (B_{r+1}). Returns @@ -225,8 +232,8 @@ def forward(self, x_1, x_2, neighborhood_1_to_1, neighborhood_2_to_1): torch.Tensor, shape = (n_{r}_cells, out_channels) Updated representations on the r-cells. """ - x_up = F.elu(self.conv_1_to_1(x_1, neighborhood_1_to_1)) - x_coboundary = F.elu(self.conv_2_to_1(x_2, neighborhood_2_to_1)) + x_up = F.elu(self.conv_1_to_1(x_1, adjacency_0)) + x_coboundary = F.elu(self.conv_2_to_1(x_2, incidence_2)) return x_up + x_coboundary @@ -236,6 +243,15 @@ class _CWNDefaultSecondConv(nn.Module): The self.forward method of this module must be treated as a protocol for the second convolutional step in CWN layer. + + Parameters + ---------- + in_channels_0 : int + Dimension of input features on (r-1)-cells (nodes in case r = 1). + in_channels_1 : int + Dimension of input features on r-cells (edges in case r = 1). + out_channels : int + Dimension of output features on r-cells. """ def __init__(self, in_channels_0, in_channels_1, out_channels) -> None: @@ -244,7 +260,7 @@ def __init__(self, in_channels_0, in_channels_1, out_channels) -> None: in_channels_0, out_channels, aggr_norm=False, update_func=None ) - def forward(self, x_0, x_1, neighborhood_0_to_1): + def forward(self, x_0, x_1, incidence_1_t): r"""Forward pass. Parameters @@ -253,7 +269,7 @@ def forward(self, x_0, x_1, neighborhood_0_to_1): Input features on the (r-1)-cells. x_1 : torch.Tensor, shape = (n_{r}_cells, in_channels_{r}) Input features on the r-cells. - neighborhood_0_to_1 : torch.sparse, shape = (n_{r}_cells, n_{r-1}_cells) + incidence_1_t : torch.sparse, shape = (n_{r}_cells, n_{r-1}_cells) Neighborhood matrix mapping (r-1)-cells to r-cells (B^T_r). Returns @@ -261,7 +277,7 @@ def forward(self, x_0, x_1, neighborhood_0_to_1): torch.Tensor, shape = (n_{r}_cells, out_channels) Updated representations on the r-cells. """ - return F.elu(self.conv_0_to_1(x_0, neighborhood_0_to_1)) + return F.elu(self.conv_0_to_1(x_0, incidence_1_t)) class _CWNDefaultAggregate(nn.Module): @@ -294,7 +310,15 @@ def forward(self, x, y): class _CWNDefaultUpdate(nn.Module): - r"""Default implementation of an update step in CWNLayer.""" + r"""Default implementation of an update step in CWNLayer. + + Parameters + ---------- + in_channels : int + Dimension of input features. + out_channels : int + Dimension of output features. + """ def __init__(self, in_channels, out_channels) -> None: super().__init__() diff --git a/tutorials/cell/cwn_train.ipynb b/tutorials/cell/cwn_train.ipynb index ac569b87..2eb9b991 100644 --- a/tutorials/cell/cwn_train.ipynb +++ b/tutorials/cell/cwn_train.ipynb @@ -52,7 +52,18 @@ }, "id": "h-kcMSPGNH1v" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import numpy as np\n", "import toponetx.datasets as datasets\n", @@ -318,12 +329,12 @@ " x_0,\n", " x_1,\n", " x_2,\n", - " neighborhood_1_to_1,\n", - " neighborhood_2_to_1,\n", - " neighborhood_0_to_1,\n", + " adjacency_1,\n", + " incidence_2,\n", + " incidence_1_t,\n", " ):\n", " x_0, x_1, x_2 = self.base_model(\n", - " x_0, x_1, x_2, neighborhood_1_to_1, neighborhood_2_to_1, neighborhood_0_to_1\n", + " x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t\n", " )\n", " x_0 = self.lin_0(x_0)\n", " x_1 = self.lin_1(x_1)\n", @@ -492,13 +503,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch:2, Train Loss: 83.8053 Test Loss: 73.7517\n" + "Epoch:2, Train Loss: 83.8053 Test Loss: 73.7517\n", + "Epoch:4, Train Loss: 81.9551 Test Loss: 50.2781\n", + "Epoch:6, Train Loss: 78.3991 Test Loss: 49.9035\n", + "Epoch:8, Train Loss: 75.8110 Test Loss: 45.7197\n", + "Epoch:10, Train Loss: 74.3838 Test Loss: 40.5566\n" ] } ], "source": [ "test_interval = 2\n", - "num_epochs = 2\n", + "num_epochs = 10\n", "\n", "for epoch_i in range(1, num_epochs + 1):\n", " epoch_loss = []\n", @@ -577,9 +592,9 @@ "provenance": [] }, "kernelspec": { - "display_name": "venv_modelx", + "display_name": "venv_tmx", "language": "python", - "name": "venv_modelx" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -592,11 +607,6 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" - }, - "vscode": { - "interpreter": { - "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" - } } }, "nbformat": 4, From bf30e09568460afdaf2396917a59358fb55311cb Mon Sep 17 00:00:00 2001 From: gbg141 Date: Thu, 4 Apr 2024 20:29:05 +0200 Subject: [PATCH 3/4] CCXN Updated --- topomodelx/nn/cell/ccxn.py | 12 ++++++++---- topomodelx/nn/cell/ccxn_layer.py | 18 ++++++++---------- tutorials/cell/ccxn_train.ipynb | 8 +++----- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/topomodelx/nn/cell/ccxn.py b/topomodelx/nn/cell/ccxn.py index b6e3e042..9848df73 100644 --- a/topomodelx/nn/cell/ccxn.py +++ b/topomodelx/nn/cell/ccxn.py @@ -20,6 +20,8 @@ class CCXN(torch.nn.Module): Number of CCXN layers. att : bool Whether to use attention. + **kwargs : optional + Additional arguments CCXNLayer. References ---------- @@ -36,6 +38,7 @@ def __init__( in_channels_2, n_layers=2, att=False, + **kwargs, ): super().__init__() @@ -45,11 +48,12 @@ def __init__( in_channels_1=in_channels_1, in_channels_2=in_channels_2, att=att, + **kwargs, ) for _ in range(n_layers) ) - def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2): + def forward(self, x_0, x_1, adjacency_0, incidence_2_t): """Forward computation through layers. Parameters @@ -58,9 +62,9 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2): Input features on the nodes (0-cells). x_1 : torch.Tensor, shape = (n_edges, in_channels_1) Input features on the edges (1-cells). - neighborhood_0_to_0 : torch.Tensor, shape = (n_nodes, n_nodes) + adjacency_0 : torch.Tensor, shape = (n_nodes, n_nodes) Adjacency matrix of rank 0 (up). - neighborhood_1_to_2 : torch.Tensor, shape = (n_faces, n_edges) + incidence_2_t : torch.Tensor, shape = (n_faces, n_edges) Transpose of boundary matrix of rank 2. Returns @@ -73,5 +77,5 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2): Final hidden states of the faces (2-cells). """ for layer in self.layers: - x_0, x_1, x_2 = layer(x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2) + x_0, x_1, x_2 = layer(x_0, x_1, adjacency_0, incidence_2_t) return (x_0, x_1, x_2) diff --git a/topomodelx/nn/cell/ccxn_layer.py b/topomodelx/nn/cell/ccxn_layer.py index 248fd678..e596998b 100644 --- a/topomodelx/nn/cell/ccxn_layer.py +++ b/topomodelx/nn/cell/ccxn_layer.py @@ -25,10 +25,8 @@ class CCXNLayer(torch.nn.Module): Dimension of input features on faces (2-cells). att : bool, default=False Whether to use attention. - - Notes - ----- - This is the architecture proposed for entire complex classification. + **kwargs : optional + Additional arguments for the modules of the CCXN layer. References ---------- @@ -45,7 +43,7 @@ class CCXNLayer(torch.nn.Module): """ def __init__( - self, in_channels_0, in_channels_1, in_channels_2, att: bool = False + self, in_channels_0, in_channels_1, in_channels_2, att: bool = False, **kwargs ) -> None: super().__init__() self.conv_0_to_0 = Conv( @@ -55,7 +53,7 @@ def __init__( in_channels=in_channels_1, out_channels=in_channels_2, att=att ) - def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2, x_2=None): + def forward(self, x_0, x_1, adjacency_0, incidence_2_t, x_2=None): r"""Forward pass. The forward pass was initially proposed in [1]_. @@ -97,9 +95,9 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2, x_2=None): Input features on the nodes of the cell complex. x_1 : torch.Tensor, shape = (n_1_cells, channels) Input features on the edges of the cell complex. - neighborhood_0_to_0 : torch.sparse, shape = (n_0_cells, n_0_cells) + adjacency_0 : torch.sparse, shape = (n_0_cells, n_0_cells) Neighborhood matrix mapping nodes to nodes (A_0_up). - neighborhood_1_to_2 : torch.sparse, shape = (n_2_cells, n_1_cells) + incidence_2_t : torch.sparse, shape = (n_2_cells, n_1_cells) Neighborhood matrix mapping edges to faces (B_2^T). x_2 : torch.Tensor, shape = (n_2_cells, channels) Input features on the faces of the cell complex. @@ -113,10 +111,10 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2, x_2=None): x_0 = torch.nn.functional.relu(x_0) x_1 = torch.nn.functional.relu(x_1) - x_0 = self.conv_0_to_0(x_0, neighborhood_0_to_0) + x_0 = self.conv_0_to_0(x_0, adjacency_0) x_0 = torch.nn.functional.relu(x_0) - x_2 = self.conv_1_to_2(x_1, neighborhood_1_to_2, x_2) + x_2 = self.conv_1_to_2(x_1, incidence_2_t, x_2) x_2 = torch.nn.functional.relu(x_2) return x_0, x_1, x_2 diff --git a/tutorials/cell/ccxn_train.ipynb b/tutorials/cell/ccxn_train.ipynb index dc86c739..a58c1d41 100644 --- a/tutorials/cell/ccxn_train.ipynb +++ b/tutorials/cell/ccxn_train.ipynb @@ -269,10 +269,8 @@ " self.lin_1 = torch.nn.Linear(in_channels_1, num_classes)\n", " self.lin_2 = torch.nn.Linear(in_channels_2, num_classes)\n", "\n", - " def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):\n", - " x_0, x_1, x_2 = self.base_model(\n", - " x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2\n", - " )\n", + " def forward(self, x_0, x_1, adjacency_0, incidence_2_t):\n", + " x_0, x_1, x_2 = self.base_model(x_0, x_1, adjacency_0, incidence_2_t)\n", " x_0 = self.lin_0(x_0)\n", " x_1 = self.lin_1(x_1)\n", " x_2 = self.lin_2(x_2)\n", @@ -436,7 +434,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/gbg141/Documents/TopoProjectX/TopoModelX/venv_modelx/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n", + "/Users/gbg141/Documents/Projects/TopoModelX/venv_tmx/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n", " return F.mse_loss(input, target, reduction=self.reduction)\n" ] }, From db96096330aed5b268c5298c593b81315b6c7df8 Mon Sep 17 00:00:00 2001 From: gbg141 Date: Thu, 4 Apr 2024 21:02:29 +0200 Subject: [PATCH 4/4] Solving test CAN (looping) --- test/nn/cell/test_can.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/nn/cell/test_can.py b/test/nn/cell/test_can.py index 581e20a8..43534402 100644 --- a/test/nn/cell/test_can.py +++ b/test/nn/cell/test_can.py @@ -20,6 +20,7 @@ def test_forward(self): heads=1, n_layers=2, att_lift=False, + pooling=True, ).to(device) x_0 = torch.rand(2, 2)