Skip to content

Commit

Permalink
Add PDVN into the search CLI (#46)
Browse files Browse the repository at this point in the history
This PR adds the PDVN search algorithm to `cli/search`, and the
`search_algorithm` argument now supports `retro_star`, `mcts`, and
`pdvn`.

---------

Co-authored-by: Guoqing Liu <guoqingliu@microsoft.com>
Co-authored-by: Krzysztof Maziarz <krmaziar@microsoft.com>
  • Loading branch information
3 people authored Dec 13, 2023
1 parent 5ed4907 commit fc8a4b0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.
### Added

- Add a general CLI endpoint ([#44](https://github.com/microsoft/syntheseus/pull/44)) ([@kmaziarz])
- Add support for PDVN to the search CLI ([#46](https://github.com/microsoft/syntheseus/pull/46)) ([@fiberleif])

### Changed

Expand Down
40 changes: 37 additions & 3 deletions syntheseus/cli/search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Script for running end-to-end retrosynthetic search.
The supported single-step model types are listed in `syntheseus/reaction_prediction/cli/eval.py`;
each can be combined with either MCTS or Retro* to perform search.
The supported single-step model types are listed in `syntheseus/cli/eval_single_step.py`;
each can be combined with Retro*, MCTS, or PDVN to perform search.
Example invocation:
python ./syntheseus/cli/search.py \
Expand Down Expand Up @@ -38,6 +38,7 @@
from syntheseus.search.algorithms.best_first.retro_star import RetroStarSearch
from syntheseus.search.algorithms.mcts import base as mcts_base
from syntheseus.search.algorithms.mcts.molset import MolSetMCTS
from syntheseus.search.algorithms.pdvn import PDVN_MCTS
from syntheseus.search.analysis.route_extraction import iter_routes_time_order
from syntheseus.search.analysis.solution_time import get_first_solution_time
from syntheseus.search.chem import Molecule
Expand Down Expand Up @@ -79,6 +80,27 @@ class MCTSConfig:
bound_function_class: str = "pucb_bound"


@dataclass
class PDVNConfig:
max_expansion_depth: int = 10

value_function_syn_class: str = "ConstantNodeEvaluator"
value_function_syn_kwargs: Dict[str, Any] = field(default_factory=lambda: {"constant": 0.5})

value_function_cost_class: str = "ConstantNodeEvaluator"
value_function_cost_kwargs: Dict[str, Any] = field(default_factory=lambda: {"constant": 0.0})

and_node_cost_fn_class: str = "ConstantNodeEvaluator"
and_node_cost_fn_kwargs: Dict[str, Any] = field(default_factory=lambda: {"constant": 0.1})

policy_class: str = "ReactionModelProbPolicy"
policy_kwargs: Dict[str, Any] = field(default_factory=dict)

c_dead: float = 5.0
bound_constant: float = 1e2
bound_function_class: str = "pucb_bound"


@dataclass
class SearchConfig(BackwardModelConfig):
"""Config for running search for given search targets."""
Expand All @@ -104,9 +126,10 @@ class SearchConfig(BackwardModelConfig):
reaction_model_use_cache: bool = True # Whether to cache the results

# Fields configuring the search algorithm
search_algorithm: str = "retro_star" # Either "mcts" or "retro_star"
search_algorithm: str = "retro_star" # "retro_star", "mcts", or "pdvn"
retro_star_config: RetroStarConfig = RetroStarConfig()
mcts_config: MCTSConfig = MCTSConfig()
pdvn_config: PDVNConfig = PDVNConfig()

# Fields configuring what to save after the run
save_graph: bool = True # Whether to save the full reaction graph (can be large)
Expand Down Expand Up @@ -198,6 +221,17 @@ def build_node_evaluator(key: str) -> None:
del alg_kwargs["bound_function_class"]

alg = MolSetMCTS(**alg_kwargs)
elif config.search_algorithm == "pdvn":
alg_kwargs.update(cast(Dict[str, Any], OmegaConf.to_container(config.pdvn_config)))
build_node_evaluator("value_function_syn")
build_node_evaluator("value_function_cost")
build_node_evaluator("and_node_cost_fn")
build_node_evaluator("policy")

alg_kwargs["bound_function"] = lookup_by_name(mcts_base, alg_kwargs["bound_function_class"])
del alg_kwargs["bound_function_class"]

alg = PDVN_MCTS(**alg_kwargs)
else:
raise NotImplementedError(f"Unsupported search algorithm {config.search_algorithm}")

Expand Down

0 comments on commit fc8a4b0

Please sign in to comment.