Skip to content

Commit

Permalink
Merge pull request #50 from OpenFreeEnergy/twin_star
Browse files Browse the repository at this point in the history
Twin star Network
  • Loading branch information
RiesBen authored Jun 11, 2024
2 parents 65b9316 + 6a043ca commit 193a55d
Show file tree
Hide file tree
Showing 18 changed files with 187 additions and 31 deletions.
Binary file modified .img/network_layouts.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions src/konnektor/network_planners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
NNodeEdgesNetworkGenerator)
## Starmap Like Networks
from .generators.star_network_generator import StarNetworkGenerator, RadialLigandNetworkPlanner
from .generators.twin_star_network_generator import TwinStarNetworkGenerator
from .generators.clustered_network_generator import StarrySkyNetworkGenerator

## MST like Networks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def concatenate_networks(self, nodesA: list[int], nodesB: list[int],
nx.Graph
the resulting graph, containing both subgraphs.
"""

# The initial "weights" are Scores, which need to be translated to weights.
weights = list(map(lambda x: 1-x, weights))
wedges_map = {(e[0], e[1]): w for e, w in zip(edges, weights)}
wedges = [(e[0], e[1], w) for e, w in zip(edges, weights)]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def _translate_input(self, edges: List[Tuple[int, int]],
# build Edges:
w_edges = []
nodes = []
# The initial "weights" are Scores, which need to be translated to weights.
weights = list(map(lambda x: 1-x, weights))
for e, w in zip(edges, weights):
w_edges.append((e[0], e[1], w))
nodes.extend(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def generate_network(self, edges: list[tuple[int, int]],
weights: list[float], n_edges:int=None) -> nx.Graph:
wedges = []
nodes = []
# The initial "weights" are Scores, which need to be translated to weights.
weights = list(map(lambda x: 1-x, weights))
for edge, weight in zip(edges, weights):
wedges.append([edge[0], edge[1], weight])
nodes.extend(list(edge))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def generate_network(self, edges: list[tuple[int, int]],

w_edges = []
nodes = []
# The initial "weights" are Scores, which need to be translated to weights.
weights = list(map(lambda x: 1-x, weights))
for e, w in zip(edges, weights):
w_edges.append((e[0], e[1], w))
nodes.extend(e)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/konnektor

from typing import Callable
from typing import Callable, Iterable

import networkx as nx
import numpy as np
Expand All @@ -12,12 +12,15 @@

class RadialNetworkAlgorithm(_AbstractNetworkAlgorithm):

def __init__(self, metric_aggregation_method: Callable = None):
def __init__(self, metric_aggregation_method: Callable = None, n_centers: int = 1):
self.metric_aggregation_method = metric_aggregation_method
self.n_centers = n_centers

def _central_lig_selection(self, edges: list[tuple[int, int]],
weights: list[float]) -> int:
weights: list[float]) -> Iterable[int]:
nodes = set([n for e in edges for n in e])
# The initial "weights" are Scores, which need to be translated to weights.
weights = list(map(lambda x: 1-x, weights))
edge_weights = list(zip(edges, weights))

node_scores = {n: [e_s[1] for e_s in edge_weights if (n in e_s[0])] for
Expand All @@ -33,9 +36,8 @@ def _central_lig_selection(self, edges: list[tuple[int, int]],
aggregated_scores = list(
map(lambda x: (x[0], np.sum(x[1])), filtered_node_scores.items()))
sorted_node_scores = list(sorted(aggregated_scores, key=lambda x: x[1]))

opt_node = sorted_node_scores[0]
return opt_node
opt_nodes = sorted_node_scores[:self.n_centers]
return opt_nodes

def generate_network(self, edges: list[tuple[int, int]],
weights: list[float],
Expand Down Expand Up @@ -75,12 +77,16 @@ def generate_network(self, edges: list[tuple[int, int]],
"""

if (central_node is None):
central_node, avg_score = self._central_lig_selection(edges=edges,
weights=weights)
central_nodes = self._central_lig_selection(edges=edges,
weights=weights, )
elif isinstance(central_node, (SmallMoleculeComponent, str)):
central_nodes = [(central_node, 1)]
else:
raise ValueError("invalide central node type: "+str(type(central_node)))

wedges = []
for edge, weight in zip(edges, weights):
if (central_node in edge):
if any(central_node in edge for central_node, avg_score in central_nodes):
wedges.append([edge[0], edge[1], weight])

# Todo: Warning if something was not connected to the central ligand?
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/konnektor

from typing import Iterable

from gufe import Component, LigandNetwork, AtomMapper

from konnektor.network_planners._networkx_implementations import \
RadialNetworkAlgorithm
from ._abstract_network_generator import NetworkGenerator
from .maximal_network_generator import MaximalNetworkGenerator


class TwinStarNetworkGenerator(NetworkGenerator):

