diff --git a/gufe/network.py b/gufe/network.py index e02be31a..59929d68 100644 --- a/gufe/network.py +++ b/gufe/network.py @@ -1,7 +1,8 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/gufe -from typing import Iterable, Optional +from typing import Generator, Iterable, Optional +from typing_extensions import Self # Self is included in typing as of python 3.11 import networkx as nx from .tokenization import GufeTokenizable @@ -10,6 +11,7 @@ from .transformations import Transformation + class AlchemicalNetwork(GufeTokenizable): _edges: frozenset[Transformation] _nodes: frozenset[ChemicalSystem] @@ -102,7 +104,7 @@ def _to_dict(self) -> dict: "name": self.name} @classmethod - def _from_dict(cls, d: dict): + def _from_dict(cls, d: dict) -> Self: return cls(nodes=frozenset(d['nodes']), edges=frozenset(d['edges']), name=d.get('name')) @@ -116,6 +118,21 @@ def to_graphml(self) -> str: raise NotImplementedError @classmethod - def from_graphml(cls, str): + def from_graphml(cls, str) -> Self: """Currently not implemented""" raise NotImplementedError + + @classmethod + def _from_nx_graph(cls, nx_graph) -> Self: + """Create an alchemical network from a networkx representation.""" + chemical_systems = [n for n in nx_graph.nodes()] + transformations = [e[2]['object'] for e in nx_graph.edges(data=True)] + return cls(nodes=chemical_systems, edges=transformations) + + def connected_subgraphs(self) -> Generator[Self, None, None]: + """Return a generator of all connected subgraphs of the alchemical network.""" + node_groups = nx.weakly_connected_components(self.graph) + for node_group in node_groups: + nx_subgraph = self.graph.subgraph(node_group) + alc_subgraph = self._from_nx_graph(nx_subgraph) + yield(alc_subgraph) diff --git a/gufe/tests/conftest.py b/gufe/tests/conftest.py index 62487740..3080f469 100644 --- a/gufe/tests/conftest.py +++ b/gufe/tests/conftest.py @@ -256,9 +256,8 @@ def complex_equilibrium(solvated_complex): protocol=DummyProtocol(settings=DummyProtocol.default_settings()) ) - @pytest.fixture -def benzene_variants_star_map( +def benzene_variants_star_map_transformations( benzene, toluene, phenol, @@ -320,7 +319,15 @@ def benzene_variants_star_map( mapping=None, ) - return gufe.AlchemicalNetwork( - list(solvated_ligand_transformations.values()) - + list(solvated_complex_transformations.values()) - ) + return list(solvated_ligand_transformations.values()), list(solvated_complex_transformations.values()) + + +@pytest.fixture +def benzene_variants_star_map(benzene_variants_star_map_transformations): + solvated_ligand_transformations, solvated_complex_transformations = benzene_variants_star_map_transformations + return gufe.AlchemicalNetwork(solvated_ligand_transformations+solvated_complex_transformations) + +@pytest.fixture +def benzene_variants_ligand_star_map(benzene_variants_star_map_transformations): + solvated_ligand_transformations, _ = benzene_variants_star_map_transformations + return gufe.AlchemicalNetwork(solvated_ligand_transformations) diff --git a/gufe/tests/test_alchemicalnetwork.py b/gufe/tests/test_alchemicalnetwork.py index ef148bac..8231c51d 100644 --- a/gufe/tests/test_alchemicalnetwork.py +++ b/gufe/tests/test_alchemicalnetwork.py @@ -47,3 +47,28 @@ def test_connectivity(self, benzene_variants_star_map): else: edges = alnet.graph.edges(node) assert len(edges) == 0 + + def test_connected_subgraphs_multiple_subgraphs(self, benzene_variants_star_map): + """Identify two separate networks and one floating nodes as subgraphs.""" + # remove an edge to create a network w/ two subnetworks and one floating node + edge_list = [e for e in benzene_variants_star_map.edges] + alnet = benzene_variants_star_map.copy_with_replacements(edges=edge_list[:-1]) + + subgraphs = [subgraph for subgraph in alnet.connected_subgraphs()] + + assert set([len(subgraph.nodes) for subgraph in subgraphs]) == {6,7,1} + + # which graph has the removed node is not deterministic, so we just + # check that one graph is all-solvent and the other is all-protein + for subgraph in subgraphs: + components = [frozenset(n.components.keys()) for n in subgraph.nodes] + if {'solvent','protein','ligand'} in components: + assert set(components) == {frozenset({'solvent','protein','ligand'})} + else: + assert set(components) == {frozenset({'solvent','ligand'})} + + def test_connected_subgraphs_one_subgraph(self, benzene_variants_ligand_star_map): + """Return the same network if it only contains one connected component.""" + alnet = benzene_variants_ligand_star_map + subgraphs = [subgraph for subgraph in alnet.connected_subgraphs()] + assert subgraphs == [alnet]