Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to check if there is a route with a given set of starting molecules #34

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 209 additions & 0 deletions syntheseus/search/analysis/starting_molecule_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""Code related to starting molecules match metric (called exact set-wise match in FusionRetro)."""

from __future__ import annotations

import itertools
import typing
from collections.abc import Iterable
from typing import Optional

from syntheseus.search.chem import Molecule
from syntheseus.search.graph.and_or import ANDOR_NODE, AndNode, AndOrGraph, OrNode

T = typing.TypeVar("T")


def is_route_with_starting_mols(
graph: AndOrGraph,
starting_mols: set[Molecule],
forbidden_nodes: typing.Optional[set[ANDOR_NODE]] = None,
) -> bool:
"""Checks whether there is a route in graph matching the starting mols."""
forbidden_nodes = forbidden_nodes or set()
return _is_route_with_starting_mols(
graph,
graph.root_node,
starting_mols,
forbidden_nodes,
)


def split_into_subsets(A: list[T], k: int) -> Iterable[list[list[T]]]:
"""
Enumerate all possible ways to create k subsets from a list A such that none of the k subsets are empty,
the union of the sets is A, and the order of the subsets *does* matter.

This function is easiest to explain by example:

A single partition is just the list A itself.

>>> list(split_into_subsets([1, 2], 1))
>>> [[[1, 2]]]

For k=2, there are 4 possible subsets meaning 16 pairs of subsets,
but only 8 of them have non-empty subsets which include every element once.

>>> list(split_into_subsets([1, 2,], 2))
>>> [[[1], [2]], [[1], [1,2]], [[2], [1]], [[2], [1,2]], [[1,2], [1]], [[1,2], [2]]]
>>> [[[1], [2]], [[1], [1, 2]], [[2], [1]], [[2], [1, 2]], [[1, 2], [1]], [[1, 2], [2]], [[1, 2], [1, 2]]]

The implementation just uses itertools.combinations and itertools.products to enumerate all possible partions,
and simply rejects those which do not sum up to the entire set.
It is not very efficient for large A or large k, so use with caution.

NOTE: the efficiency of this method could definitely be improved later.
"""

# Check 1: elements of list are unique
assert len(set(A)) == len(A)

# Check 2: k is valid
assert k >= 1

# Base case: list is empty
if len(A) == 0:
return

# Iterate through all subsets
power_set_non_empty = itertools.chain.from_iterable(
itertools.combinations(A, r) for r in range(1, len(A) + 1)
)
for subsets in itertools.product(power_set_non_empty, repeat=k):
if set(itertools.chain.from_iterable(subsets)) == set(A):
yield [list(s) for s in subsets]


def _is_solvable_from_starting_mols(
graph: AndOrGraph,
starting_mols: set[Molecule],
forbidden_nodes: Optional[set[ANDOR_NODE]] = None,
) -> dict[ANDOR_NODE, bool]:
"""Get whether each node is solvable only from a specified set of starting molecules."""
forbidden_nodes = forbidden_nodes or set()

# Which nodes are solvable because they contain a starting molecule?
node_to_contains_start_mol = {
n: (isinstance(n, OrNode) and n.mol in starting_mols) for n in graph.nodes()
}
node_to_solvable = {n: False for n in graph.nodes()}

# Do passes through all nodes
update_happened = True
while update_happened:
update_happened = False
for n in graph.nodes():
successors_are_solvable = [node_to_solvable[c] for c in graph.successors(n)]
if n in forbidden_nodes:
new_solvable = False # regardless of successors, forbidden nodes are not solvable
elif isinstance(n, OrNode):
new_solvable = any(successors_are_solvable) or node_to_contains_start_mol[n]
elif isinstance(n, AndNode):
new_solvable = all(successors_are_solvable)
else:
raise ValueError

if new_solvable != node_to_solvable[n]:
node_to_solvable[n] = new_solvable
update_happened = True

return node_to_solvable


def _starting_mols_under_each_node(graph: AndOrGraph) -> dict[ANDOR_NODE, set[Molecule]]:
"""Get set of molecules reachable under each node in the graph."""

# Initialize to empty sets, except for nodes with purchasable mols
node_to_mols: dict[ANDOR_NODE, set[Molecule]] = {n: set() for n in graph.nodes()}
for n in graph.nodes():
if isinstance(n, OrNode):
node_to_mols[n].add(n.mol)
Comment on lines +115 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says "purchasable", but is this reflected in the code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch: I used to check whether a molecule is purchasable but removed it to make the method more general (i.e. the nodes marked purchasable in the graph don't need to be the starting molecules). I will fix this by changing the comment.


# Do passes through all nodes
update_happened = True
while update_happened:
update_happened = False
for n in graph.nodes():
for c in graph.successors(n):
if not (node_to_mols[c] <= node_to_mols[n]):
node_to_mols[n].update(node_to_mols[c])
update_happened = True

return node_to_mols


def _is_route_with_starting_mols(
graph: AndOrGraph,
start_node: OrNode,
starting_mols: set[Molecule],
forbidden_nodes: set[ANDOR_NODE],
node_to_solvable: Optional[dict[ANDOR_NODE, bool]] = None,
node_to_reachable_starting_mols: Optional[dict[ANDOR_NODE, set[Molecule]]] = None,
) -> bool:
"""
Recursive method to check whether there is a route in the graph,
starting from `start_node` and excluding `forbidden_nodes`,
whose leaves and exactly `starting_mols`.

