Skip to content

Commit

Permalink
additional tests added and pytest input fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
allaffa committed Dec 12, 2024
1 parent 2f8e65d commit 4604a63
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,16 @@ def pytest_train_model(mpnn_type, ci_input, overwrite_data=False):
unittest_train_model(mpnn_type, None, None, ci_input, False, overwrite_data)


# Test only models
# Test models that allow edge attributes
@pytest.mark.parametrize(
"mpnn_type",
["GAT", "PNA", "PNAPlus", "CGCNN", "SchNet", "DimeNet", "EGNN", "PNAEq", "PAINN"],
)
def pytest_train_model_lengths(mpnn_type, overwrite_data=False):
unittest_train_model(mpnn_type, None, None, "ci.json", True, overwrite_data)


# Test models that allow edge attributes with global attention mechanisms
@pytest.mark.parametrize(
"global_attn_engine",
["GPS"],
Expand All @@ -231,7 +240,7 @@ def pytest_train_model(mpnn_type, ci_input, overwrite_data=False):
"mpnn_type",
["GAT", "PNA", "PNAPlus", "CGCNN", "SchNet", "DimeNet", "EGNN", "PNAEq", "PAINN"],
)
def pytest_train_model_lengths(
def pytest_train_model_lengths_global_attention(
mpnn_type, global_attn_engine, global_attn_type, overwrite_data=False
):
unittest_train_model(
Expand Down

0 comments on commit 4604a63

Please sign in to comment.