From 10c63fb58251a15e8808f61bf942d7bad7d5b0af Mon Sep 17 00:00:00 2001 From: Austin Tripp Date: Tue, 14 Nov 2023 15:43:42 +0000 Subject: [PATCH 1/5] Add set partition function. --- .../analysis/starting_molecule_match.py | 62 +++++++++++++++ .../analysis/test_starting_molecule_match.py | 75 +++++++++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 syntheseus/search/analysis/starting_molecule_match.py create mode 100644 syntheseus/tests/search/analysis/test_starting_molecule_match.py diff --git a/syntheseus/search/analysis/starting_molecule_match.py b/syntheseus/search/analysis/starting_molecule_match.py new file mode 100644 index 00000000..d0b960a7 --- /dev/null +++ b/syntheseus/search/analysis/starting_molecule_match.py @@ -0,0 +1,62 @@ +"""Code related to starting molecules match metric (called exact set-wise match in FusionRetro).""" + +from __future__ import annotations + +import itertools +import typing +from collections.abc import Iterable + +T = typing.TypeVar("T") + + +def partition_set(A: list[T], k: int) -> Iterable[list[list[T]]]: + """ + Enumerates all possible ways to partition a list A into k disjoint subsets which are non-empty + (in all possible orders). + + This function is easiest to explain by example: + + A single partition is just the list A itself. + + >>> list(partition_set([1, 2, 3], 1)) + >>> [[[1, 2, 3]]] + + For k=2, there are 3 possible partitions each with two possible orders. + + >>> list(partition_set([1, 2, 3], 2)) + >>> [[[1], [2, 3]], [[2], [1, 3]], [[3], [1, 2]], [[1, 2], [3]], [[1, 3], [2]], [[2, 3], [1]]] + + For k=3, there is only 1 possible partition but 6 possible orders. + + >>> list(partition_set([1, 2, 3], 3)) + >>> [[[1], [2], [3]], [[1], [3], [2]], [[2], [1], [3]], [[2], [3], [1]], [[3], [1], [2]], [[3], [2], [1]]] + + This function uses a recursive implementation. + First it partitions A into 2 (where the second partition has at least k-1 elements), + then it recursively partitions the second partition into (k-1) partitions. + """ + + # Check 1: elements of list are unique + assert len(set(A)) == len(A) + + # Check 2: k is valid + assert k >= 1 + + # Base case 1: list is empty + if len(A) == 0: + return + + # Base case 2: just a single partition + if k == 1: + yield [A] + return + + # Main case: partition A into two parts, then recursively partition the second part + max_size_of_first_partition = len(A) - k + 1 + for first_partition_size in range(1, max_size_of_first_partition + 1): + for first_partition in itertools.combinations(A, first_partition_size): + # Find the remaining elements to partition. + # NOTE: this assumes that the elements of A are unique. + remaining_elements = [x for x in A if x not in first_partition] + for subsequent_partitions in partition_set(remaining_elements, k - 1): + yield [list(first_partition)] + subsequent_partitions diff --git a/syntheseus/tests/search/analysis/test_starting_molecule_match.py b/syntheseus/tests/search/analysis/test_starting_molecule_match.py new file mode 100644 index 00000000..96f93e8f --- /dev/null +++ b/syntheseus/tests/search/analysis/test_starting_molecule_match.py @@ -0,0 +1,75 @@ +import pytest + +from syntheseus.search.analysis import starting_molecule_match + + +@pytest.mark.parametrize( + "A,k,expected_partitions", + [ + ([1, 2, 3], 1, [[[1, 2, 3]]]), + ( + [1, 2, 3], + 2, + [ + [[1], [2, 3]], + [[2], [1, 3]], + [[3], [1, 2]], + [[1, 2], [3]], + [[1, 3], [2]], + [[2, 3], [1]], + ], + ), + ( + [1, 2, 3], + 3, + [ + [[1], [2], [3]], + [[1], [3], [2]], + [[2], [1], [3]], + [[2], [3], [1]], + [[3], [1], [2]], + [[3], [2], [1]], + ], + ), + ([1, 2, 3], 4, []), + ([1, 2, 3], 5, []), + ( + [1, 2, 3, 4], + 2, + [ + [[1], [2, 3, 4]], + [[2], [1, 3, 4]], + [[3], [1, 2, 4]], + [[4], [1, 2, 3]], + [[1, 2], [3, 4]], + [[1, 3], [2, 4]], + [[1, 4], [2, 3]], + [[2, 3], [1, 4]], + [[2, 4], [1, 3]], + [[3, 4], [1, 2]], + [[1, 2, 3], [4]], + [[1, 2, 4], [3]], + [[1, 3, 4], [2]], + [[2, 3, 4], [1]], + ], + ), + ([], 1, []), # empty list has no partitions + ([], 2, []), # test again with k=2 + ], +) +def test_partition_set_valid(A, k, expected_partitions): + output = list(starting_molecule_match.partition_set(A, k)) + assert output == expected_partitions + + +@pytest.mark.parametrize( + "A,k", + [ + ([1, 2, 3], -1), # negative k + ([1, 2, 3], 0), # k=0 + ([1, 2, 2], 0), # list has duplicates + ], +) +def test_partition_set_invalid(A, k): + with pytest.raises(AssertionError): + list(starting_molecule_match.partition_set(A, k)) From 402306a1d09d034a2d01816886e7042f8e97c10b Mon Sep 17 00:00:00 2001 From: Austin Tripp Date: Tue, 14 Nov 2023 16:47:30 +0000 Subject: [PATCH 2/5] Fix small typo in docstring. --- syntheseus/tests/search/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syntheseus/tests/search/conftest.py b/syntheseus/tests/search/conftest.py index 35d9f755..037502be 100644 --- a/syntheseus/tests/search/conftest.py +++ b/syntheseus/tests/search/conftest.py @@ -248,7 +248,7 @@ def retrosynthesis_task6() -> RetrosynthesisTask: CCCOC -> CC + COC There are 2 routes of length 2: - CCCO -> CCCO + C + CCCOC -> CCCO + C C -> O CCCOC -> CCCOO From 3a0372849d22a7c0c684e6844076993c09188ce1 Mon Sep 17 00:00:00 2001 From: Austin Tripp Date: Tue, 14 Nov 2023 16:51:59 +0000 Subject: [PATCH 3/5] Add full method to check for starting molecule matches. --- .../analysis/starting_molecule_match.py | 104 ++++++++++++++++++ .../analysis/test_starting_molecule_match.py | 50 +++++++++ 2 files changed, 154 insertions(+) diff --git a/syntheseus/search/analysis/starting_molecule_match.py b/syntheseus/search/analysis/starting_molecule_match.py index d0b960a7..a4547972 100644 --- a/syntheseus/search/analysis/starting_molecule_match.py +++ b/syntheseus/search/analysis/starting_molecule_match.py @@ -6,9 +6,28 @@ import typing from collections.abc import Iterable +from syntheseus.search.chem import Molecule +from syntheseus.search.graph.and_or import ANDOR_NODE, AndOrGraph, OrNode + T = typing.TypeVar("T") +def is_route_with_starting_mols( + graph: AndOrGraph, + starting_mols: set[Molecule], + forbidden_nodes: typing.Optional[set[ANDOR_NODE]] = None, +) -> bool: + """Checks whether there is a route in graph matching the starting mols.""" + forbidden_nodes = forbidden_nodes or set() + return _is_route_with_starting_mols( + graph, + graph.root_node, + starting_mols, + forbidden_nodes, + _starting_mols_under_each_node(graph), + ) + + def partition_set(A: list[T], k: int) -> Iterable[list[list[T]]]: """ Enumerates all possible ways to partition a list A into k disjoint subsets which are non-empty @@ -60,3 +79,88 @@ def partition_set(A: list[T], k: int) -> Iterable[list[list[T]]]: remaining_elements = [x for x in A if x not in first_partition] for subsequent_partitions in partition_set(remaining_elements, k - 1): yield [list(first_partition)] + subsequent_partitions + + +def _starting_mols_under_each_node(graph: AndOrGraph) -> dict[ANDOR_NODE, set[Molecule]]: + """Get set of molecules reachable under each node in the graph.""" + + # Initialize to empty sets, except for nodes with purchasable mols + node_to_mols: dict[ANDOR_NODE, set[Molecule]] = {n: set() for n in graph.nodes()} + for n in graph.nodes(): + if isinstance(n, OrNode): + node_to_mols[n].add(n.mol) + + # Do passes through all nodes + update_happened = True + while update_happened: + update_happened = False + for n in graph.nodes(): + for c in graph.successors(n): + if not (node_to_mols[c] <= node_to_mols[n]): + node_to_mols[n].update(node_to_mols[c]) + update_happened = True + + return node_to_mols + + +def _is_route_with_starting_mols( + graph: AndOrGraph, + start_node: OrNode, + starting_mols: set[Molecule], + forbidden_nodes: set[ANDOR_NODE], + node_to_all_reachable_starting_mols: dict[ANDOR_NODE, set[Molecule]], +) -> bool: + """ + Recursive method to check whether there is a route in the graph, + starting from `start_node` and excluding `forbidden_nodes`, + whose leaves and exactly `starting_mols`. + + To prune the search, we use the `node_to_all_reachable_starting_mols` dictionary, + which contains the set of all purchasable molecules reachable under each node (not necessarily part of a single route though). + We use this to prune the search early: if some molecules cannot be reached at all, there is no point checking whether + they might be reachable from a single route. + """ + assert start_node in graph + + # Base case 1: starting mols is empty + if len(starting_mols) == 0: + return False + + # Base case 2: start node is forbidden + if start_node in forbidden_nodes: + return False + + # Base case 3: starting are not reachable at all from the start node + if not (starting_mols <= node_to_all_reachable_starting_mols[start_node]): + return False + + # Base case 4: there is just one starting molecule and this OrNode contains it. + if len(starting_mols) == 1 and list(starting_mols)[0] == start_node.mol: + return True + + # Main case: the required starting molecules are reachable, + # but we just need to check whether they are reachable within a single synthesis route. + # We do this by explicitly trying to find this synthesis route. + for rxn_child in graph.successors(start_node): + # If the starting molecules are not reachable from this reaction child, abort the search + if node_to_all_reachable_starting_mols[rxn_child] >= starting_mols: + # Also abort search if any grandchildren are forbidden + grandchildren = list(graph.successors(rxn_child)) + if not any(gc in forbidden_nodes for gc in grandchildren): + # Main recurisve call: we partition K molecules among N children and check whether + for start_mol_partition in partition_set(list(starting_mols), len(grandchildren)): + for gc, allocated_start_mols in zip(grandchildren, start_mol_partition): + assert isinstance(gc, OrNode) + if not _is_route_with_starting_mols( + graph, + gc, + set(allocated_start_mols), + forbidden_nodes | {start_node, rxn_child}, + node_to_all_reachable_starting_mols, + ): + break + else: # i.e. loop finished without breaking + return True + + # If the method has not returned at this point then there is no route + return False diff --git a/syntheseus/tests/search/analysis/test_starting_molecule_match.py b/syntheseus/tests/search/analysis/test_starting_molecule_match.py index 96f93e8f..b89a567b 100644 --- a/syntheseus/tests/search/analysis/test_starting_molecule_match.py +++ b/syntheseus/tests/search/analysis/test_starting_molecule_match.py @@ -1,6 +1,8 @@ import pytest from syntheseus.search.analysis import starting_molecule_match +from syntheseus.search.chem import Molecule +from syntheseus.search.graph.and_or import AndOrGraph @pytest.mark.parametrize( @@ -73,3 +75,51 @@ def test_partition_set_valid(A, k, expected_partitions): def test_partition_set_invalid(A, k): with pytest.raises(AssertionError): list(starting_molecule_match.partition_set(A, k)) + + +class TestStartingMoleculeMatch: + @pytest.mark.parametrize( + "starting_smiles,expected_ans", + [ + ("COCS", True), # Is starting molecule + ("CO.CS", True), # One of the routes + ("CC", True), # Another route + ("CS.CC", True), # Another route + ("CO.CC", True), # Can be a route if CO occurs twice and is reacted in one of them + ("COCC.CC", False), # both mols are in graph, but not part of same route + ("", False), # an empty set should always be False + ], + ) + def test_small_andorgraph( + self, andor_graph_non_minimal: AndOrGraph, starting_smiles: str, expected_ans: bool + ): + starting_mols = {Molecule(s) for s in starting_smiles.split(".")} + assert ( + starting_molecule_match.is_route_with_starting_mols( + andor_graph_non_minimal, starting_mols + ) + == expected_ans + ) + + @pytest.mark.parametrize( + "starting_smiles,expected_ans", + [ + ("CC.COC", True), # small route, should be in there + ("CCCO.O", True), # another route from docstring + ("CCCO.C", True), # this route exists (although C is not purchasable here) + ("CCCOC", True), # this is just the root node + ("", False), # an empty set should always be False + ("C.O", True), # should be possible to decompose into just C,O + ("CCCO.CC", False), # too many atoms + ], + ) + def test_large_andorgraph( + self, andor_graph_with_many_routes: AndOrGraph, starting_smiles: str, expected_ans: bool + ): + starting_mols = {Molecule(s) for s in starting_smiles.split(".")} + assert ( + starting_molecule_match.is_route_with_starting_mols( + andor_graph_with_many_routes, starting_mols + ) + == expected_ans + ) From aee50ccbc01f92426a7d32c4b70103fffa8042f0 Mon Sep 17 00:00:00 2001 From: Austin Tripp Date: Tue, 14 Nov 2023 17:26:01 +0000 Subject: [PATCH 4/5] Fix bug where th same starting molecule was forbidden from being used in multiple branches. --- .../analysis/starting_molecule_match.py | 56 ++++++++---------- .../analysis/test_starting_molecule_match.py | 59 +++++-------------- 2 files changed, 40 insertions(+), 75 deletions(-) diff --git a/syntheseus/search/analysis/starting_molecule_match.py b/syntheseus/search/analysis/starting_molecule_match.py index a4547972..5ec89700 100644 --- a/syntheseus/search/analysis/starting_molecule_match.py +++ b/syntheseus/search/analysis/starting_molecule_match.py @@ -28,31 +28,30 @@ def is_route_with_starting_mols( ) -def partition_set(A: list[T], k: int) -> Iterable[list[list[T]]]: +def split_into_subsets(A: list[T], k: int) -> Iterable[list[list[T]]]: """ - Enumerates all possible ways to partition a list A into k disjoint subsets which are non-empty - (in all possible orders). + Enumerate all possible ways to create k subsets from a list A such that none of the k subsets are empty, + the union of the sets is A, and the order of the subsets *does* matter. This function is easiest to explain by example: A single partition is just the list A itself. - >>> list(partition_set([1, 2, 3], 1)) - >>> [[[1, 2, 3]]] + >>> list(split_into_subsets([1, 2], 1)) + >>> [[[1, 2]]] - For k=2, there are 3 possible partitions each with two possible orders. + For k=2, there are 4 possible subsets meaning 16 pairs of subsets, + but only 8 of them have non-empty subsets which include every element once. - >>> list(partition_set([1, 2, 3], 2)) - >>> [[[1], [2, 3]], [[2], [1, 3]], [[3], [1, 2]], [[1, 2], [3]], [[1, 3], [2]], [[2, 3], [1]]] + >>> list(split_into_subsets([1, 2,], 2)) + >>> [[[1], [2]], [[1], [1,2]], [[2], [1]], [[2], [1,2]], [[1,2], [1]], [[1,2], [2]]] + >>> [[[1], [2]], [[1], [1, 2]], [[2], [1]], [[2], [1, 2]], [[1, 2], [1]], [[1, 2], [2]], [[1, 2], [1, 2]]] - For k=3, there is only 1 possible partition but 6 possible orders. + The implementation just uses itertools.combinations and itertools.products to enumerate all possible partions, + and simply rejects those which do not sum up to the entire set. + It is not very efficient for large A or large k, so use with caution. - >>> list(partition_set([1, 2, 3], 3)) - >>> [[[1], [2], [3]], [[1], [3], [2]], [[2], [1], [3]], [[2], [3], [1]], [[3], [1], [2]], [[3], [2], [1]]] - - This function uses a recursive implementation. - First it partitions A into 2 (where the second partition has at least k-1 elements), - then it recursively partitions the second partition into (k-1) partitions. + NOTE: the efficiency of this method could definitely be improved later. """ # Check 1: elements of list are unique @@ -61,24 +60,17 @@ def partition_set(A: list[T], k: int) -> Iterable[list[list[T]]]: # Check 2: k is valid assert k >= 1 - # Base case 1: list is empty + # Base case: list is empty if len(A) == 0: return - # Base case 2: just a single partition - if k == 1: - yield [A] - return - - # Main case: partition A into two parts, then recursively partition the second part - max_size_of_first_partition = len(A) - k + 1 - for first_partition_size in range(1, max_size_of_first_partition + 1): - for first_partition in itertools.combinations(A, first_partition_size): - # Find the remaining elements to partition. - # NOTE: this assumes that the elements of A are unique. - remaining_elements = [x for x in A if x not in first_partition] - for subsequent_partitions in partition_set(remaining_elements, k - 1): - yield [list(first_partition)] + subsequent_partitions + # Iterate through all subsets + power_set_non_empty = itertools.chain.from_iterable( + itertools.combinations(A, r) for r in range(1, len(A) + 1) + ) + for subsets in itertools.product(power_set_non_empty, repeat=k): + if set(itertools.chain.from_iterable(subsets)) == set(A): + yield [list(s) for s in subsets] def _starting_mols_under_each_node(graph: AndOrGraph) -> dict[ANDOR_NODE, set[Molecule]]: @@ -148,7 +140,9 @@ def _is_route_with_starting_mols( grandchildren = list(graph.successors(rxn_child)) if not any(gc in forbidden_nodes for gc in grandchildren): # Main recurisve call: we partition K molecules among N children and check whether - for start_mol_partition in partition_set(list(starting_mols), len(grandchildren)): + for start_mol_partition in split_into_subsets( + list(starting_mols), len(grandchildren) + ): for gc, allocated_start_mols in zip(grandchildren, start_mol_partition): assert isinstance(gc, OrNode) if not _is_route_with_starting_mols( diff --git a/syntheseus/tests/search/analysis/test_starting_molecule_match.py b/syntheseus/tests/search/analysis/test_starting_molecule_match.py index b89a567b..1efb093b 100644 --- a/syntheseus/tests/search/analysis/test_starting_molecule_match.py +++ b/syntheseus/tests/search/analysis/test_starting_molecule_match.py @@ -8,59 +8,30 @@ @pytest.mark.parametrize( "A,k,expected_partitions", [ - ([1, 2, 3], 1, [[[1, 2, 3]]]), ( - [1, 2, 3], - 2, - [ - [[1], [2, 3]], - [[2], [1, 3]], - [[3], [1, 2]], - [[1, 2], [3]], - [[1, 3], [2]], - [[2, 3], [1]], - ], - ), - ( - [1, 2, 3], - 3, - [ - [[1], [2], [3]], - [[1], [3], [2]], - [[2], [1], [3]], - [[2], [3], [1]], - [[3], [1], [2]], - [[3], [2], [1]], - ], + [1, 2], + 1, + [[[1, 2]]], ), - ([1, 2, 3], 4, []), - ([1, 2, 3], 5, []), ( - [1, 2, 3, 4], + [1, 2], 2, [ - [[1], [2, 3, 4]], - [[2], [1, 3, 4]], - [[3], [1, 2, 4]], - [[4], [1, 2, 3]], - [[1, 2], [3, 4]], - [[1, 3], [2, 4]], - [[1, 4], [2, 3]], - [[2, 3], [1, 4]], - [[2, 4], [1, 3]], - [[3, 4], [1, 2]], - [[1, 2, 3], [4]], - [[1, 2, 4], [3]], - [[1, 3, 4], [2]], - [[2, 3, 4], [1]], + [[1], [2]], + [[1], [1, 2]], + [[2], [1]], + [[2], [1, 2]], + [[1, 2], [1]], + [[1, 2], [2]], + [[1, 2], [1, 2]], ], ), ([], 1, []), # empty list has no partitions ([], 2, []), # test again with k=2 ], ) -def test_partition_set_valid(A, k, expected_partitions): - output = list(starting_molecule_match.partition_set(A, k)) +def test_split_into_subsets_valid(A, k, expected_partitions): + output = list(starting_molecule_match.split_into_subsets(A, k)) assert output == expected_partitions @@ -72,9 +43,9 @@ def test_partition_set_valid(A, k, expected_partitions): ([1, 2, 2], 0), # list has duplicates ], ) -def test_partition_set_invalid(A, k): +def test_split_into_subsets_invalid(A, k): with pytest.raises(AssertionError): - list(starting_molecule_match.partition_set(A, k)) + list(starting_molecule_match.split_into_subsets(A, k)) class TestStartingMoleculeMatch: From 545cd306b6e76644b9e8edd3b569a22f5a04ff91 Mon Sep 17 00:00:00 2001 From: Austin Tripp Date: Wed, 15 Nov 2023 15:15:04 +0000 Subject: [PATCH 5/5] Add extra pruning check for increased efficiency (whether the molecules are solvable at all). --- .../analysis/starting_molecule_match.py | 79 +++++++++++++++---- .../analysis/test_starting_molecule_match.py | 16 ++-- 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/syntheseus/search/analysis/starting_molecule_match.py b/syntheseus/search/analysis/starting_molecule_match.py index 5ec89700..e53425dd 100644 --- a/syntheseus/search/analysis/starting_molecule_match.py +++ b/syntheseus/search/analysis/starting_molecule_match.py @@ -5,9 +5,10 @@ import itertools import typing from collections.abc import Iterable +from typing import Optional from syntheseus.search.chem import Molecule -from syntheseus.search.graph.and_or import ANDOR_NODE, AndOrGraph, OrNode +from syntheseus.search.graph.and_or import ANDOR_NODE, AndNode, AndOrGraph, OrNode T = typing.TypeVar("T") @@ -24,7 +25,6 @@ def is_route_with_starting_mols( graph.root_node, starting_mols, forbidden_nodes, - _starting_mols_under_each_node(graph), ) @@ -73,6 +73,42 @@ def split_into_subsets(A: list[T], k: int) -> Iterable[list[list[T]]]: yield [list(s) for s in subsets] +def _is_solvable_from_starting_mols( + graph: AndOrGraph, + starting_mols: set[Molecule], + forbidden_nodes: Optional[set[ANDOR_NODE]] = None, +) -> dict[ANDOR_NODE, bool]: + """Get whether each node is solvable only from a specified set of starting molecules.""" + forbidden_nodes = forbidden_nodes or set() + + # Which nodes are solvable because they contain a starting molecule? + node_to_contains_start_mol = { + n: (isinstance(n, OrNode) and n.mol in starting_mols) for n in graph.nodes() + } + node_to_solvable = {n: False for n in graph.nodes()} + + # Do passes through all nodes + update_happened = True + while update_happened: + update_happened = False + for n in graph.nodes(): + successors_are_solvable = [node_to_solvable[c] for c in graph.successors(n)] + if n in forbidden_nodes: + new_solvable = False # regardless of successors, forbidden nodes are not solvable + elif isinstance(n, OrNode): + new_solvable = any(successors_are_solvable) or node_to_contains_start_mol[n] + elif isinstance(n, AndNode): + new_solvable = all(successors_are_solvable) + else: + raise ValueError + + if new_solvable != node_to_solvable[n]: + node_to_solvable[n] = new_solvable + update_happened = True + + return node_to_solvable + + def _starting_mols_under_each_node(graph: AndOrGraph) -> dict[ANDOR_NODE, set[Molecule]]: """Get set of molecules reachable under each node in the graph.""" @@ -100,20 +136,25 @@ def _is_route_with_starting_mols( start_node: OrNode, starting_mols: set[Molecule], forbidden_nodes: set[ANDOR_NODE], - node_to_all_reachable_starting_mols: dict[ANDOR_NODE, set[Molecule]], + node_to_solvable: Optional[dict[ANDOR_NODE, bool]] = None, + node_to_reachable_starting_mols: Optional[dict[ANDOR_NODE, set[Molecule]]] = None, ) -> bool: """ Recursive method to check whether there is a route in the graph, starting from `start_node` and excluding `forbidden_nodes`, whose leaves and exactly `starting_mols`. - To prune the search, we use the `node_to_all_reachable_starting_mols` dictionary, - which contains the set of all purchasable molecules reachable under each node (not necessarily part of a single route though). - We use this to prune the search early: if some molecules cannot be reached at all, there is no point checking whether - they might be reachable from a single route. + To prune the search early, we use the `node_to_solvable` dictionary, + which contains True if a node *might* be solvable from only the starting molecules. """ assert start_node in graph + # Compute node to solvable if not provided + if node_to_solvable is None: + node_to_solvable = _is_solvable_from_starting_mols(graph, starting_mols, forbidden_nodes) + if node_to_reachable_starting_mols is None: + node_to_reachable_starting_mols = _starting_mols_under_each_node(graph) + # Base case 1: starting mols is empty if len(starting_mols) == 0: return False @@ -122,8 +163,11 @@ def _is_route_with_starting_mols( if start_node in forbidden_nodes: return False - # Base case 3: starting are not reachable at all from the start node - if not (starting_mols <= node_to_all_reachable_starting_mols[start_node]): + # Base case 3: start node not solvable + if ( + not node_to_solvable[start_node] + or not node_to_reachable_starting_mols[start_node] >= starting_mols + ): return False # Base case 4: there is just one starting molecule and this OrNode contains it. @@ -135,22 +179,27 @@ def _is_route_with_starting_mols( # We do this by explicitly trying to find this synthesis route. for rxn_child in graph.successors(start_node): # If the starting molecules are not reachable from this reaction child, abort the search - if node_to_all_reachable_starting_mols[rxn_child] >= starting_mols: + if ( + node_to_solvable[rxn_child] + and node_to_reachable_starting_mols[rxn_child] >= starting_mols + ): # Also abort search if any grandchildren are forbidden grandchildren = list(graph.successors(rxn_child)) if not any(gc in forbidden_nodes for gc in grandchildren): # Main recurisve call: we partition K molecules among N children and check whether + # each child is solvable with its allocated molecules. for start_mol_partition in split_into_subsets( list(starting_mols), len(grandchildren) ): for gc, allocated_start_mols in zip(grandchildren, start_mol_partition): assert isinstance(gc, OrNode) if not _is_route_with_starting_mols( - graph, - gc, - set(allocated_start_mols), - forbidden_nodes | {start_node, rxn_child}, - node_to_all_reachable_starting_mols, + graph=graph, + start_node=gc, + starting_mols=set(allocated_start_mols), + forbidden_nodes=forbidden_nodes | {start_node, rxn_child}, + node_to_solvable=node_to_solvable, + node_to_reachable_starting_mols=node_to_reachable_starting_mols, ): break else: # i.e. loop finished without breaking diff --git a/syntheseus/tests/search/analysis/test_starting_molecule_match.py b/syntheseus/tests/search/analysis/test_starting_molecule_match.py index 1efb093b..47728005 100644 --- a/syntheseus/tests/search/analysis/test_starting_molecule_match.py +++ b/syntheseus/tests/search/analysis/test_starting_molecule_match.py @@ -65,12 +65,10 @@ def test_small_andorgraph( self, andor_graph_non_minimal: AndOrGraph, starting_smiles: str, expected_ans: bool ): starting_mols = {Molecule(s) for s in starting_smiles.split(".")} - assert ( - starting_molecule_match.is_route_with_starting_mols( - andor_graph_non_minimal, starting_mols - ) - == expected_ans + match = starting_molecule_match.is_route_with_starting_mols( + andor_graph_non_minimal, starting_mols ) + assert match == expected_ans @pytest.mark.parametrize( "starting_smiles,expected_ans", @@ -88,9 +86,7 @@ def test_large_andorgraph( self, andor_graph_with_many_routes: AndOrGraph, starting_smiles: str, expected_ans: bool ): starting_mols = {Molecule(s) for s in starting_smiles.split(".")} - assert ( - starting_molecule_match.is_route_with_starting_mols( - andor_graph_with_many_routes, starting_mols - ) - == expected_ans + match = starting_molecule_match.is_route_with_starting_mols( + andor_graph_with_many_routes, starting_mols ) + assert match == expected_ans