To prune the search early, we use the `node_to_solvable` dictionary,
which contains True if a node *might* be solvable from only the starting molecules.
"""
assert start_node in graph

# Compute node to solvable if not provided
if node_to_solvable is None:
node_to_solvable = _is_solvable_from_starting_mols(graph, starting_mols, forbidden_nodes)
if node_to_reachable_starting_mols is None:
node_to_reachable_starting_mols = _starting_mols_under_each_node(graph)

# Base case 1: starting mols is empty
if len(starting_mols) == 0:
return False

# Base case 2: start node is forbidden
if start_node in forbidden_nodes:
return False

# Base case 3: start node not solvable
if (
not node_to_solvable[start_node]
or not node_to_reachable_starting_mols[start_node] >= starting_mols
):
return False

# Base case 4: there is just one starting molecule and this OrNode contains it.
if len(starting_mols) == 1 and list(starting_mols)[0] == start_node.mol:
return True

# Main case: the required starting molecules are reachable,
# but we just need to check whether they are reachable within a single synthesis route.
# We do this by explicitly trying to find this synthesis route.
for rxn_child in graph.successors(start_node):
# If the starting molecules are not reachable from this reaction child, abort the search
if (
node_to_solvable[rxn_child]
and node_to_reachable_starting_mols[rxn_child] >= starting_mols
):
# Also abort search if any grandchildren are forbidden
grandchildren = list(graph.successors(rxn_child))
if not any(gc in forbidden_nodes for gc in grandchildren):
# Main recurisve call: we partition K molecules among N children and check whether
# each child is solvable with its allocated molecules.
for start_mol_partition in split_into_subsets(
list(starting_mols), len(grandchildren)
):
for gc, allocated_start_mols in zip(grandchildren, start_mol_partition):
assert isinstance(gc, OrNode)
if not _is_route_with_starting_mols(
graph=graph,
start_node=gc,
starting_mols=set(allocated_start_mols),
forbidden_nodes=forbidden_nodes | {start_node, rxn_child},
node_to_solvable=node_to_solvable,
node_to_reachable_starting_mols=node_to_reachable_starting_mols,
):
break
else: # i.e. loop finished without breaking
return True

# If the method has not returned at this point then there is no route
return False
92 changes: 92 additions & 0 deletions syntheseus/tests/search/analysis/test_starting_molecule_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pytest

from syntheseus.search.analysis import starting_molecule_match
from syntheseus.search.chem import Molecule
from syntheseus.search.graph.and_or import AndOrGraph


@pytest.mark.parametrize(
"A,k,expected_partitions",
[
(
[1, 2],
1,
[[[1, 2]]],
),
(
[1, 2],
2,
[
[[1], [2]],
[[1], [1, 2]],
[[2], [1]],
[[2], [1, 2]],
[[1, 2], [1]],
[[1, 2], [2]],
[[1, 2], [1, 2]],
],
),
([], 1, []), # empty list has no partitions
([], 2, []), # test again with k=2
],
)
def test_split_into_subsets_valid(A, k, expected_partitions):
output = list(starting_molecule_match.split_into_subsets(A, k))
assert output == expected_partitions


@pytest.mark.parametrize(
"A,k",
[
([1, 2, 3], -1), # negative k
([1, 2, 3], 0), # k=0
([1, 2, 2], 0), # list has duplicates
],
)
def test_split_into_subsets_invalid(A, k):
with pytest.raises(AssertionError):
list(starting_molecule_match.split_into_subsets(A, k))


class TestStartingMoleculeMatch:
@pytest.mark.parametrize(
"starting_smiles,expected_ans",
[
("COCS", True), # Is starting molecule
("CO.CS", True), # One of the routes
("CC", True), # Another route
("CS.CC", True), # Another route
("CO.CC", True), # Can be a route if CO occurs twice and is reacted in one of them
("COCC.CC", False), # both mols are in graph, but not part of same route
("", False), # an empty set should always be False
],
)
def test_small_andorgraph(
self, andor_graph_non_minimal: AndOrGraph, starting_smiles: str, expected_ans: bool
):
starting_mols = {Molecule(s) for s in starting_smiles.split(".")}
match = starting_molecule_match.is_route_with_starting_mols(
andor_graph_non_minimal, starting_mols
)
assert match == expected_ans

@pytest.mark.parametrize(
"starting_smiles,expected_ans",
[
("CC.COC", True), # small route, should be in there
("CCCO.O", True), # another route from docstring
("CCCO.C", True), # this route exists (although C is not purchasable here)
("CCCOC", True), # this is just the root node
("", False), # an empty set should always be False
("C.O", True), # should be possible to decompose into just C,O
("CCCO.CC", False), # too many atoms
],
)
def test_large_andorgraph(
self, andor_graph_with_many_routes: AndOrGraph, starting_smiles: str, expected_ans: bool
):
starting_mols = {Molecule(s) for s in starting_smiles.split(".")}
match = starting_molecule_match.is_route_with_starting_mols(
andor_graph_with_many_routes, starting_mols
)
assert match == expected_ans
2 changes: 1 addition & 1 deletion syntheseus/tests/search/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def retrosynthesis_task6() -> RetrosynthesisTask:
CCCOC -> CC + COC

There are 2 routes of length 2:
CCCO -> CCCO + C
CCCOC -> CCCO + C
C -> O

CCCOC -> CCCOO
Expand Down
Loading