def __init__(self, mapper: AtomMapper, scorer, n_centers: int =2,
n_processes: int = 1,
_initial_edge_lister: NetworkGenerator = None):
"""
The Twin Star Ligand Network Planner , set's n ligands ligand into the center of a graph and connects all other ligands to each center.
Parameters
----------
mapper : AtomMapper
the atom mapper is required, to define the connection between two ligands.
scorer : AtomMappingScorer
scoring function evaluating an atom mapping, and giving a score between [0,1].
n_centers: int, optional
the number of centers in the network. (default: 2)
n_processes: int, optional
number of processes that can be used for the network generation. (default: 1)
_initial_edge_lister: LigandNetworkPlanner, optional
this LigandNetworkPlanner is used to give the initial set of edges. For standard usage, the Maximal NetworPlanner is used.
However in large scale approaches, it might be interesting to use the heuristicMaximalNetworkPlanner.. (default: MaximalNetworkPlanner)
"""
if _initial_edge_lister is None:
_initial_edge_lister = MaximalNetworkGenerator(mapper=mapper,
scorer=scorer,
n_processes=n_processes)

super().__init__(mapper=mapper, scorer=scorer,
network_generator=RadialNetworkAlgorithm(n_centers=n_centers),
n_processes=n_processes,
_initial_edge_lister=_initial_edge_lister)

self.n_centers = n_centers


def generate_ligand_network(self, components: Iterable[Component]) -> LigandNetwork:
"""
generate a twin star map network for the given compounds.
Parameters
----------
components: Iterable[Component]
the components to be used for the LigandNetwork
Returns
-------
LigandNetwork
a star like network.
"""
components = list(components)


# Full Graph Construction
initial_network = self._initial_edge_lister.generate_ligand_network(
components=components)
mappings = initial_network.edges

# Translate Mappings to graphable:
edge_map = {(components.index(m.componentA),
components.index(m.componentB)): m for m in mappings}
edges = list(sorted(edge_map.keys()))
weights = [edge_map[k].annotations['score'] for k in edges]

rg = self.network_generator.generate_network(edges=edges,
weights=weights)
selected_mappings = [edge_map[k] for k in rg.edges]


return LigandNetwork(edges=selected_mappings, nodes=components)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/konnektor

import pytest
import networkx as nx
import numpy as np

Expand All @@ -14,12 +15,28 @@ def test_radial_network_generation_find_center(nine_mols_edges):
weights = [e[2] for e in nine_mols_edges]

gen = RadialNetworkAlgorithm()
c_node, avg_weight = gen._central_lig_selection(edges, weights)
c_node, avg_weight = gen._central_lig_selection(edges, weights)[0]

assert c_node == "lig_14" # Check central node
assert np.round(avg_weight, 2) == 2.86
assert c_node == "lig_10" # Check central node
np.testing.assert_allclose(avg_weight, 2.055, rtol=0.01)


@pytest.mark.parametrize('n_centers', [2,3,4])
def test_radial_network_generation_find_centers(nine_mols_edges, n_centers):
edges = [(e[0], e[1]) for e in nine_mols_edges]
weights = [e[2] for e in nine_mols_edges]

gen = RadialNetworkAlgorithm(n_centers=n_centers)
centers = gen._central_lig_selection(edges, weights)

print(centers)
expected_centers = ['lig_10', 'lig_8', 'lig_9', 'lig_16']
expected_weights = [ 2.0551095953189917, 3.6524873109359146, 4.270400420741822, 4.543886935944357]
for i, (cID, avg_weight) in enumerate(centers):
print(cID, avg_weight)
assert cID == expected_centers[i] # Check central node
np.testing.assert_allclose(avg_weight, expected_weights[i], rtol=0.01)

def test_radial_network_generation_without_center(nine_mols_edges):
edges = [(e[0], e[1]) for e in nine_mols_edges]
weights = [e[2] for e in nine_mols_edges]
Expand All @@ -29,7 +46,7 @@ def test_radial_network_generation_without_center(nine_mols_edges):
g = gen.generate_network(edges, weights)

assert len(nodes) - 1 == len(g.edges)
assert all(["lig_14" in e for e in g.edges]) # check central node
assert all(["lig_10" in e for e in g.edges]) # check central node
assert all([e[0] != e[1] for e in g.edges]) # No self connectivity
assert isinstance(g, nx.Graph)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/konnektor

import itertools

import numpy as np
from gufe import LigandNetwork
from sklearn.cluster import KMeans
from konnektor.network_analysis import get_is_connected
from konnektor.network_analysis import get_is_connected, get_graph_score
from konnektor.network_planners.generators.clustered_network_generator import \
ClusteredNetworkGenerator
from konnektor.network_tools.clustering.component_diversity_clustering import ComponentsDiversityClusterer
Expand All @@ -15,7 +14,7 @@
def test_clustered_network_planner():
n_compounds = 40
components, genMapper, genScorer = build_random_dataset(
n_compounds=n_compounds)
n_compounds=n_compounds, rand_seed=42)

