From ebe1e2ae8383bc46ce76a960eb2cf05a372fb1a3 Mon Sep 17 00:00:00 2001 From: Nina Miolane Date: Wed, 20 Sep 2023 05:41:00 -0700 Subject: [PATCH] Remove template layer everywhere: code, tests + docs --- docs/api/nn/hypergraph/index.rst | 1 - docs/api/nn/hypergraph/template_layer.rst | 6 - docs/tutorials/index.rst | 18 +- test/nn/hypergraph/test_template_layer.py | 31 -- topomodelx/nn/hypergraph/template_layer.py | 71 --- tutorials/hypergraph/template_train.ipynb | 494 --------------------- 6 files changed, 9 insertions(+), 612 deletions(-) delete mode 100644 docs/api/nn/hypergraph/template_layer.rst delete mode 100644 test/nn/hypergraph/test_template_layer.py delete mode 100644 topomodelx/nn/hypergraph/template_layer.py delete mode 100644 tutorials/hypergraph/template_train.ipynb diff --git a/docs/api/nn/hypergraph/index.rst b/docs/api/nn/hypergraph/index.rst index 14f09431d..a5bdf282e 100644 --- a/docs/api/nn/hypergraph/index.rst +++ b/docs/api/nn/hypergraph/index.rst @@ -23,7 +23,6 @@ The Base class is composed of primarily these classes: hypergat hypersage_layer hypersage - template_layer unigcn_layer unigcn unigcnii_layer diff --git a/docs/api/nn/hypergraph/template_layer.rst b/docs/api/nn/hypergraph/template_layer.rst deleted file mode 100644 index 43c05a385..000000000 --- a/docs/api/nn/hypergraph/template_layer.rst +++ /dev/null @@ -1,6 +0,0 @@ -************** -Template_Layer -************** - -.. automodule:: topomodelx.nn.hypergraph.template_layer - :members: diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index c3845472d..bef58e0ce 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -4,9 +4,9 @@ Tutorials ========= ---------------------------------------------------------- -Tutorials for Topological Neural Nets on Cellular Complex ---------------------------------------------------------- +--------------------------------------------- +Topological Neural Nets on Cellular Complexes +--------------------------------------------- .. nbgallery:: :maxdepth: 1 @@ -15,9 +15,9 @@ Tutorials for Topological Neural Nets on Cellular Complex ../notebooks/cell/* ----------------------------------------------------- -Tutorials for Topological Neural Nets on Hypergraphs ----------------------------------------------------- +-------------------------------------- +Topological Neural Nets on Hypergraphs +-------------------------------------- .. nbgallery:: :maxdepth: 1 @@ -25,9 +25,9 @@ Tutorials for Topological Neural Nets on Hypergraphs ../notebooks/hypergraph/* ------------------------------------------------------------ -Tutorials for Topological Neural Nets on Simplicial Complex ------------------------------------------------------------ +----------------------------------------------- +Topological Neural Nets on Simplicial Complexes +----------------------------------------------- .. nbgallery:: :maxdepth: 1 diff --git a/test/nn/hypergraph/test_template_layer.py b/test/nn/hypergraph/test_template_layer.py deleted file mode 100644 index b25b78bfc..000000000 --- a/test/nn/hypergraph/test_template_layer.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Test the template layer.""" -import pytest -import torch - -from topomodelx.nn.hypergraph.template_layer import TemplateLayer - - -class TestTemplateLayer: - """Test the template layer.""" - - @pytest.fixture - def template_layer(self): - """Return a template layer.""" - in_channels = 10 - intermediate_channels = 20 - out_channels = 30 - return TemplateLayer(in_channels, intermediate_channels, out_channels) - - def test_forward(self, template_layer): - """Test the forward pass of the template layer.""" - x_2 = torch.randn(3, 10) - incidence_2 = torch.tensor([[1, 0, 1], [0, 1, 1]], dtype=torch.float32) - output = template_layer.forward(x_2, incidence_2) - assert output.shape == (3, 30) - - def test_forward_with_invalid_input(self, template_layer): - """Test the forward pass of the template layer with invalid input.""" - x_2 = torch.randn(4, 10) - incidence_2 = torch.tensor([[1, 0, 1], [0, 1, 1]], dtype=torch.float32) - with pytest.raises(ValueError): - template_layer.forward(x_2, incidence_2) diff --git a/topomodelx/nn/hypergraph/template_layer.py b/topomodelx/nn/hypergraph/template_layer.py deleted file mode 100644 index 2212c725d..000000000 --- a/topomodelx/nn/hypergraph/template_layer.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Template Layer with two conv passing steps.""" -import torch - -from topomodelx.base.conv import Conv - - -class TemplateLayer(torch.nn.Module): - """Template Layer with two conv passing steps. - - A two-step message passing layer. - - Parameters - ---------- - in_channels : int - Dimension of input features. - intermediate_channels : int - Dimension of intermediate features. - out_channels : int - Dimension of output features. - """ - - def __init__( - self, - in_channels, - intermediate_channels, - out_channels, - ) -> None: - super().__init__() - - self.conv_level1_1_to_0 = Conv( - in_channels=in_channels, - out_channels=intermediate_channels, - aggr_norm=True, - update_func="sigmoid", - ) - self.conv_level2_0_to_1 = Conv( - in_channels=intermediate_channels, - out_channels=out_channels, - aggr_norm=True, - update_func="sigmoid", - ) - - def reset_parameters(self) -> None: - r"""Reset learnable parameters.""" - self.conv_level1_1_to_0.reset_parameters() - self.conv_level2_0_to_1.reset_parameters() - - def forward(self, x_1, incidence_1): - r"""Forward computation. - - Parameters - ---------- - x_1 : torch.Tensor, shape=[n_edges, in_channels] - Input features on the edges of the simplicial complex. - incidence_1 : torch.sparse - shape=[n_nodes, n_edges] - Incidence matrix mapping edges to nodes (B_1). - - Returns - ------- - x_1 : torch.Tensor, shape=[n_edges, out_channels] - Output features on the edges of the simplicial complex. - """ - incidence_1_transpose = incidence_1.transpose(1, 0) - if x_1.shape[-2] != incidence_1.shape[-1]: - raise ValueError( - f"Shape of input face features does not have the correct number of edges {incidence_1.shape[-1]}." - ) - x_0 = self.conv_level1_1_to_0(x_1, incidence_1) - x_1 = self.conv_level2_0_to_1(x_0, incidence_1_transpose) - return x_1 diff --git a/tutorials/hypergraph/template_train.ipynb b/tutorials/hypergraph/template_train.ipynb deleted file mode 100644 index f0e13c0e2..000000000 --- a/tutorials/hypergraph/template_train.ipynb +++ /dev/null @@ -1,494 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Train a (template) Hypergraph Neural Network\n", - "\n", - "In this notebook, we will create and train a two-step message passing network in the hypergraph domain. We will use a benchmark dataset, shrec16, a collection of 3D meshes, to train the model to perform classification at the level of the hypergraph. " - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "ExecuteTime": { - "end_time": "2023-06-01T16:14:51.222779223Z", - "start_time": "2023-06-01T16:14:49.575421023Z" - } - }, - "outputs": [], - "source": [ - "import torch\n", - "import numpy as np\n", - "from sklearn.model_selection import train_test_split\n", - "\n", - "import toponetx.datasets as datasets\n", - "from topomodelx.nn.hypergraph.template_layer import TemplateLayer" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If GPU's are available, we will make use of them. Otherwise, this will run on CPU." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "ExecuteTime": { - "end_time": "2023-06-01T16:14:51.959770754Z", - "start_time": "2023-06-01T16:14:51.956096841Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cpu\n" - ] - } - ], - "source": [ - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "print(device)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Pre-processing\n", - "\n", - "## Import data ##\n", - "\n", - "The first step is to import the dataset, shrec 16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a hypergraph.\n", - "\n", - "We will also retrieve:\n", - "- input signal on the edges for each of these hypergraphs, as that will be what we feed the model in input\n", - "- the label associated to the hypergraph" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "ExecuteTime": { - "end_time": "2023-06-01T16:14:53.022151550Z", - "start_time": "2023-06-01T16:14:52.949636599Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Loading shrec 16 small dataset...\n", - "\n", - "done!\n" - ] - } - ], - "source": [ - "shrec, _ = datasets.mesh.shrec_16(size=\"small\")\n", - "\n", - "shrec = {key: np.array(value) for key, value in shrec.items()}\n", - "x_0s = shrec[\"node_feat\"]\n", - "x_1s = shrec[\"edge_feat\"]\n", - "x_2s = shrec[\"face_feat\"]\n", - "\n", - "ys = shrec[\"label\"]\n", - "simplexes = shrec[\"complexes\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The 6th simplicial complex has 252 nodes with features of dimension 6.\n", - "The 6th simplicial complex has 750 edges with features of dimension 10.\n", - "The 6th simplicial complex has 500 faces with features of dimension 7.\n" - ] - } - ], - "source": [ - "i_complex = 6\n", - "print(\n", - " f\"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}.\"\n", - ")\n", - "print(\n", - " f\"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}.\"\n", - ")\n", - "print(\n", - " f\"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}.\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Define neighborhood structures and lift into hypergraph domain. ##\n", - "\n", - "Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on each simplicial complex. In the case of this architecture, we need the boundary matrix (or incidence matrix) $B_1$ with shape $n_\\text{nodes} \\times n_\\text{edges}$.\n", - "\n", - "Once we have recorded the incidence matrix (note that all incidence amtrices in the hypergraph domain must be unsigned), we lift each simplicial complex into a hypergraph. The pairwise edges will become pairwise hyperedges, and faces in the simplciial complex will become 3-wise hyperedges." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "ExecuteTime": { - "end_time": "2023-06-01T16:14:53.022151550Z", - "start_time": "2023-06-01T16:14:52.949636599Z" - } - }, - "outputs": [], - "source": [ - "hg_list = []\n", - "incidence_1_list = []\n", - "for simplex in simplexes:\n", - " incidence_1 = simplex.incidence_matrix(rank=1, signed=False)\n", - " incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()\n", - " incidence_1_list.append(incidence_1)\n", - " hg = simplex.to_hypergraph()\n", - " hg_list.append(hg)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "matrix([[1., 1., 1., ..., 0., 0., 0.],\n", - " [1., 0., 0., ..., 0., 0., 0.],\n", - " [0., 1., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 0., 0., 0.],\n", - " [0., 0., 0., ..., 1., 1., 1.]], dtype=float32)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "simplex.incidence_matrix(rank=1, signed=False).todense()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The 6th hypergraph has an incidence matrix of shape torch.Size([252, 750]).\n" - ] - } - ], - "source": [ - "i_complex = 6\n", - "print(\n", - " f\"The {i_complex}th hypergraph has an incidence matrix of shape {incidence_1_list[i_complex].shape}.\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Create the Neural Network\n", - "\n", - "Using the TemplateLayer class, we create a neural network with stacked layers." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "ExecuteTime": { - "end_time": "2023-06-01T16:14:55.343005145Z", - "start_time": "2023-06-01T16:14:55.339481459Z" - } - }, - "outputs": [], - "source": [ - "channels_edge = x_1s[0].shape[1]\n", - "channels_node = x_0s[0].shape[1]" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "ExecuteTime": { - "end_time": "2023-06-01T16:14:56.033274119Z", - "start_time": "2023-06-01T16:14:56.029056913Z" - } - }, - "outputs": [], - "source": [ - "class TemplateNN(torch.nn.Module):\n", - " \"\"\"Neural network implementation of Template for hypergraph classification.\n", - "\n", - " Parameters\n", - " ---------\n", - " channels_edge : int\n", - " Dimension of edge features\n", - " channels_node : int\n", - " Dimension of node features\n", - " n_layer : 2\n", - " Amount of message passing layers.\n", - "\n", - " \"\"\"\n", - "\n", - " def __init__(self, channels_edge, channels_node, n_layers=2):\n", - " super().__init__()\n", - " layers = []\n", - " for _ in range(n_layers):\n", - " layers.append(\n", - " TemplateLayer(\n", - " in_channels=channels_edge,\n", - " intermediate_channels=channels_node,\n", - " out_channels=channels_edge,\n", - " )\n", - " )\n", - " self.layers = torch.nn.ModuleList(layers)\n", - " self.linear = torch.nn.Linear(channels_edge, 1)\n", - "\n", - " def forward(self, x_1, incidence_1):\n", - " \"\"\"Forward computation through layers, then linear layer, then global max pooling.\n", - "\n", - " Parameters\n", - " ---------\n", - " x_1 : tensor\n", - " shape = [n_edges, channels_edge]\n", - " Edge features.\n", - "\n", - " incidence_1 : tensor\n", - " shape = [n_nodes, n_edges]\n", - " Boundary matrix of rank 1.\n", - "\n", - " Returns\n", - " --------\n", - " _ : tensor\n", - " shape = [1]\n", - " Label assigned to whole complex.\n", - " \"\"\"\n", - " for layer in self.layers:\n", - " x_1 = layer(x_1, incidence_1)\n", - " pooled_x = torch.max(x_1, dim=0)[0]\n", - " return torch.sigmoid(self.linear(pooled_x))[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Train the Neural Network\n", - "\n", - "We specify the model, the loss, and an optimizer." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "ExecuteTime": { - "end_time": "2023-06-01T16:14:58.153514385Z", - "start_time": "2023-06-01T16:14:57.243596119Z" - } - }, - "outputs": [], - "source": [ - "model = TemplateNN(channels_edge, channels_node, n_layers=2)\n", - "model = model.to(device)\n", - "opt = torch.optim.Adam(model.parameters(), lr=0.1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Split the dataset into train and test sets." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "ExecuteTime": { - "end_time": "2023-06-01T16:14:59.046068930Z", - "start_time": "2023-06-01T16:14:59.037648626Z" - } - }, - "outputs": [], - "source": [ - "test_size = 0.2\n", - "x_1_train, x_1_test = train_test_split(x_1s, test_size=test_size, shuffle=False)\n", - "incidence_1_train, incidence_1_test = train_test_split(\n", - " incidence_1_list, test_size=test_size, shuffle=False\n", - ")\n", - "y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "loss_fn = torch.nn.MSELoss()" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "ExecuteTime": { - "end_time": "2023-06-01T16:15:01.683216142Z", - "start_time": "2023-06-01T16:15:00.727075750Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 1 loss: 274.6126\n", - "Test_loss: 529.0002\n", - "Epoch: 2 loss: 274.6126\n", - "Test_loss: 529.0001\n", - "Epoch: 3 loss: 274.6126\n", - "Test_loss: 529.0001\n", - "Epoch: 4 loss: 274.6126\n", - "Test_loss: 529.0001\n", - "Epoch: 5 loss: 274.6126\n", - "Test_loss: 529.0001\n", - "Epoch: 6 loss: 274.6126\n", - "Test_loss: 529.0001\n", - "Epoch: 7 loss: 274.6126\n", - "Test_loss: 529.0001\n", - "Epoch: 8 loss: 274.6126\n", - "Test_loss: 529.0001\n", - "Epoch: 9 loss: 274.6126\n", - "Test_loss: 529.0001\n", - "Epoch: 10 loss: 274.6126\n", - "Test_loss: 529.0001\n" - ] - } - ], - "source": [ - "test_interval = 1\n", - "num_epochs = 10\n", - "for epoch_i in range(1, num_epochs + 1):\n", - " epoch_loss = []\n", - " model.train()\n", - " for x_1, incidence_1, y in zip(x_1_train, incidence_1_train, y_train):\n", - " x_1 = torch.tensor(x_1)\n", - " x_1, incidence_1, y = (\n", - " x_1.float().to(device),\n", - " incidence_1.float().to(device),\n", - " torch.tensor(y, dtype=torch.float).to(device),\n", - " )\n", - " opt.zero_grad()\n", - " y_hat = model(x_1, incidence_1)\n", - " loss = loss_fn(y_hat, y)\n", - "\n", - " loss.backward()\n", - " opt.step()\n", - " epoch_loss.append(loss.item())\n", - "\n", - " print(\n", - " f\"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}\",\n", - " flush=True,\n", - " )\n", - " if epoch_i % test_interval == 0:\n", - " with torch.no_grad():\n", - " for x_1, incidence_1, y in zip(x_1_test, incidence_1_test, y_test):\n", - " x_1 = torch.tensor(x_1)\n", - " x_1, incidence_1, y = (\n", - " x_1.float().to(device),\n", - " incidence_1.float().to(device),\n", - " torch.tensor(y, dtype=torch.float).to(device),\n", - " )\n", - " y_hat = model(x_1, incidence_1)\n", - " loss = loss_fn(y_hat, y)\n", - "\n", - " print(f\"Test_loss: {loss:.4f}\", flush=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.9.6 64-bit", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.6" - }, - "vscode": { - "interpreter": { - "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" - } - } - }, - "nbformat": 4, - "nbformat_minor": 4 -}