diff --git a/pyproject.toml b/pyproject.toml index f2cc0a05..80a4657c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,3 +121,6 @@ checks = [ "EX01", "SA01" ] +exclude = [ + '\.undocumented_method$', +] diff --git a/test/nn/cell/test_can.py b/test/nn/cell/test_can.py index c4c19eb0..245d4f02 100644 --- a/test/nn/cell/test_can.py +++ b/test/nn/cell/test_can.py @@ -18,7 +18,6 @@ def test_forward(self): out_channels=2, dropout=0.5, heads=1, - num_classes=1, n_layers=2, att_lift=False, ).to(device) @@ -36,5 +35,5 @@ def test_forward(self): adjacency_2 = adjacency_1.float().to(device) incidence_2 = adjacency_1.float().to(device) - y = model(x_0, x_1, adjacency_1, adjacency_2, incidence_2) - assert y.shape == torch.Size([1]) + x_1 = model(x_0, x_1, adjacency_1, adjacency_2, incidence_2) + assert x_1.shape == torch.Size([1, 2]) diff --git a/test/nn/cell/test_ccxn.py b/test/nn/cell/test_ccxn.py index a2fc4d92..9394864f 100644 --- a/test/nn/cell/test_ccxn.py +++ b/test/nn/cell/test_ccxn.py @@ -15,7 +15,6 @@ def test_forward(self): in_channels_0=2, in_channels_1=2, in_channels_2=2, - num_classes=1, n_layers=2, att=False, ).to(device) @@ -33,5 +32,7 @@ def test_forward(self): adjacency_1 = adjacency_1.float().to(device) incidence_2 = incidence_2.float().to(device) - y = model(x_0, x_1, adjacency_1, incidence_2) - assert y.shape == torch.Size([1]) + x_0, x_1, x_2 = model(x_0, x_1, adjacency_1, incidence_2) + assert x_0.shape == torch.Size([2, 2]) + assert x_1.shape == torch.Size([2, 2]) + assert x_2.shape == torch.Size([2, 2]) diff --git a/test/nn/cell/test_cwn.py b/test/nn/cell/test_cwn.py index e7e4f3ab..cb0aafd1 100644 --- a/test/nn/cell/test_cwn.py +++ b/test/nn/cell/test_cwn.py @@ -16,7 +16,6 @@ def test_forward(self): in_channels_1=2, in_channels_2=2, hid_channels=16, - num_classes=1, n_layers=2, ).to(device) @@ -36,5 +35,7 @@ def test_forward(self): incidence_2 = incidence_2.float().to(device) incidence_1_t = incidence_1_t.float().to(device) - y = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t) - assert y.shape == torch.Size([1]) + x_0, x_1, x_2 = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t) + assert x_0.shape == torch.Size([2, 16]) + assert x_1.shape == torch.Size([2, 16]) + assert x_2.shape == torch.Size([2, 16]) diff --git a/test/nn/combinatorial/test_hmc.py b/test/nn/combinatorial/test_hmc.py index a4371bc3..8270a7b1 100644 --- a/test/nn/combinatorial/test_hmc.py +++ b/test/nn/combinatorial/test_hmc.py @@ -15,7 +15,7 @@ def test_forward(self): intermediate_channels = [2, 2, 2] final_channels = [2, 2, 2] channels_per_layer = [[in_channels, intermediate_channels, final_channels]] - model = HMC(channels_per_layer, negative_slope=0.2, num_classes=2).to(device) + model = HMC(channels_per_layer, negative_slope=0.2).to(device) x_0 = torch.rand(2, 2) x_1 = torch.rand(2, 2) @@ -29,7 +29,7 @@ def test_forward(self): ) adjacency_0 = adjacency_0.float().to(device) - y = model( + x_0, x_1, x_2 = model( x_0, x_1, x_2, @@ -39,4 +39,6 @@ def test_forward(self): adjacency_0, adjacency_0, ) - assert y.shape == torch.Size([2]) + assert x_0.shape == torch.Size([2, 2]) + assert x_1.shape == torch.Size([2, 2]) + assert x_2.shape == torch.Size([2, 2]) diff --git a/test/nn/simplicial/test_san.py b/test/nn/simplicial/test_san.py index f250bc89..d99c613e 100644 --- a/test/nn/simplicial/test_san.py +++ b/test/nn/simplicial/test_san.py @@ -42,7 +42,7 @@ def test_forward(self): in_channels=in_channels, hidden_channels=hidden_channels, out_channels=out_channels, - n_layers=1, + n_layers=3, ) laplacian_down_1 = from_sparse(simplicial_complex.down_laplacian_matrix(rank=1)) laplacian_up_1 = from_sparse(simplicial_complex.up_laplacian_matrix(rank=1)) @@ -50,7 +50,7 @@ def test_forward(self): assert torch.any( torch.isclose( model(x, laplacian_up_1, laplacian_down_1)[0], - torch.tensor([2.8254, -0.9797]), + torch.tensor([-2.5604, -3.5924]), rtol=1e-02, ) ) diff --git a/topomodelx/nn/cell/can.py b/topomodelx/nn/cell/can.py index bdcb4e2a..25e8fccb 100644 --- a/topomodelx/nn/cell/can.py +++ b/topomodelx/nn/cell/can.py @@ -17,12 +17,10 @@ class CAN(torch.nn.Module): Number of input channels for the edge-level input. out_channels : int Number of output channels. - num_classes : int - Number of output classes. dropout : float, optional Dropout probability. Default is 0.5. heads : int, optional - Number of attention heads. Default is 3. + Number of attention heads. Default is 2. concat : bool, optional Whether to concatenate the output channels of attention heads. Default is True. skip_connection : bool, optional @@ -33,6 +31,8 @@ class CAN(torch.nn.Module): Number of CAN layers. att_lift : bool, default=True Whether to apply a lift the signal from node-level to edge-level input. + k_pool : float, default=0.5 + The pooling ratio i.e, the fraction of r-cells to keep after the pooling operation. References ---------- @@ -47,14 +47,14 @@ def __init__( in_channels_0, in_channels_1, out_channels, - num_classes, dropout=0.5, - heads=3, + heads=2, concat=True, skip_connection=True, att_activation=torch.nn.LeakyReLU(0.2), n_layers=2, att_lift=True, + k_pool=0.5, ): super().__init__() @@ -98,7 +98,7 @@ def __init__( layers.append( PoolLayer( - k_pool=0.5, + k_pool=k_pool, in_channels_0=out_channels * heads, signal_pool_activation=torch.nn.Sigmoid(), readout=True, @@ -106,8 +106,6 @@ def __init__( ) self.layers = torch.nn.ModuleList(layers) - self.lin_0 = torch.nn.Linear(heads * out_channels, 128) - self.lin_1 = torch.nn.Linear(128, num_classes) def forward( self, x_0, x_1, neighborhood_0_to_0, lower_neighborhood, upper_neighborhood @@ -129,8 +127,8 @@ def forward( Returns ------- - torch.Tensor - Output prediction for the cell complex. + torch.Tensor, shape = (num_pooled_edges, heads * out_channels) + Final hidden representations of pooled edges. """ if hasattr(self, "lift_layer"): x_1 = self.lift_layer(x_0, neighborhood_0_to_0, x_1) @@ -144,10 +142,4 @@ def forward( x_1 = layer(x_1, lower_neighborhood, upper_neighborhood) x_1 = F.dropout(x_1, p=0.5, training=self.training) - # max pooling over all nodes in each graph - x = x_1.max(dim=0)[0] - - # Feed-Foward Neural Network to predict the graph label - out = self.lin_1(torch.nn.functional.relu(self.lin_0(x))) - - return out + return x_1 diff --git a/topomodelx/nn/cell/can_layer.py b/topomodelx/nn/cell/can_layer.py index edbc73b3..d94d3fb4 100644 --- a/topomodelx/nn/cell/can_layer.py +++ b/topomodelx/nn/cell/can_layer.py @@ -289,7 +289,7 @@ class PoolLayer(MessagePassing): Parameters ---------- k_pool : float in (0, 1] - The pooling ratio i.e, the fraction of edges to keep after the pooling operation. + The pooling ratio i.e, the fraction of r-cells to keep after the pooling operation. in_channels_0 : int Number of input channels of the input signal. signal_pool_activation : Callable @@ -323,14 +323,14 @@ def reset_parameters(self) -> None: init.xavier_uniform_(self.att_pool.data, gain=gain) def forward( # type: ignore[override] - self, x_0, lower_neighborhood, upper_neighborhood + self, x, lower_neighborhood, upper_neighborhood ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Forward pass. Parameters ---------- - x_0 : torch.Tensor, shape = (num_nodes, in_channels_0) - Node signal. + x : torch.Tensor, shape = (n_r_cells, in_channels_r) + Input r-cell signal. lower_neighborhood : torch.Tensor Lower neighborhood matrix. upper_neighborhood : torch.Tensor @@ -339,7 +339,7 @@ def forward( # type: ignore[override] Returns ------- torch.Tensor - Pooled node signal of shape (num_pooled_nodes, in_channels_0). + Pooled r_cell signal of shape (n_r_cells, in_channels_r). Notes ----- @@ -351,21 +351,19 @@ def forward( # type: ignore[override] = \phi^t(h_x^t, m_{x}^{(r)}), \forall x\in \mathcal C_r^{t+1} \end{align*} """ - # Compute the output edge signal by applying the activation function - Zp = torch.einsum("nc,ce->ne", x_0, self.att_pool) - # Apply top-k pooling to the edge signal + # Compute the output r-cell signal by applying the activation function + Zp = torch.einsum("nc,ce->ne", x, self.att_pool) + # Apply top-k pooling to the r-cell signal _, top_indices = topk(Zp.view(-1), int(self.k_pool * Zp.size(0))) # Rescale the pooled signal Zp = self.signal_pool_activation(Zp) - out = x_0[top_indices] * Zp[top_indices] + out = x[top_indices] * Zp[top_indices] # Readout operation if self.readout: - out = scatter_add(out, top_indices, dim=0, dim_size=x_0.size(0))[ - top_indices - ] + out = scatter_add(out, top_indices, dim=0, dim_size=x.size(0))[top_indices] - # Update lower and upper neighborhood matrices with the top-k pooled edges + # Update lower and upper neighborhood matrices with the top-k pooled r-cells lower_neighborhood_modified = torch.index_select( lower_neighborhood, 0, top_indices ) diff --git a/topomodelx/nn/cell/ccxn.py b/topomodelx/nn/cell/ccxn.py index e9ee68d7..284334d7 100644 --- a/topomodelx/nn/cell/ccxn.py +++ b/topomodelx/nn/cell/ccxn.py @@ -16,8 +16,6 @@ class CCXN(torch.nn.Module): Dimension of input features on edges. in_channels_2 : int Dimension of input features on faces. - num_classes : int - Number of classes. n_layers : int Number of CCXN layers. att : bool @@ -36,7 +34,6 @@ def __init__( in_channels_0, in_channels_1, in_channels_2, - num_classes, n_layers=2, att=False, ): @@ -52,12 +49,9 @@ def __init__( ) ) self.layers = torch.nn.ModuleList(layers) - self.lin_0 = torch.nn.Linear(in_channels_0, num_classes) - self.lin_1 = torch.nn.Linear(in_channels_1, num_classes) - self.lin_2 = torch.nn.Linear(in_channels_2, num_classes) def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2): - """Forward computation through layers, then linear layers, then avg pooling. + """Forward computation through layers. Parameters ---------- @@ -72,24 +66,13 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2): Returns ------- - torch.Tensor, shape = (1) - Label assigned to whole complex. + x_0 : torch.Tensor, shape = (n_nodes, in_channels_0) + Final hidden states of the nodes (0-cells). + x_1 : torch.Tensor, shape = (n_edges, in_channels_1) + Final hidden states the edges (1-cells). + x_2 : torch.Tensor, shape = (n_faces, in_channels_2) + Final hidden states of the faces (2-cells). """ for layer in self.layers: x_0, x_1, x_2 = layer(x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2) - x_0 = self.lin_0(x_0) - x_1 = self.lin_1(x_1) - x_2 = self.lin_2(x_2) - # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0. - two_dimensional_cells_mean = torch.nanmean(x_2, dim=0) - two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0 - one_dimensional_cells_mean = torch.nanmean(x_1, dim=0) - one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0 - zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0) - zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0 - # Return the sum of the averages - return ( - two_dimensional_cells_mean - + one_dimensional_cells_mean - + zero_dimensional_cells_mean - ) + return (x_0, x_1, x_2) diff --git a/topomodelx/nn/cell/cwn.py b/topomodelx/nn/cell/cwn.py index ab7a85bd..5e656ef2 100644 --- a/topomodelx/nn/cell/cwn.py +++ b/topomodelx/nn/cell/cwn.py @@ -19,8 +19,6 @@ class CWN(torch.nn.Module): Dimension of input features on faces (2-cells). hid_channels : int Dimension of hidden features. - num_classes : int - Number of classes. n_layers : int Number of CWN layers. @@ -38,7 +36,6 @@ def __init__( in_channels_1, in_channels_2, hid_channels, - num_classes, n_layers, ): super().__init__() @@ -58,10 +55,6 @@ def __init__( ) self.layers = torch.nn.ModuleList(layers) - self.lin_0 = torch.nn.Linear(hid_channels, num_classes) - self.lin_1 = torch.nn.Linear(hid_channels, num_classes) - self.lin_2 = torch.nn.Linear(hid_channels, num_classes) - def forward( self, x_0, @@ -90,8 +83,12 @@ def forward( Returns ------- - torch.Tensor, shape = (1) - Label assigned to whole complex. + x_0 : torch.Tensor, shape = (n_nodes, in_channels_0) + Final hidden states of the nodes (0-cells). + x_1 : torch.Tensor, shape = (n_edges, in_channels_1) + Final hidden states the edges (1-cells). + x_2 : torch.Tensor, shape = (n_edges, in_channels_2) + Final hidden states of the faces (2-cells). """ x_0 = F.elu(self.proj_0(x_0)) x_1 = F.elu(self.proj_1(x_1)) @@ -107,23 +104,4 @@ def forward( neighborhood_0_to_1, ) - x_0 = self.lin_0(x_0) - x_1 = self.lin_1(x_1) - x_2 = self.lin_2(x_2) - - # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0. - two_dimensional_cells_mean = torch.nanmean(x_2, dim=0) - two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0 - - one_dimensional_cells_mean = torch.nanmean(x_1, dim=0) - one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0 - - zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0) - zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0 - - # Return the sum of the averages - return ( - two_dimensional_cells_mean - + one_dimensional_cells_mean - + zero_dimensional_cells_mean - ) + return x_0, x_1, x_2 diff --git a/topomodelx/nn/combinatorial/hmc.py b/topomodelx/nn/combinatorial/hmc.py index c7486d1e..2216f9be 100644 --- a/topomodelx/nn/combinatorial/hmc.py +++ b/topomodelx/nn/combinatorial/hmc.py @@ -18,22 +18,24 @@ class HMC(torch.nn.Module): for each input signal (nodes, edges, and faces) for the k-th layer. The second list contains the number of intermediate channels for each input signal (nodes, edges, and faces) for the k-th layer. Finally, the third list contains the number of output channels for - each input signal (nodes, edges, and faces) for the k-th layer . - num_classes : int - Number of classes. + each input signal (nodes, edges, and faces) for the k-th layer. negative_slope : float Negative slope for the LeakyReLU activation. + update_func_attention : str + Update function for the attention mechanism. Default is "relu". + update_func_aggregation : str + Update function for the aggregation mechanism. Default is "relu". """ def __init__( self, channels_per_layer, - num_classes, negative_slope=0.2, update_func_attention="relu", update_func_aggregation="relu", ) -> None: def check_channels_consistency(): + """Check that the number of input, intermediate, and output channels is consistent.""" assert len(channels_per_layer) > 0 for i in range(len(channels_per_layer) - 1): assert channels_per_layer[i][2][0] == channels_per_layer[i + 1][0][0] @@ -41,7 +43,6 @@ def check_channels_consistency(): assert channels_per_layer[i][2][2] == channels_per_layer[i + 1][0][2] super().__init__() - self.num_classes = num_classes check_channels_consistency() self.layers = torch.nn.ModuleList( [ @@ -58,10 +59,6 @@ def check_channels_consistency(): ] ) - self.l0 = torch.nn.Linear(channels_per_layer[-1][2][0], num_classes) - self.l1 = torch.nn.Linear(channels_per_layer[-1][2][1], num_classes) - self.l2 = torch.nn.Linear(channels_per_layer[-1][2][2], num_classes) - def forward( self, x_0, @@ -72,7 +69,7 @@ def forward( neighborhood_2_to_2, neighborhood_0_to_1, neighborhood_1_to_2, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass. Parameters @@ -96,8 +93,12 @@ def forward( Returns ------- - y_hat : torch.Tensor, shape=[num_classes] - Vector embedding that represents the probability of the input mesh to belong to each class. + torch.Tensor, shape = (n_nodes, out_channels_0) + Final hidden states of the nodes (0-cells). + torch.Tensor, shape = (n_edges, out_channels_1) + Final hidden states the edges (1-cells). + torch.Tensor, shape = (n_faces, out_channels_2) + Final hidden states of the faces (2-cells). """ for layer in self.layers: x_0, x_1, x_2 = layer( @@ -111,13 +112,4 @@ def forward( neighborhood_1_to_2, ) - x_0 = self.l0(x_0) - x_1 = self.l1(x_1) - x_2 = self.l2(x_2) - - # Sum all the elements in the dimension zero - x_0 = torch.nanmean(x_0, dim=0) - x_1 = torch.nanmean(x_1, dim=0) - x_2 = torch.nanmean(x_2, dim=0) - - return x_0 + x_1 + x_2 + return x_0, x_1, x_2 diff --git a/topomodelx/nn/simplicial/san.py b/topomodelx/nn/simplicial/san.py index cca936c8..62e275d4 100644 --- a/topomodelx/nn/simplicial/san.py +++ b/topomodelx/nn/simplicial/san.py @@ -5,7 +5,7 @@ class SAN(torch.nn.Module): - r"""Simplicial Attention Network (SAN) implementation for binary edge classification. + """Simplicial Attention Network (SAN) implementation for binary edge classification. Parameters ---------- @@ -36,6 +36,7 @@ def __init__( n_layers=2, ): super().__init__() + self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = ( @@ -76,6 +77,7 @@ def __init__( n_filters=self.n_filters, ) ) + self.layers = torch.nn.ModuleList(self.layers) def compute_projection_matrix(self, laplacian): """Compute the projection matrix. @@ -92,9 +94,8 @@ def compute_projection_matrix(self, laplacian): torch.Tensor, shape = (n_edges, n_edges) Projection matrix. """ - projection_mat = ( - torch.eye(laplacian.shape[0]) - self.epsilon_harmonic * laplacian - ) + eye = torch.eye(laplacian.shape[0]).to(laplacian.device) + projection_mat = eye - self.epsilon_harmonic * laplacian projection_mat = torch.linalg.matrix_power(projection_mat, self.order_harmonic) return projection_mat diff --git a/tutorials/cell/can_train.ipynb b/tutorials/cell/can_train.ipynb index 900fbd40..080058fc 100644 --- a/tutorials/cell/can_train.ipynb +++ b/tutorials/cell/can_train.ipynb @@ -136,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:53.006542411Z", @@ -148,7 +148,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "cuda:2\n" + "cpu\n" ] } ], @@ -176,7 +176,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:55.279147916Z", @@ -234,7 +234,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:55.832585216Z", @@ -279,37 +279,70 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class Network(torch.nn.Module):\n", + " def __init__(\n", + " self,\n", + " in_channels_0,\n", + " in_channels_1,\n", + " out_channels,\n", + " num_classes,\n", + " dropout=0.5,\n", + " heads=2,\n", + " n_layers=2,\n", + " att_lift=True,\n", + " ):\n", + " super().__init__()\n", + " self.base_model = CAN(\n", + " in_channels_0,\n", + " in_channels_1,\n", + " out_channels,\n", + " dropout=dropout,\n", + " heads=heads,\n", + " n_layers=n_layers,\n", + " att_lift=att_lift,\n", + " )\n", + " self.lin_0 = torch.nn.Linear(heads * out_channels, 128)\n", + " self.lin_1 = torch.nn.Linear(128, num_classes)\n", + "\n", + " def forward(self, x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood):\n", + " x = self.base_model(x_0, x_1, adjacency, lower_neighborhood, upper_neighborhood)\n", + " # max pooling over edges in each graph\n", + " x = x.max(dim=0)[0]\n", + " # Feed-Foward Neural Network to predict the graph label\n", + " out = self.lin_1(torch.nn.functional.relu(self.lin_0(x)))\n", + " return torch.sigmoid(out)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:56.672913561Z", "start_time": "2023-05-31T09:13:56.667986426Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The dimension of input features on nodes, edges and faces are: 7, 4 and 5.\n" - ] - } - ], + "outputs": [], "source": [ "in_channels_0 = x_0_list[0].shape[-1]\n", "in_channels_1 = x_1_list[0].shape[-1]\n", - "in_channels_2 = 5\n", - "print(\n", - " f\"The dimension of input features on nodes, edges and faces are: {in_channels_0}, {in_channels_1} and {in_channels_2}.\"\n", - ")\n", - "model = CAN(\n", + "out_channels = 32\n", + "num_classes = 2\n", + "heads = 2\n", + "n_layers = 2\n", + "\n", + "model = Network(\n", " in_channels_0,\n", " in_channels_1,\n", - " 32,\n", + " out_channels,\n", + " num_classes,\n", " dropout=0.5,\n", - " heads=2,\n", - " num_classes=2,\n", - " n_layers=2,\n", + " heads=heads,\n", + " n_layers=n_layers,\n", " att_lift=True,\n", ")\n", "model = model.to(device)" @@ -326,7 +359,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:40.411845803Z", @@ -337,37 +370,39 @@ { "data": { "text/plain": [ - "CAN(\n", - " (lift_layer): MultiHeadLiftLayer(\n", - " (lifts): LiftLayer()\n", - " )\n", - " (layers): ModuleList(\n", - " (0): CANLayer(\n", - " (lower_att): MultiHeadCellAttention(\n", - " (att_activation): LeakyReLU(negative_slope=0.2)\n", - " (lin): Linear(in_features=11, out_features=64, bias=False)\n", - " )\n", - " (upper_att): MultiHeadCellAttention(\n", - " (att_activation): LeakyReLU(negative_slope=0.2)\n", + "Network(\n", + " (base_model): CAN(\n", + " (lift_layer): MultiHeadLiftLayer(\n", + " (lifts): LiftLayer()\n", + " )\n", + " (layers): ModuleList(\n", + " (0): CANLayer(\n", + " (lower_att): MultiHeadCellAttention(\n", + " (att_activation): LeakyReLU(negative_slope=0.2)\n", + " (lin): Linear(in_features=11, out_features=64, bias=False)\n", + " )\n", + " (upper_att): MultiHeadCellAttention(\n", + " (att_activation): LeakyReLU(negative_slope=0.2)\n", + " (lin): Linear(in_features=11, out_features=64, bias=False)\n", + " )\n", " (lin): Linear(in_features=11, out_features=64, bias=False)\n", + " (aggregation): Aggregation()\n", " )\n", - " (lin): Linear(in_features=11, out_features=64, bias=False)\n", - " (aggregation): Aggregation()\n", - " )\n", - " (1): CANLayer(\n", - " (lower_att): MultiHeadCellAttention(\n", - " (att_activation): LeakyReLU(negative_slope=0.2)\n", + " (1): CANLayer(\n", + " (lower_att): MultiHeadCellAttention(\n", + " (att_activation): LeakyReLU(negative_slope=0.2)\n", + " (lin): Linear(in_features=64, out_features=64, bias=False)\n", + " )\n", + " (upper_att): MultiHeadCellAttention(\n", + " (att_activation): LeakyReLU(negative_slope=0.2)\n", + " (lin): Linear(in_features=64, out_features=64, bias=False)\n", + " )\n", " (lin): Linear(in_features=64, out_features=64, bias=False)\n", + " (aggregation): Aggregation()\n", " )\n", - " (upper_att): MultiHeadCellAttention(\n", - " (att_activation): LeakyReLU(negative_slope=0.2)\n", - " (lin): Linear(in_features=64, out_features=64, bias=False)\n", + " (2): PoolLayer(\n", + " (signal_pool_activation): Sigmoid()\n", " )\n", - " (lin): Linear(in_features=64, out_features=64, bias=False)\n", - " (aggregation): Aggregation()\n", - " )\n", - " (2): PoolLayer(\n", - " (signal_pool_activation): Sigmoid()\n", " )\n", " )\n", " (lin_0): Linear(in_features=64, out_features=128, bias=True)\n", @@ -375,7 +410,7 @@ ")" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -395,7 +430,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:41.150933630Z", @@ -428,7 +463,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:42.918836083Z", @@ -440,28 +475,32 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 1 loss: 0.6348 Train_acc: 0.6718\n", + "Epoch: 1 loss: 0.6200 Train_acc: 0.6947\n", "Test_acc: 0.5965\n", - "Epoch: 2 loss: 0.6101 Train_acc: 0.6947\n", + "Epoch: 2 loss: 0.6110 Train_acc: 0.6947\n", "Test_acc: 0.5965\n", - "Epoch: 3 loss: 0.6008 Train_acc: 0.6947\n", + "Epoch: 3 loss: 0.6054 Train_acc: 0.6947\n", "Test_acc: 0.5965\n", - "Epoch: 4 loss: 0.5888 Train_acc: 0.7099\n", + "Epoch: 4 loss: 0.5990 Train_acc: 0.6947\n", "Test_acc: 0.6316\n", - "Epoch: 5 loss: 0.5850 Train_acc: 0.7252\n", - "Test_acc: 0.7368\n", - "Epoch: 6 loss: 0.5841 Train_acc: 0.7328\n", - "Test_acc: 0.6491\n", - "Epoch: 7 loss: 0.5772 Train_acc: 0.7328\n", + "Epoch: 5 loss: 0.6021 Train_acc: 0.7099\n", + "Test_acc: 0.6316\n", + "Epoch: 6 loss: 0.5911 Train_acc: 0.7252\n", + "Test_acc: 0.6316\n", + "Epoch: 7 loss: 0.5889 Train_acc: 0.7176\n", + "Test_acc: 0.6316\n", + "Epoch: 8 loss: 0.5829 Train_acc: 0.7252\n", + "Test_acc: 0.6842\n", + "Epoch: 9 loss: 0.5786 Train_acc: 0.7252\n", "Test_acc: 0.6491\n", - "Epoch: 8 loss: 0.5375 Train_acc: 0.7405\n", - "Test_acc: 0.7193\n" + "Epoch: 10 loss: 0.5746 Train_acc: 0.7328\n", + "Test_acc: 0.6842\n" ] } ], "source": [ "test_interval = 1\n", - "num_epochs = 2\n", + "num_epochs = 10\n", "for epoch_i in range(1, num_epochs + 1):\n", " epoch_loss = []\n", " num_samples = 0\n", @@ -525,13 +564,20 @@ " test_acc = correct / num_samples\n", " print(f\"Test_acc: {test_acc:.4f}\", flush=True)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.11.6 64-bit", + "display_name": "venv_modelx", "language": "python", - "name": "python3" + "name": "venv_modelx" }, "language_info": { "codemirror_mode": { @@ -543,7 +589,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.11.3" }, "vscode": { "interpreter": { diff --git a/tutorials/cell/ccxn_train.ipynb b/tutorials/cell/ccxn_train.ipynb index 8ac5d2ad..bc4df8ae 100644 --- a/tutorials/cell/ccxn_train.ipynb +++ b/tutorials/cell/ccxn_train.ipynb @@ -4,6 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "\n", "# Train a Convolutional Cell Complex Network (CCXN)\n", "\n", "We create and train a simplified version of the CCXN originally proposed in [Hajij et. al : Cell Complex Neural Networks (2020)](https://arxiv.org/pdf/2010.00743.pdf).\n", @@ -66,7 +67,10 @@ "from topomodelx.utils.sparse import from_sparse\n", "\n", "torch.manual_seed(0)\n", - "np.random.seed(0)" + "np.random.seed(0)\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" ] }, { @@ -90,7 +94,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "cuda\n" + "cpu\n" ] } ], @@ -242,6 +246,56 @@ { "cell_type": "code", "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "class Network(torch.nn.Module):\n", + " def __init__(\n", + " self,\n", + " in_channels_0,\n", + " in_channels_1,\n", + " in_channels_2,\n", + " num_classes,\n", + " n_layers=2,\n", + " att=False,\n", + " ):\n", + " super().__init__()\n", + " self.base_model = CCXN(\n", + " in_channels_0,\n", + " in_channels_1,\n", + " in_channels_2,\n", + " n_layers=n_layers,\n", + " att=att,\n", + " )\n", + " self.lin_0 = torch.nn.Linear(in_channels_0, num_classes)\n", + " self.lin_1 = torch.nn.Linear(in_channels_1, num_classes)\n", + " self.lin_2 = torch.nn.Linear(in_channels_2, num_classes)\n", + "\n", + " def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):\n", + " x_0, x_1, x_2 = self.base_model(\n", + " x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2\n", + " )\n", + " x_0 = self.lin_0(x_0)\n", + " x_1 = self.lin_1(x_1)\n", + " x_2 = self.lin_2(x_2)\n", + " # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.\n", + " two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)\n", + " two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0\n", + " one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)\n", + " one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0\n", + " zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)\n", + " zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0\n", + " # Return the sum of the averages\n", + " return (\n", + " two_dimensional_cells_mean\n", + " + one_dimensional_cells_mean\n", + " + zero_dimensional_cells_mean\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:56.672913561Z", @@ -261,13 +315,46 @@ "in_channels_0 = x_0s[0].shape[-1]\n", "in_channels_1 = x_1s[0].shape[-1]\n", "in_channels_2 = 5\n", + "num_classes = 2\n", "print(\n", " f\"The dimension of input features on nodes, edges and faces are: {in_channels_0}, {in_channels_1} and {in_channels_2}.\"\n", ")\n", - "model = CCXN(in_channels_0, in_channels_1, in_channels_2, num_classes=1, n_layers=2)\n", + "model = Network(in_channels_0, in_channels_1, in_channels_2, num_classes, n_layers=2)\n", "model = model.to(device)" ] }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Network(\n", + " (base_model): CCXN(\n", + " (layers): ModuleList(\n", + " (0-1): 2 x CCXNLayer(\n", + " (conv_0_to_0): Conv()\n", + " (conv_1_to_2): Conv()\n", + " )\n", + " )\n", + " )\n", + " (lin_0): Linear(in_features=6, out_features=2, bias=True)\n", + " (lin_1): Linear(in_features=10, out_features=2, bias=True)\n", + " (lin_2): Linear(in_features=5, out_features=2, bias=True)\n", + ")" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -279,7 +366,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:40.411845803Z", @@ -302,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:41.150933630Z", @@ -332,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:42.918836083Z", @@ -340,11 +427,18 @@ } }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch:2, Train Loss: 83.8803 Test Loss: 72.8717\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "/usr/local/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n", + "/Users/gbg141/Documents/TopoProjectX/TopoModelX/venv_modelx/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n", " return F.mse_loss(input, target, reduction=self.reduction)\n" ] }, @@ -352,11 +446,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch:2, Train Loss: 82.9399 Test Loss: 72.3905\n", - "Epoch:4, Train Loss: 79.9278 Test Loss: 76.1475\n", - "Epoch:6, Train Loss: 76.7409 Test Loss: 75.5150\n", - "Epoch:8, Train Loss: 74.2580 Test Loss: 76.6475\n", - "Epoch:10, Train Loss: 72.1732 Test Loss: 78.3276\n" + "Epoch:4, Train Loss: 80.8463 Test Loss: 74.7231\n", + "Epoch:6, Train Loss: 77.9684 Test Loss: 75.5384\n", + "Epoch:8, Train Loss: 75.5704 Test Loss: 76.0005\n", + "Epoch:10, Train Loss: 73.3453 Test Loss: 78.1194\n" ] } ], @@ -424,7 +517,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:20:01.334080388Z", @@ -433,8 +526,8 @@ }, "outputs": [], "source": [ - "model = CCXN(\n", - " in_channels_0, in_channels_1, in_channels_2, num_classes=2, n_layers=2, att=True\n", + "model = Network(\n", + " in_channels_0, in_channels_1, in_channels_2, num_classes, n_layers=2, att=True\n", ")\n", "model = model.to(device)\n", "crit = torch.nn.CrossEntropyLoss()\n", @@ -458,7 +551,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:21:57.205551344Z", @@ -466,29 +559,21 @@ } }, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([2])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n", - " return F.mse_loss(input, target, reduction=self.reduction)\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch:2, Train Loss: 96.2627 Test Loss: 67.7411\n", - "Epoch:4, Train Loss: 83.8081 Test Loss: 60.9365\n", - "Epoch:6, Train Loss: 78.7461 Test Loss: 64.3800\n", - "Epoch:8, Train Loss: 76.0083 Test Loss: 67.3870\n", - "Epoch:10, Train Loss: 73.6383 Test Loss: 67.6792\n" + "Epoch:2, Train Loss: 86.7494 Test Loss: 81.5152\n", + "Epoch:4, Train Loss: 79.9747 Test Loss: 85.7426\n", + "Epoch:6, Train Loss: 76.6452 Test Loss: 89.2596\n", + "Epoch:8, Train Loss: 74.2124 Test Loss: 88.7206\n", + "Epoch:10, Train Loss: 72.5803 Test Loss: 87.3506\n" ] } ], "source": [ "test_interval = 2\n", - "num_epochs = 2\n", + "num_epochs = 10\n", "for epoch_i in range(1, num_epochs + 1):\n", " epoch_loss = []\n", " model.train()\n", @@ -544,7 +629,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.6 64-bit", + "display_name": "venv_modelx", "language": "python", "name": "python3" }, @@ -559,11 +644,6 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" - }, - "vscode": { - "interpreter": { - "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" - } } }, "nbformat": 4, diff --git a/tutorials/cell/cwn_train.ipynb b/tutorials/cell/cwn_train.ipynb index 0e60cf6e..a2c4691b 100644 --- a/tutorials/cell/cwn_train.ipynb +++ b/tutorials/cell/cwn_train.ipynb @@ -94,7 +94,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "cuda\n" + "cpu\n" ] } ], @@ -289,6 +289,68 @@ { "cell_type": "code", "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "class Network(torch.nn.Module):\n", + " def __init__(\n", + " self,\n", + " in_channels_0,\n", + " in_channels_1,\n", + " in_channels_2,\n", + " hid_channels=16,\n", + " num_classes=1,\n", + " n_layers=2,\n", + " ):\n", + " super().__init__()\n", + " self.base_model = CWN(\n", + " in_channels_0,\n", + " in_channels_1,\n", + " in_channels_2,\n", + " hid_channels=hid_channels,\n", + " n_layers=n_layers,\n", + " )\n", + " self.lin_0 = torch.nn.Linear(hid_channels, num_classes)\n", + " self.lin_1 = torch.nn.Linear(hid_channels, num_classes)\n", + " self.lin_2 = torch.nn.Linear(hid_channels, num_classes)\n", + "\n", + " def forward(\n", + " self,\n", + " x_0,\n", + " x_1,\n", + " x_2,\n", + " neighborhood_1_to_1,\n", + " neighborhood_2_to_1,\n", + " neighborhood_0_to_1,\n", + " ):\n", + " x_0, x_1, x_2 = self.base_model(\n", + " x_0, x_1, x_2, neighborhood_1_to_1, neighborhood_2_to_1, neighborhood_0_to_1\n", + " )\n", + " x_0 = self.lin_0(x_0)\n", + " x_1 = self.lin_1(x_1)\n", + " x_2 = self.lin_2(x_2)\n", + "\n", + " # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.\n", + " two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)\n", + " two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0\n", + "\n", + " one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)\n", + " one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0\n", + "\n", + " zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)\n", + " zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0\n", + "\n", + " # Return the sum of the averages\n", + " return (\n", + " two_dimensional_cells_mean\n", + " + one_dimensional_cells_mean\n", + " + zero_dimensional_cells_mean\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:13:56.672913561Z", @@ -318,7 +380,7 @@ " f\"The dimensions of input features on nodes, edges and faces are \"\n", " f\"{in_channels_0}, {in_channels_1} and {in_channels_2}, respectively.\"\n", ")\n", - "model = CWN(\n", + "model = Network(\n", " in_channels_0,\n", " in_channels_1,\n", " in_channels_2,\n", @@ -342,7 +404,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:40.411845803Z", @@ -367,7 +429,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:41.150933630Z", @@ -414,7 +476,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2023-05-31T09:19:42.918836083Z", @@ -431,11 +493,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch:2, Train Loss: 83.8053 Test Loss: 73.7517\n", - "Epoch:4, Train Loss: 81.9552 Test Loss: 50.2781\n", - "Epoch:6, Train Loss: 78.3991 Test Loss: 49.9034\n", - "Epoch:8, Train Loss: 75.8107 Test Loss: 45.7201\n", - "Epoch:10, Train Loss: 74.3833 Test Loss: 40.5558\n" + "Epoch:2, Train Loss: 83.8053 Test Loss: 73.7517\n" ] } ], @@ -518,9 +576,9 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3.9.6 64-bit", + "display_name": "venv_modelx", "language": "python", - "name": "python3" + "name": "venv_modelx" }, "language_info": { "codemirror_mode": { diff --git a/tutorials/combinatorial/hmc_train.ipynb b/tutorials/combinatorial/hmc_train.ipynb index 27726b20..0bd18935 100644 --- a/tutorials/combinatorial/hmc_train.ipynb +++ b/tutorials/combinatorial/hmc_train.ipynb @@ -156,7 +156,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "cuda\n" + "cpu\n" ] } ], @@ -606,6 +606,68 @@ " return test_accuracy" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We generate our Network, combining HOAN model with the appropriate readout for the considered task" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "class Network(torch.nn.Module):\n", + " def __init__(\n", + " self,\n", + " channels_per_layer,\n", + " negative_slope=0.2,\n", + " num_classes=2,\n", + " ):\n", + " super().__init__()\n", + " self.base_model = HMC(\n", + " channels_per_layer,\n", + " negative_slope,\n", + " )\n", + " self.l0 = torch.nn.Linear(channels_per_layer[-1][2][0], num_classes)\n", + " self.l1 = torch.nn.Linear(channels_per_layer[-1][2][1], num_classes)\n", + " self.l2 = torch.nn.Linear(channels_per_layer[-1][2][2], num_classes)\n", + "\n", + " def forward(\n", + " self,\n", + " x_0,\n", + " x_1,\n", + " x_2,\n", + " neighborhood_0_to_0,\n", + " neighborhood_1_to_1,\n", + " neighborhood_2_to_2,\n", + " neighborhood_0_to_1,\n", + " neighborhood_1_to_2,\n", + " ):\n", + " x_0, x_1, x_2 = self.base_model(\n", + " x_0,\n", + " x_1,\n", + " x_2,\n", + " neighborhood_0_to_0,\n", + " neighborhood_1_to_1,\n", + " neighborhood_2_to_2,\n", + " neighborhood_0_to_1,\n", + " neighborhood_1_to_2,\n", + " )\n", + " x_0 = self.l0(x_0)\n", + " x_1 = self.l1(x_1)\n", + " x_2 = self.l2(x_2)\n", + "\n", + " # Sum all the elements in the dimension zero\n", + " x_0 = torch.nanmean(x_0, dim=0)\n", + " x_1 = torch.nanmean(x_1, dim=0)\n", + " x_2 = torch.nanmean(x_2, dim=0)\n", + "\n", + " return x_0 + x_1 + x_2" + ] + }, { "cell_type": "markdown", "metadata": { @@ -620,7 +682,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2023-08-24T06:55:28.220563261Z", @@ -639,7 +701,7 @@ "\n", "channels_per_layer = [[in_channels, intermediate_channels, final_channels]]\n", "# defube HOAN mesh classifier\n", - "model = HMC(\n", + "model = Network(\n", " channels_per_layer, negative_slope=0.2, num_classes=training_dataset.num_classes()\n", ")\n", "\n", @@ -649,6 +711,57 @@ "trainer = Trainer(model, training_dataloader, testing_dataloader, 0.001, device)" ] }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Network(\n", + " (base_model): HMC(\n", + " (layers): ModuleList(\n", + " (0): HMCLayer(\n", + " (hbs_0_level1): HBS(\n", + " (weight): ParameterList( (0): Parameter containing: [torch.float32 of size 6x60])\n", + " (att_weight): ParameterList( (0): Parameter containing: [torch.float32 of size 120x1])\n", + " )\n", + " (hbns_0_1_level1): HBNS()\n", + " (hbns_1_2_level1): HBNS()\n", + " (hbs_0_level2): HBS(\n", + " (weight): ParameterList( (0): Parameter containing: [torch.float32 of size 60x60])\n", + " (att_weight): ParameterList( (0): Parameter containing: [torch.float32 of size 120x1])\n", + " )\n", + " (hbns_0_1_level2): HBNS()\n", + " (hbs_1_level2): HBS(\n", + " (weight): ParameterList( (0): Parameter containing: [torch.float32 of size 60x60])\n", + " (att_weight): ParameterList( (0): Parameter containing: [torch.float32 of size 120x1])\n", + " )\n", + " (hbns_1_2_level2): HBNS()\n", + " (hbs_2_level2): HBS(\n", + " (weight): ParameterList( (0): Parameter containing: [torch.float32 of size 60x60])\n", + " (att_weight): ParameterList( (0): Parameter containing: [torch.float32 of size 120x1])\n", + " )\n", + " (aggr): Aggregation()\n", + " )\n", + " )\n", + " )\n", + " (l0): Linear(in_features=60, out_features=30, bias=True)\n", + " (l1): Linear(in_features=60, out_features=30, bias=True)\n", + " (l2): Linear(in_features=60, out_features=30, bias=True)\n", + ")" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -658,7 +771,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2023-08-24T07:01:54.496249166Z", @@ -670,20 +783,28 @@ } }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/gbg141/Documents/Projects/TopoModelX/topomodelx/nn/combinatorial/hmc_layer.py:683: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/SparseCsrTensorImpl.cpp:56.)\n", + " A_p = torch.sparse.mm(A_p, neighborhood)\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 0 loss: 3.6020 Train_acc: 0.0333\n", - "Test_acc: 0.0917\n", - "Epoch: 1 loss: 3.2161 Train_acc: 0.0750\n", - "Test_acc: 0.1333\n", - "Epoch: 2 loss: 2.9366 Train_acc: 0.1437\n", + "Epoch: 0 loss: 3.5569 Train_acc: 0.0292\n", + "Test_acc: 0.0667\n", + "Epoch: 1 loss: 3.2807 Train_acc: 0.0688\n", + "Test_acc: 0.1583\n", + "Epoch: 2 loss: 2.9899 Train_acc: 0.1125\n", + "Test_acc: 0.1417\n", + "Epoch: 3 loss: 2.6567 Train_acc: 0.1792\n", "Test_acc: 0.1583\n", - "Epoch: 3 loss: 2.6148 Train_acc: 0.2167\n", - "Test_acc: 0.2500\n", - "Epoch: 4 loss: 2.3257 Train_acc: 0.2875\n", - "Test_acc: 0.2750\n" + "Epoch: 4 loss: 2.3474 Train_acc: 0.2583\n", + "Test_acc: 0.3250\n" ] } ], @@ -719,9 +840,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "venv_modelx", "language": "python", - "name": "python3" + "name": "venv_modelx" }, "language_info": { "codemirror_mode": { diff --git a/tutorials/simplicial/san_train.ipynb b/tutorials/simplicial/san_train.ipynb index efd2cb5a..e3d52c21 100644 --- a/tutorials/simplicial/san_train.ipynb +++ b/tutorials/simplicial/san_train.ipynb @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 1, "metadata": { "id": "ZNrtWfL10pEe" }, @@ -80,8 +80,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" + " No module named 'igraph'. If you need to use hypernetx.algorithms.hypergraph_modularity, please install additional packages by running the following command: pip install .['all']\n" ] } ], @@ -99,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 2, "metadata": { "id": "Z05cyYcw0pEh", "outputId": "0ba2482b-dc68-451b-95bd-d2ea97e04378" @@ -135,7 +134,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 3, "metadata": { "id": "BreEb4B00pEi", "outputId": "fed8fb52-16b7-418e-812c-7f29bf32de1c" @@ -156,7 +155,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 4, "metadata": { "id": "fdT4Zjsp0pEi", "outputId": "054e2a94-3503-421d-d10c-8a9ceb21cce9" @@ -168,7 +167,7 @@ "(34, 78, 45, 11, 2)" ] }, - "execution_count": 21, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -200,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 5, "metadata": { "id": "nukmcsOJ0pEj" }, @@ -238,7 +237,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 6, "metadata": { "id": "EBf26K5N0pEj", "outputId": "22acb245-384f-4d62-9e59-b55bdea3353e" @@ -286,7 +285,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 7, "metadata": { "id": "MmP56FAH0pEk" }, @@ -306,7 +305,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 8, "metadata": { "id": "K2zU2_jL0pEk" }, @@ -326,7 +325,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 9, "metadata": { "id": "cS7o2U620pEl" }, @@ -351,7 +350,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 10, "metadata": { "id": "Ub-GWaMm0pEl" }, @@ -418,7 +417,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -440,7 +439,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 12, "metadata": { "id": "2n-lFfkJ0pEm", "outputId": "3c633dee-f904-49ba-c169-f21499488cb4" @@ -452,7 +451,7 @@ "(torch.Size([78, 78]), torch.Size([78, 78]), torch.Size([78, 2]))" ] }, - "execution_count": 44, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -463,9 +462,31 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Network(\n", + " (base_model): SAN(\n", + " (layers): ModuleList(\n", + " (0): SANLayer(\n", + " (conv_down): SANConv()\n", + " (conv_up): SANConv()\n", + " (conv_harmonic): Conv()\n", + " )\n", + " )\n", + " )\n", + " (linear): Linear(in_features=16, out_features=2, bias=True)\n", + ")" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "n_layers = 1\n", "model = Network(\n", @@ -474,7 +495,8 @@ " out_channels=out_channels,\n", " n_layers=n_layers,\n", ")\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)" + "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n", + "model" ] }, { @@ -497,7 +519,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 14, "metadata": { "id": "sWopn2U60pEn", "outputId": "a2153bef-91eb-4f99-8aad-3166b4966c26" @@ -507,238 +529,67 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch: 1 loss: 0.7225 Train_acc: 0.7000\n", - "Epoch: 2 loss: 0.7185 Train_acc: 0.7000\n", - "Epoch: 3 loss: 0.7139 Train_acc: 0.7000\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch: 4 loss: 0.7091 Train_acc: 0.7333\n", - "Epoch: 5 loss: 0.7044 Train_acc: 0.7333\n", - "Epoch: 6 loss: 0.6999 Train_acc: 0.7333\n", - "Epoch: 7 loss: 0.6958 Train_acc: 0.7333\n", - "Epoch: 8 loss: 0.6922 Train_acc: 0.7333\n", - "Epoch: 9 loss: 0.6890 Train_acc: 0.7333\n", - "Epoch: 10 loss: 0.6862 Train_acc: 0.7333\n", - "Test_acc: 0.2500\n", - "Epoch: 11 loss: 0.6837 Train_acc: 0.7333\n", - "Epoch: 12 loss: 0.6816 Train_acc: 0.7333\n", - "Epoch: 13 loss: 0.6798 Train_acc: 0.7333\n", - "Epoch: 14 loss: 0.6782 Train_acc: 0.7333\n", - "Epoch: 15 loss: 0.6768 Train_acc: 0.7333\n", - "Epoch: 16 loss: 0.6756 Train_acc: 0.7333\n", - "Epoch: 17 loss: 0.6746 Train_acc: 0.7333\n", - "Epoch: 18 loss: 0.6737 Train_acc: 0.7333\n", - "Epoch: 19 loss: 0.6730 Train_acc: 0.7333\n", - "Epoch: 20 loss: 0.6723 Train_acc: 0.7333\n", - "Test_acc: 0.2500\n", - "Epoch: 21 loss: 0.6717 Train_acc: 0.7333\n", - "Epoch: 22 loss: 0.6712 Train_acc: 0.7333\n", - "Epoch: 23 loss: 0.6708 Train_acc: 0.7333\n", - "Epoch: 24 loss: 0.6704 Train_acc: 0.7333\n", - "Epoch: 25 loss: 0.6700 Train_acc: 0.7333\n", - "Epoch: 26 loss: 0.6697 Train_acc: 0.7333\n", - "Epoch: 27 loss: 0.6694 Train_acc: 0.7333\n", - "Epoch: 28 loss: 0.6692 Train_acc: 0.7333\n", - "Epoch: 29 loss: 0.6689 Train_acc: 0.7333\n", - "Epoch: 30 loss: 0.6687 Train_acc: 0.7333\n", - "Test_acc: 0.2500\n", - "Epoch: 31 loss: 0.6685 Train_acc: 0.7333\n", - "Epoch: 32 loss: 0.6683 Train_acc: 0.7333\n", - "Epoch: 33 loss: 0.6682 Train_acc: 0.7333\n", - "Epoch: 34 loss: 0.6680 Train_acc: 0.7333\n", - "Epoch: 35 loss: 0.6678 Train_acc: 0.7333\n", - "Epoch: 36 loss: 0.6677 Train_acc: 0.7333\n", - "Epoch: 37 loss: 0.6676 Train_acc: 0.7333\n", - "Epoch: 38 loss: 0.6674 Train_acc: 0.7333\n", - "Epoch: 39 loss: 0.6673 Train_acc: 0.7333\n", - "Epoch: 40 loss: 0.6672 Train_acc: 0.7333\n", - "Test_acc: 0.2500\n", - "Epoch: 41 loss: 0.6670 Train_acc: 0.7333\n", - "Epoch: 42 loss: 0.6669 Train_acc: 0.7333\n", - "Epoch: 43 loss: 0.6668 Train_acc: 0.7333\n", - "Epoch: 44 loss: 0.6666 Train_acc: 0.7333\n", - "Epoch: 45 loss: 0.6665 Train_acc: 0.7333\n", - "Epoch: 46 loss: 0.6664 Train_acc: 0.7333\n", - "Epoch: 47 loss: 0.6663 Train_acc: 0.7333\n", - "Epoch: 48 loss: 0.6661 Train_acc: 0.7333\n", - "Epoch: 49 loss: 0.6660 Train_acc: 0.7333\n", - "Epoch: 50 loss: 0.6659 Train_acc: 0.7333\n", - "Test_acc: 0.2500\n", - "Epoch: 51 loss: 0.6657 Train_acc: 0.7333\n", - "Epoch: 52 loss: 0.6656 Train_acc: 0.7333\n", - "Epoch: 53 loss: 0.6654 Train_acc: 0.7333\n", - "Epoch: 54 loss: 0.6653 Train_acc: 0.7333\n", - "Epoch: 55 loss: 0.6652 Train_acc: 0.7333\n", - "Epoch: 56 loss: 0.6650 Train_acc: 0.7333\n", - "Epoch: 57 loss: 0.6648 Train_acc: 0.7333\n", - "Epoch: 58 loss: 0.6647 Train_acc: 0.7333\n", - "Epoch: 59 loss: 0.6645 Train_acc: 0.7333\n", - "Epoch: 60 loss: 0.6644 Train_acc: 0.7333\n", - "Test_acc: 0.2500\n", - "Epoch: 61 loss: 0.6642 Train_acc: 0.7333\n", - "Epoch: 62 loss: 0.6640 Train_acc: 0.7333\n", - "Epoch: 63 loss: 0.6638 Train_acc: 0.7333\n", - "Epoch: 64 loss: 0.6637 Train_acc: 0.7333\n", - "Epoch: 65 loss: 0.6635 Train_acc: 0.7333\n", - "Epoch: 66 loss: 0.6633 Train_acc: 0.7333\n", - "Epoch: 67 loss: 0.6631 Train_acc: 0.7333\n", - "Epoch: 68 loss: 0.6629 Train_acc: 0.7333\n", - "Epoch: 69 loss: 0.6627 Train_acc: 0.7333\n", - "Epoch: 70 loss: 0.6624 Train_acc: 0.7333\n", - "Test_acc: 0.2500\n", - "Epoch: 71 loss: 0.6622 Train_acc: 0.7333\n", - "Epoch: 72 loss: 0.6620 Train_acc: 0.7333\n", - "Epoch: 73 loss: 0.6618 Train_acc: 0.7333\n", - "Epoch: 74 loss: 0.6615 Train_acc: 0.7333\n", - "Epoch: 75 loss: 0.6613 Train_acc: 0.7333\n", - "Epoch: 76 loss: 0.6610 Train_acc: 0.7333\n", - "Epoch: 77 loss: 0.6608 Train_acc: 0.7333\n", - "Epoch: 78 loss: 0.6605 Train_acc: 0.7333\n", - "Epoch: 79 loss: 0.6603 Train_acc: 0.7333\n", - "Epoch: 80 loss: 0.6600 Train_acc: 0.7333\n", - "Test_acc: 0.2500\n", - "Epoch: 81 loss: 0.6597 Train_acc: 0.7333\n", - "Epoch: 82 loss: 0.6594 Train_acc: 0.7333\n", - "Epoch: 83 loss: 0.6592 Train_acc: 0.7333\n", - "Epoch: 84 loss: 0.6589 Train_acc: 0.7333\n", - "Epoch: 85 loss: 0.6586 Train_acc: 0.7333\n", - "Epoch: 86 loss: 0.6583 Train_acc: 0.7667\n", - "Epoch: 87 loss: 0.6580 Train_acc: 0.7667\n", - "Epoch: 88 loss: 0.6577 Train_acc: 0.7667\n", - "Epoch: 89 loss: 0.6574 Train_acc: 0.7667\n", - "Epoch: 90 loss: 0.6571 Train_acc: 0.7667\n", - "Test_acc: 0.2500\n", - "Epoch: 91 loss: 0.6568 Train_acc: 0.7667\n", - "Epoch: 92 loss: 0.6565 Train_acc: 0.7667\n", - "Epoch: 93 loss: 0.6562 Train_acc: 0.7667\n", - "Epoch: 94 loss: 0.6559 Train_acc: 0.7667\n", - "Epoch: 95 loss: 0.6556 Train_acc: 0.7667\n", - "Epoch: 96 loss: 0.6553 Train_acc: 0.8000\n", - "Epoch: 97 loss: 0.6550 Train_acc: 0.8000\n", - "Epoch: 98 loss: 0.6548 Train_acc: 0.8000\n", - "Epoch: 99 loss: 0.6545 Train_acc: 0.8000\n", - "Epoch: 100 loss: 0.6543 Train_acc: 0.8000\n", - "Test_acc: 0.2500\n", - "Epoch: 101 loss: 0.6541 Train_acc: 0.8000\n", - "Epoch: 102 loss: 0.6539 Train_acc: 0.8333\n", - "Epoch: 103 loss: 0.6537 Train_acc: 0.8333\n", - "Epoch: 104 loss: 0.6535 Train_acc: 0.8333\n", - "Epoch: 105 loss: 0.6534 Train_acc: 0.8333\n", - "Epoch: 106 loss: 0.6532 Train_acc: 0.8333\n", - "Epoch: 107 loss: 0.6531 Train_acc: 0.8333\n", - "Epoch: 108 loss: 0.6529 Train_acc: 0.8333\n", - "Epoch: 109 loss: 0.6528 Train_acc: 0.8333\n", - "Epoch: 110 loss: 0.6526 Train_acc: 0.8333\n", - "Test_acc: 0.2500\n", - "Epoch: 111 loss: 0.6525 Train_acc: 0.8333\n", - "Epoch: 112 loss: 0.6524 Train_acc: 0.8333\n", - "Epoch: 113 loss: 0.6522 Train_acc: 0.8333\n", - "Epoch: 114 loss: 0.6521 Train_acc: 0.8333\n", - "Epoch: 115 loss: 0.6520 Train_acc: 0.8333\n", - "Epoch: 116 loss: 0.6518 Train_acc: 0.8333\n", - "Epoch: 117 loss: 0.6517 Train_acc: 0.8333\n", - "Epoch: 118 loss: 0.6516 Train_acc: 0.8333\n", - "Epoch: 119 loss: 0.6515 Train_acc: 0.8333\n", - "Epoch: 120 loss: 0.6513 Train_acc: 0.8333\n", - "Test_acc: 0.2500\n", - "Epoch: 121 loss: 0.6512 Train_acc: 0.8333\n", - "Epoch: 122 loss: 0.6511 Train_acc: 0.8333\n", - "Epoch: 123 loss: 0.6510 Train_acc: 0.8333\n", - "Epoch: 124 loss: 0.6509 Train_acc: 0.8333\n", - "Epoch: 125 loss: 0.6508 Train_acc: 0.8333\n", - "Epoch: 126 loss: 0.6507 Train_acc: 0.8333\n", - "Epoch: 127 loss: 0.6506 Train_acc: 0.8333\n", - "Epoch: 128 loss: 0.6505 Train_acc: 0.8333\n", - "Epoch: 129 loss: 0.6504 Train_acc: 0.8333\n", - "Epoch: 130 loss: 0.6504 Train_acc: 0.8333\n", - "Test_acc: 0.2500\n", - "Epoch: 131 loss: 0.6503 Train_acc: 0.8333\n", - "Epoch: 132 loss: 0.6502 Train_acc: 0.8333\n", - "Epoch: 133 loss: 0.6501 Train_acc: 0.8333\n", - "Epoch: 134 loss: 0.6500 Train_acc: 0.8333\n", - "Epoch: 135 loss: 0.6500 Train_acc: 0.8333\n", - "Epoch: 136 loss: 0.6499 Train_acc: 0.8333\n", - "Epoch: 137 loss: 0.6498 Train_acc: 0.8333\n", - "Epoch: 138 loss: 0.6498 Train_acc: 0.8333\n", - "Epoch: 139 loss: 0.6497 Train_acc: 0.8333\n", - "Epoch: 140 loss: 0.6496 Train_acc: 0.8333\n", - "Test_acc: 0.2500\n", - "Epoch: 141 loss: 0.6495 Train_acc: 0.8333\n", - "Epoch: 142 loss: 0.6495 Train_acc: 0.8333\n", - "Epoch: 143 loss: 0.6494 Train_acc: 0.8333\n", - "Epoch: 144 loss: 0.6494 Train_acc: 0.8333\n", - "Epoch: 145 loss: 0.6493 Train_acc: 0.8333\n", - "Epoch: 146 loss: 0.6492 Train_acc: 0.8333\n", - "Epoch: 147 loss: 0.6492 Train_acc: 0.8333\n", - "Epoch: 148 loss: 0.6491 Train_acc: 0.8333\n", - "Epoch: 149 loss: 0.6491 Train_acc: 0.8333\n", - "Epoch: 150 loss: 0.6490 Train_acc: 0.8333\n", - "Test_acc: 0.2500\n", - "Epoch: 151 loss: 0.6490 Train_acc: 0.8333\n", - "Epoch: 152 loss: 0.6489 Train_acc: 0.8333\n", - "Epoch: 153 loss: 0.6489 Train_acc: 0.8333\n", - "Epoch: 154 loss: 0.6488 Train_acc: 0.8333\n", - "Epoch: 155 loss: 0.6488 Train_acc: 0.8333\n", - "Epoch: 156 loss: 0.6487 Train_acc: 0.8333\n", - "Epoch: 157 loss: 0.6487 Train_acc: 0.8333\n", - "Epoch: 158 loss: 0.6486 Train_acc: 0.8333\n", - "Epoch: 159 loss: 0.6486 Train_acc: 0.8333\n", - "Epoch: 160 loss: 0.6485 Train_acc: 0.8333\n", + "Epoch: 1 loss: 0.7247 Train_acc: 0.3000\n", + "Epoch: 2 loss: 0.7226 Train_acc: 0.6667\n", + "Epoch: 3 loss: 0.7203 Train_acc: 0.6667\n", + "Epoch: 4 loss: 0.7167 Train_acc: 0.6667\n", + "Epoch: 5 loss: 0.7115 Train_acc: 0.6667\n", + "Epoch: 6 loss: 0.7057 Train_acc: 0.7000\n", + "Epoch: 7 loss: 0.6999 Train_acc: 0.7000\n", + "Epoch: 8 loss: 0.6931 Train_acc: 0.7000\n", + "Epoch: 9 loss: 0.6874 Train_acc: 0.7000\n", + "Epoch: 10 loss: 0.6814 Train_acc: 0.7000\n", "Test_acc: 0.2500\n", - "Epoch: 161 loss: 0.6485 Train_acc: 0.8333\n", - "Epoch: 162 loss: 0.6484 Train_acc: 0.8333\n", - "Epoch: 163 loss: 0.6484 Train_acc: 0.8333\n", - "Epoch: 164 loss: 0.6483 Train_acc: 0.8333\n", - "Epoch: 165 loss: 0.6483 Train_acc: 0.8333\n", - "Epoch: 166 loss: 0.6483 Train_acc: 0.8333\n", - "Epoch: 167 loss: 0.6482 Train_acc: 0.8333\n", - "Epoch: 168 loss: 0.6482 Train_acc: 0.8333\n", - "Epoch: 169 loss: 0.6481 Train_acc: 0.8333\n", - "Epoch: 170 loss: 0.6481 Train_acc: 0.8333\n", + "Epoch: 11 loss: 0.6760 Train_acc: 0.7333\n", + "Epoch: 12 loss: 0.6716 Train_acc: 0.7333\n", + "Epoch: 13 loss: 0.6669 Train_acc: 0.7333\n", + "Epoch: 14 loss: 0.6624 Train_acc: 0.7000\n", + "Epoch: 15 loss: 0.6586 Train_acc: 0.7000\n", + "Epoch: 16 loss: 0.6551 Train_acc: 0.7667\n", + "Epoch: 17 loss: 0.6528 Train_acc: 0.7667\n", + "Epoch: 18 loss: 0.6508 Train_acc: 0.7667\n", + "Epoch: 19 loss: 0.6491 Train_acc: 0.7667\n", + "Epoch: 20 loss: 0.6479 Train_acc: 0.7667\n", "Test_acc: 0.2500\n", - "Epoch: 171 loss: 0.6481 Train_acc: 0.8333\n", - "Epoch: 172 loss: 0.6480 Train_acc: 0.8333\n", - "Epoch: 173 loss: 0.6480 Train_acc: 0.8333\n", - "Epoch: 174 loss: 0.6480 Train_acc: 0.8333\n", - "Epoch: 175 loss: 0.6479 Train_acc: 0.8333\n", - "Epoch: 176 loss: 0.6479 Train_acc: 0.8333\n", - "Epoch: 177 loss: 0.6479 Train_acc: 0.8333\n", - "Epoch: 178 loss: 0.6478 Train_acc: 0.8333\n", - "Epoch: 179 loss: 0.6478 Train_acc: 0.8333\n", - "Epoch: 180 loss: 0.6477 Train_acc: 0.8333\n", + "Epoch: 21 loss: 0.6463 Train_acc: 0.7667\n", + "Epoch: 22 loss: 0.6452 Train_acc: 0.7667\n", + "Epoch: 23 loss: 0.6438 Train_acc: 0.7333\n", + "Epoch: 24 loss: 0.6433 Train_acc: 0.7333\n", + "Epoch: 25 loss: 0.6419 Train_acc: 0.7667\n", + "Epoch: 26 loss: 0.6412 Train_acc: 0.7667\n", + "Epoch: 27 loss: 0.6397 Train_acc: 0.7667\n", + "Epoch: 28 loss: 0.6391 Train_acc: 0.7667\n", + "Epoch: 29 loss: 0.6379 Train_acc: 0.7667\n", + "Epoch: 30 loss: 0.6369 Train_acc: 0.7667\n", "Test_acc: 0.2500\n", - "Epoch: 181 loss: 0.6477 Train_acc: 0.8333\n", - "Epoch: 182 loss: 0.6477 Train_acc: 0.8333\n", - "Epoch: 183 loss: 0.6477 Train_acc: 0.8333\n", - "Epoch: 184 loss: 0.6476 Train_acc: 0.8333\n", - "Epoch: 185 loss: 0.6476 Train_acc: 0.8333\n", - "Epoch: 186 loss: 0.6476 Train_acc: 0.8333\n", - "Epoch: 187 loss: 0.6475 Train_acc: 0.8333\n", - "Epoch: 188 loss: 0.6475 Train_acc: 0.8333\n", - "Epoch: 189 loss: 0.6475 Train_acc: 0.8333\n", - "Epoch: 190 loss: 0.6474 Train_acc: 0.8333\n", + "Epoch: 31 loss: 0.6360 Train_acc: 0.7667\n", + "Epoch: 32 loss: 0.6347 Train_acc: 0.7667\n", + "Epoch: 33 loss: 0.6333 Train_acc: 0.7667\n", + "Epoch: 34 loss: 0.6317 Train_acc: 0.7667\n", + "Epoch: 35 loss: 0.6298 Train_acc: 0.7667\n", + "Epoch: 36 loss: 0.6282 Train_acc: 0.7667\n", + "Epoch: 37 loss: 0.6272 Train_acc: 0.7667\n", + "Epoch: 38 loss: 0.6267 Train_acc: 0.8000\n", + "Epoch: 39 loss: 0.6265 Train_acc: 0.8000\n", + "Epoch: 40 loss: 0.6262 Train_acc: 0.8000\n", "Test_acc: 0.2500\n", - "Epoch: 191 loss: 0.6474 Train_acc: 0.8333\n", - "Epoch: 192 loss: 0.6474 Train_acc: 0.8333\n", - "Epoch: 193 loss: 0.6474 Train_acc: 0.8333\n", - "Epoch: 194 loss: 0.6473 Train_acc: 0.8333\n", - "Epoch: 195 loss: 0.6473 Train_acc: 0.8333\n", - "Epoch: 196 loss: 0.6473 Train_acc: 0.8333\n", - "Epoch: 197 loss: 0.6472 Train_acc: 0.8333\n", - "Epoch: 198 loss: 0.6472 Train_acc: 0.8333\n", - "Epoch: 199 loss: 0.6472 Train_acc: 0.8333\n", - "Epoch: 200 loss: 0.6472 Train_acc: 0.8333\n", + "Epoch: 41 loss: 0.6260 Train_acc: 0.8000\n", + "Epoch: 42 loss: 0.6259 Train_acc: 0.8000\n", + "Epoch: 43 loss: 0.6260 Train_acc: 0.8000\n", + "Epoch: 44 loss: 0.6261 Train_acc: 0.7667\n", + "Epoch: 45 loss: 0.6260 Train_acc: 0.7667\n", + "Epoch: 46 loss: 0.6258 Train_acc: 0.8000\n", + "Epoch: 47 loss: 0.6255 Train_acc: 0.8000\n", + "Epoch: 48 loss: 0.6252 Train_acc: 0.8000\n", + "Epoch: 49 loss: 0.6249 Train_acc: 0.8000\n", + "Epoch: 50 loss: 0.6245 Train_acc: 0.8000\n", "Test_acc: 0.2500\n" ] } ], "source": [ "test_interval = 10\n", - "num_epochs = 200\n", + "num_epochs = 50\n", "for epoch_i in range(1, num_epochs + 1):\n", " epoch_loss = []\n", " model.train()\n",