diff --git a/hydragnn/models/Base.py b/hydragnn/models/Base.py index 665fb07e7..9ab828df3 100644 --- a/hydragnn/models/Base.py +++ b/hydragnn/models/Base.py @@ -18,6 +18,8 @@ import sys from hydragnn.utils.distributed import get_device +import inspect + class Base(Module): def __init__( @@ -152,23 +154,48 @@ def _init_node_conv(self): if len(node_feature_ind) == 0: return # In this part, each head has same number of convolutional layers, but can have different output dimension - self.convs_node_hidden.append( - self.get_conv(self.hidden_dim, self.hidden_dim_node[0]) - ) - self.batch_norms_node_hidden.append(BatchNorm(self.hidden_dim_node[0])) - for ilayer in range(self.num_conv_layers_node - 1): + if "last_layer" in inspect.signature(self.get_conv).parameters: self.convs_node_hidden.append( self.get_conv( - self.hidden_dim_node[ilayer], self.hidden_dim_node[ilayer + 1] + self.hidden_dim, self.hidden_dim_node[0], last_layer=False ) ) + else: + self.convs_node_hidden.append( + self.get_conv(self.hidden_dim, self.hidden_dim_node[0]) + ) + self.batch_norms_node_hidden.append(BatchNorm(self.hidden_dim_node[0])) + for ilayer in range(self.num_conv_layers_node - 1): + # This check is needed because the "get_conv" method of SCFStack takes one additional argument called last_layer + if "last_layer" in inspect.signature(self.get_conv).parameters: + self.convs_node_hidden.append( + self.get_conv( + self.hidden_dim_node[ilayer], + self.hidden_dim_node[ilayer + 1], + last_layer=False, + ) + ) + else: + self.convs_node_hidden.append( + self.get_conv( + self.hidden_dim_node[ilayer], self.hidden_dim_node[ilayer + 1] + ) + ) self.batch_norms_node_hidden.append( BatchNorm(self.hidden_dim_node[ilayer + 1]) ) for ihead in node_feature_ind: - self.convs_node_output.append( - self.get_conv(self.hidden_dim_node[-1], self.head_dims[ihead]) - ) + # This check is needed because the "get_conv" method of SCFStack takes one additional argument called last_layer + if "last_layer" in inspect.signature(self.get_conv).parameters: + self.convs_node_output.append( + self.get_conv( + self.hidden_dim_node[-1], self.head_dims[ihead], last_layer=True + ) + ) + else: + self.convs_node_output.append( + self.get_conv(self.hidden_dim_node[-1], self.head_dims[ihead]) + ) self.batch_norms_node_output.append(BatchNorm(self.head_dims[ihead])) def _multihead(self): @@ -277,9 +304,10 @@ def forward(self, data): else: if self.node_NN_type == "conv": for conv, batch_norm in zip(headloc[0::2], headloc[1::2]): - x_node = self.activation_function( - batch_norm(conv(x=x, edge_index=data.edge_index)) - ) + c, pos = conv(x=x, pos=pos, **conv_args) + c = batch_norm(c) + x = self.activation_function(c) + x_node = x else: x_node = headloc(x=x, batch=data.batch) outputs.append(x_node) diff --git a/hydragnn/models/DIMEStack.py b/hydragnn/models/DIMEStack.py index e36d70988..1949d3406 100644 --- a/hydragnn/models/DIMEStack.py +++ b/hydragnn/models/DIMEStack.py @@ -102,6 +102,7 @@ def get_conv(self, input_dim, output_dim): out_channels=output_dim, num_layers=1, act=SiLU(), + output_initializer="glorot_orthogonal", ) return Sequential( "x, pos, rbf, sbf, i, j, idx_kj, idx_ji", diff --git a/tests/inputs/ci_conv_head.json b/tests/inputs/ci_conv_head.json new file mode 100644 index 000000000..c697b4a64 --- /dev/null +++ b/tests/inputs/ci_conv_head.json @@ -0,0 +1,79 @@ +{ + "Verbosity": { + "level": 0 + }, + "Dataset": { + "name": "unit_test_singlehead", + "format": "unit_test", + "compositional_stratified_splitting": true, + "rotational_invariance": false, + "path": { + "train": "dataset/unit_test_singlehead_train", + "test": "dataset/unit_test_singlehead_test", + "validate": "dataset/unit_test_singlehead_validate" + }, + "node_features": { + "name": ["x","x2","x3"], + "dim": [1, 1, 1], + "column_index": [0, 6, 7] + }, + "graph_features":{ + "name": [ "sum_x_x2_x3"], + "dim": [1], + "column_index": [0] + } + }, + "NeuralNetwork": { + "Architecture": { + "model_type": "PNA", + "radius": 2.0, + "max_neighbours": 100, + "num_gaussians": 50, + "envelope_exponent": 5, + "int_emb_size": 64, + "basis_emb_size": 8, + "out_emb_size": 128, + "num_after_skip": 2, + "num_before_skip": 1, + "num_radial": 6, + "num_spherical": 7, + "num_filters": 126, + "periodic_boundary_conditions": false, + "hidden_dim": 20, + "num_conv_layers": 2, + "output_heads": { + "node": { + "num_headlayers": 2, + "dim_headlayers": [20,10], + "type": "conv" + } + }, + "task_weights": [1.0] + }, + "Variables_of_interest": { + "input_node_features": [0], + "output_names": ["x"], + "output_index": [0], + "type": ["node"], + "denormalize_output": false + }, + "Training": { + "num_epoch": 100, + "perc_train": 0.7, + "EarlyStopping": false, + "patience": 10, + "loss_function_type": "mse", + "batch_size": 32, + "Optimizer": { + "type": "AdamW", + "use_zero_redundancy": false, + "learning_rate": 0.02 + } + } + }, + "Visualization": { + "plot_init_solution": true, + "plot_hist_solution": false, + "create_plots": true + } +} diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 0ba9073f3..35f0b5d7a 100755 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -139,6 +139,9 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False thresholds["PNA"] = [0.10, 0.10] if use_lengths and "vector" in ci_input: thresholds["PNA"] = [0.2, 0.15] + if ci_input == "ci_conv_head.json": + thresholds["GIN"] = [0.25, 0.40] + verbosity = 2 for ihead in range(len(true_values)): @@ -199,3 +202,10 @@ def pytest_train_equivariant_model(model_type, overwrite_data=False): @pytest.mark.parametrize("model_type", ["PNA"]) def pytest_train_model_vectoroutput(model_type, overwrite_data=False): unittest_train_model(model_type, "ci_vectoroutput.json", True, overwrite_data) + + +@pytest.mark.parametrize( + "model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "SchNet", "DimeNet", "EGNN"] +) +def pytest_train_model_conv_head(model_type, overwrite_data=False): + unittest_train_model(model_type, "ci_conv_head.json", False, overwrite_data)