from konnektor.network_planners import (RadialLigandNetworkPlanner,
MstConcatenator)
Expand All @@ -37,3 +36,5 @@ def test_clustered_network_planner():
assert len(planner.clusters) == 3
assert len(ligand_network.edges) == 3*((n_compounds//3)-1) + (3 * concatenator.n_connecting_edges) + 1
assert get_is_connected(ligand_network)

np.testing.assert_allclose(get_graph_score(ligand_network), 25.708691, rtol=0.01)
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/konnektor

from konnektor.network_analysis import get_is_connected, get_node_number_cycles
import numpy as np
from konnektor.network_analysis import get_is_connected, get_node_number_cycles, get_graph_score
from konnektor.network_planners import CyclicNetworkGenerator
from konnektor.utils.toy_data import build_random_dataset

Expand All @@ -10,7 +11,7 @@ def test_cyclic_network_planner():
n_compounds = 8
ncycles = 2
components, genMapper, genScorer = build_random_dataset(
n_compounds=n_compounds)
n_compounds=n_compounds, rand_seed=42)

planner = CyclicNetworkGenerator(
mapper=genMapper, scorer=genScorer, cycle_sizes=3,
Expand All @@ -24,3 +25,5 @@ def test_cyclic_network_planner():
assert get_is_connected(network)
nnode_cycles = get_node_number_cycles(network)
assert all(v >= ncycles for k, v in nnode_cycles.items())

np.testing.assert_allclose(get_graph_score(network), 10.347529, rtol=0.01)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import gufe
import networkx as nx
import numpy as np
import pytest
from gufe import LigandNetwork

Expand All @@ -11,6 +12,7 @@
atom_mapping_basic_test_files,
mol_from_smiles, genScorer,
GenAtomMapper, ErrorMapper)
from konnektor.network_analysis import get_graph_score


def test_minimal_spanning_network_mappers(atom_mapping_basic_test_files):
Expand All @@ -25,7 +27,7 @@ def test_minimal_spanning_network_mappers(atom_mapping_basic_test_files):

assert isinstance(network, LigandNetwork)
assert list(network.edges)

np.testing.assert_allclose(get_graph_score(network), 0.066667, rtol=0.001)

@pytest.fixture(scope='session')
def minimal_spanning_network(toluene_vs_others):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/konnektor

import numpy as np
from gufe import LigandNetwork

from konnektor.network_analysis import get_is_connected
from konnektor.network_analysis import get_is_connected, get_graph_score
from konnektor.network_planners import NNodeEdgesNetworkGenerator
from konnektor.tests.network_planners.conf import (
atom_mapping_basic_test_files,
Expand All @@ -26,3 +27,5 @@ def test_nedges_network_mappers(atom_mapping_basic_test_files):
assert len(network.nodes) == len(ligands)
assert len(network.edges) <= len(ligands) * 2
assert get_is_connected(network)

np.testing.assert_allclose(get_graph_score(network), 0.066667, rtol=0.01)
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/konnektor

import gufe
import numpy as np
import networkx as nx
import pytest

import gufe
from gufe import LigandNetwork

from konnektor.network_planners import \
Expand All @@ -12,6 +14,7 @@
atom_mapping_basic_test_files,
mol_from_smiles, genScorer,
GenAtomMapper, ErrorMapper)
from konnektor.network_analysis import get_graph_score


def test_minimal_spanning_network_mappers(atom_mapping_basic_test_files):
Expand All @@ -27,6 +30,7 @@ def test_minimal_spanning_network_mappers(atom_mapping_basic_test_files):

assert isinstance(network, LigandNetwork)
assert list(network.edges)
np.testing.assert_allclose(get_graph_score(network), 0.066667, rtol=0.01)


@pytest.fixture(scope='session')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@pytest.mark.parametrize('as_list', [False])
def test_radial_network(atom_mapping_basic_test_files, toluene_vs_others,
def test_star_network(atom_mapping_basic_test_files, toluene_vs_others,
as_list):
toluene, others = toluene_vs_others
central_ligand_name = 'toluene'
Expand All @@ -38,7 +38,7 @@ def test_radial_network(atom_mapping_basic_test_files, toluene_vs_others,
for mapping in network.edges)


def test_radial_network_with_scorer(toluene_vs_others):
def test_star_network_with_scorer(toluene_vs_others):
toluene, others = toluene_vs_others

mapper = GenAtomMapper()
Expand All @@ -58,7 +58,7 @@ def test_radial_network_with_scorer(toluene_vs_others):
edge.componentA_to_componentB)


def test_radial_network_multiple_mappers_no_scorer(toluene_vs_others):
def test_star_network_multiple_mappers_no_scorer(toluene_vs_others):
toluene, others = toluene_vs_others
# in this one, we should always take the bad mapper
mapper = BadMapper()
Expand All @@ -73,7 +73,7 @@ def test_radial_network_multiple_mappers_no_scorer(toluene_vs_others):
assert edge.componentA_to_componentB == {0: 0}


def test_radial_network_failure(atom_mapping_basic_test_files):
def test_star_network_failure(atom_mapping_basic_test_files):
nigel = SmallMoleculeComponent(mol_from_smiles('N'))

mapper = ErrorMapper()
Expand Down
Loading

0 comments on commit 193a55d

Please sign in to comment.