Skip to content

Commit

Permalink
Automatically download a default checkpoint if it is not provided (#48)
Browse files Browse the repository at this point in the history
This PR extends the single-step model classes to automatically download
the default (trained on USPTO-50K) model checkpoint if no model
directory is provided. This makes it even easier to get started with
`syntheseus` as one does not have to manually download and unzip the
files. Once downloaded, the checkpoints are cached.

On top of this, I also extend the integration CI to run a new
single-step model test that checks that the supported models do run and
return reasonable predictions for a simple input. While locally the test
works for all 6 `pip`-installable models, on GitHub the GPU is not
available, and setting `device="cpu"` uncovers that some of the models
appear to only work on GPU. Thus, for now the test only covers
Chemformer, LocalRetro and MEGAN; the other models will be added in the
future.

When running those tests, I found that the links we provided for
LocalRetro and RetroKNN were not quite correct (the initial upload
missed some of the files, which I later corrected, but the links were
still pointing to the first upload rather than the corrected one).
Morever, I found that RetroKNN was zipped using a method that `python`
could not deal with, so I re-generated the file and re-uploaded it
(again fixing the link).
  • Loading branch information
kmaziarz authored Dec 15, 2023
1 parent 2310e89 commit fcf0b44
Show file tree
Hide file tree
Showing 15 changed files with 275 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ jobs:
- name: Install syntheseus with all single-step models
run: |
pip install .[all]
- name: Run single-step model tests
run: |
python -m pytest ./syntheseus/tests/reaction_prediction/inference/test_models.py
5 changes: 2 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.

- Add a general CLI endpoint ([#44](https://github.com/microsoft/syntheseus/pull/44)) ([@kmaziarz])
- Add support for PDVN to the search CLI ([#46](https://github.com/microsoft/syntheseus/pull/46)) ([@fiberleif])
- Add initial static documentation ([#45](https://github.com/microsoft/syntheseus/pull/45)) ([@kmaziarz])

### Changed

- Simplify single-step model setup ([#41](https://github.com/microsoft/syntheseus/pull/41)) ([@kmaziarz])
- Simplify single-step model setup ([#41](https://github.com/microsoft/syntheseus/pull/41), [#48](https://github.com/microsoft/syntheseus/pull/48)) ([@kmaziarz])
- Refactor single-step evaluation script and move it to cli/ ([#43](https://github.com/microsoft/syntheseus/pull/43)) ([@kmaziarz])
- Return model predictions as dataclasses instead of pydantic models ([#47](https://github.com/microsoft/syntheseus/pull/47)) ([@kmaziarz])

Expand All @@ -23,9 +24,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.
### Changed

- Select search hyperparameters depending on which algorithm and single-step model are used ([#30](https://github.com/microsoft/syntheseus/pull/30)) ([@kmaziarz])
- Add option to override time tolerance in algorithm tests ([#25](https://github.com/microsoft/syntheseus/pull/25)) ([@austint])
- Improve the heuristic used for estimating diversity ([#22](https://github.com/microsoft/syntheseus/pull/22), [#28](https://github.com/microsoft/syntheseus/pull/28)) ([@kmaziarz])
- Improve the aesthetics of `README.md` ([#19](https://github.com/microsoft/syntheseus/pull/19)) ([@kmaziarz])

### Added

Expand Down
11 changes: 8 additions & 3 deletions docs/single_step.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
Syntheseus currently supports 7 established single-step models. For convenience, for each model we also include a checkpoint trained on USPTO-50K.
Syntheseus currently supports 7 established single-step models.

For convenience, for each model we include a default checkpoint trained on USPTO-50K.
If no checkpoint directory is provided during model loading, `syntheseus` will automatically download a default checkpoint and cache it on disk for future use.
The default path for the cache is `$HOME/.cache/torch/syntheseus`, but it can be overriden by setting the `SYNTHESEUS_CACHE_DIR` environment variable.
See table below for the links to the default checkpoints.

| Model checkpoint link | Source |
|----------------------------------------------------------------|--------|
| [Chemformer](https://figshare.com/ndownloader/files/42009888) | finetuned by us starting from checkpoint released by authors |
| [GLN](https://figshare.com/ndownloader/files/42012720) | released by authors |
| [LocalRetro](https://figshare.com/ndownloader/files/42012729) | trained by us |
| [LocalRetro](https://figshare.com/ndownloader/files/42287319) | trained by us |
| [MEGAN](https://figshare.com/ndownloader/files/42012732) | trained by us |
| [MHNreact](https://figshare.com/ndownloader/files/42012777) | trained by us |
| [RetroKNN](https://figshare.com/ndownloader/files/42012786) | trained by us |
| [RetroKNN](https://figshare.com/ndownloader/files/43636584) | trained by us |
| [RootAligned](https://figshare.com/ndownloader/files/42012792) | released by authors |

??? note "More advanced datasets"
Expand Down
46 changes: 46 additions & 0 deletions syntheseus/reaction_prediction/inference/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pathlib import Path
from typing import Optional, Union

from syntheseus.interface.models import (
BackwardReactionModel,
ForwardReactionModel,
InputType,
OutputType,
ReactionModel,
)
from syntheseus.reaction_prediction.utils.downloading import get_default_model_dir_from_cache


class ExternalReactionModel(ReactionModel[InputType, OutputType]):
"""Base class for the external reaction models, abstracting out common functinality."""

def __init__(
self, model_dir: Optional[Union[str, Path]] = None, device: Optional[str] = None
) -> None:
import torch

self.model_dir = Path(model_dir or self.get_default_model_dir())
self.device = device or ("cuda:0" if torch.cuda.is_available() else "cpu")

def get_default_model_dir(self) -> Path:
model_dir = get_default_model_dir_from_cache(self.name, is_forward=self.is_forward())

if model_dir is None:
raise ValueError(
f"Could not obtain a default checkpoint for model {self.name}, "
"please provide an explicit value for `model_dir`"
)

return model_dir

@property
def name(self) -> str:
return self.__class__.__name__.removesuffix("Model")


class ExternalBackwardReactionModel(ExternalReactionModel, BackwardReactionModel):
pass


class ExternalForwardReactionModel(ExternalReactionModel, ForwardReactionModel):
pass
18 changes: 8 additions & 10 deletions syntheseus/reaction_prediction/inference/chemformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

import sys
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union, cast
from typing import Any, Dict, List, Tuple, cast

from syntheseus.interface.bag import Bag
from syntheseus.interface.models import InputType, OutputType, PredictionList, ReactionModel
from syntheseus.interface.models import InputType, OutputType, PredictionList
from syntheseus.interface.molecule import Molecule
from syntheseus.reaction_prediction.inference.base import ExternalReactionModel
from syntheseus.reaction_prediction.utils.inference import (
get_module_path,
get_unique_file_in_dir,
Expand All @@ -21,18 +22,18 @@
from syntheseus.reaction_prediction.utils.misc import suppress_outputs


class ChemformerModel(ReactionModel[InputType, OutputType]):
def __init__(
self, model_dir: Union[str, Path], device: str = "cuda:0", is_forward: bool = False
) -> None:
class ChemformerModel(ExternalReactionModel[InputType, OutputType]):
def __init__(self, *args, is_forward: bool = False, **kwargs) -> None:
"""Initializes the Chemformer model wrapper.
Assumed format of the model directory:
- `model_dir` contains the model checkpoint as the only `*.ckpt` file
"""
self._is_forward = is_forward
super().__init__(*args, **kwargs)

# There should be exaclty one `*.ckpt` file under `model_dir`.
chkpt_path = get_unique_file_in_dir(model_dir, pattern="*.ckpt")
chkpt_path = get_unique_file_in_dir(self.model_dir, pattern="*.ckpt")

import chemformer

Expand All @@ -44,9 +45,6 @@ def __init__(
from chemformer.molbart.decoder import DecodeSampler
from chemformer.molbart.models.pre_train import BARTModel

self._is_forward = is_forward
self.device = device

# Vocab path for the tokenizer is relative from Chemformer dir.
self.tokenizer = util.load_tokeniser(
Path(chemformer_root_dir) / util.DEFAULT_VOCAB_PATH, util.DEFAULT_CHEM_TOKEN_START
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
backward:
Chemformer: https://figshare.com/ndownloader/files/42009888
GLN: https://figshare.com/ndownloader/files/42012720
LocalRetro: https://figshare.com/ndownloader/files/42287319
MEGAN: https://figshare.com/ndownloader/files/42012732
MHNreact: https://figshare.com/ndownloader/files/42012777
RetroKNN: https://figshare.com/ndownloader/files/43636584
RootAligned: https://figshare.com/ndownloader/files/42012792
forward:
Chemformer: https://figshare.com/ndownloader/files/42012708
30 changes: 14 additions & 16 deletions syntheseus/reaction_prediction/inference/gln.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,36 @@

import sys
from pathlib import Path
from typing import List, Union
from typing import List

from syntheseus.interface.models import BackwardPredictionList, BackwardReactionModel
from syntheseus.interface.models import BackwardPredictionList
from syntheseus.interface.molecule import Molecule
from syntheseus.reaction_prediction.inference.base import ExternalBackwardReactionModel
from syntheseus.reaction_prediction.utils.inference import process_raw_smiles_outputs
from syntheseus.reaction_prediction.utils.misc import suppress_outputs


class GLNModel(BackwardReactionModel):
def __init__(
self,
model_dir: Union[str, Path],
device: str = "cuda:0",
dataset_name: str = "schneider50k",
) -> None:
class GLNModel(ExternalBackwardReactionModel):
def __init__(self, *args, dataset_name: str = "schneider50k", **kwargs) -> None:
"""Initializes the GLN model wrapper.
Assumed format of the model directory:
- `model_dir` contains files necessary to build `RetroGLN`
- `model_dir/{dataset_name}.ckpt` is the model checkpoint
- `model_dir/cooked_{dataset_name}/atom_list.txt` is the atom type list
"""
super().__init__(*args, **kwargs)

import torch

chkpt_path = Path(model_dir) / f"{dataset_name}.ckpt"
args = {
"dropbox": model_dir,
chkpt_path = Path(self.model_dir) / f"{dataset_name}.ckpt"
gln_args = {
"dropbox": self.model_dir,
"data_name": dataset_name,
"model_for_test": chkpt_path,
"tpl_name": "default",
"f_atoms": Path(model_dir) / f"cooked_{dataset_name}" / "atom_list.txt",
"gpu": torch.device(device).index,
"f_atoms": Path(self.model_dir) / f"cooked_{dataset_name}" / "atom_list.txt",
"gpu": torch.device(self.device).index,
}

# Suppress most of the prints from GLN's internals. This only works on messages that
Expand All @@ -50,14 +48,14 @@ def __init__(
from gln.common.cmd_args import cmd_args

sys.argv = []
for name, value in args.items():
for name, value in gln_args.items():
setattr(cmd_args, name, value)
sys.argv += [f"-{name}", str(value)]

# The global state hackery has to happen before this.
from gln.test.model_inference import RetroGLN

self.model = RetroGLN(model_dir, chkpt_path)
self.model = RetroGLN(self.model_dir, chkpt_path)

def get_parameters(self):
return self.model.gln.parameters()
Expand Down
18 changes: 10 additions & 8 deletions syntheseus/reaction_prediction/inference/local_retro.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

import sys
from pathlib import Path
from typing import Any, List, Union
from typing import Any, List

from syntheseus.interface.models import BackwardPredictionList, BackwardReactionModel
from syntheseus.interface.models import BackwardPredictionList
from syntheseus.interface.molecule import Molecule
from syntheseus.reaction_prediction.inference.base import ExternalBackwardReactionModel
from syntheseus.reaction_prediction.utils.inference import (
get_module_path,
get_unique_file_in_dir,
Expand All @@ -21,15 +22,16 @@
from syntheseus.reaction_prediction.utils.misc import suppress_outputs


class LocalRetroModel(BackwardReactionModel):
def __init__(self, model_dir: Union[str, Path], device: str = "cuda:0") -> None:
class LocalRetroModel(ExternalBackwardReactionModel):
def __init__(self, *args, **kwargs) -> None:
"""Initializes the LocalRetro model wrapper.
Assumed format of the model directory:
- `model_dir` contains the model checkpoint as the only `*.pth` file
- `model_dir` contains the config as the only `*.json` file
- `model_dir/data` contains `*.csv` data files needed by LocalRetro
"""
super().__init__(*args, **kwargs)

import local_retro
from local_retro import scripts
Expand All @@ -41,13 +43,13 @@ def __init__(self, model_dir: Union[str, Path], device: str = "cuda:0") -> None:
from local_retro.Retrosynthesis import load_templates
from local_retro.scripts.utils import init_featurizer, load_model

data_dir = Path(model_dir) / "data"
data_dir = Path(self.model_dir) / "data"
self.args = init_featurizer(
{
"mode": "test",
"device": device,
"model_path": get_unique_file_in_dir(model_dir, pattern="*.pth"),
"config_path": get_unique_file_in_dir(model_dir, pattern="*.json"),
"device": self.device,
"model_path": get_unique_file_in_dir(self.model_dir, pattern="*.pth"),
"config_path": get_unique_file_in_dir(self.model_dir, pattern="*.json"),
"data_dir": data_dir,
"rxn_class_given": False,
}
Expand Down
19 changes: 10 additions & 9 deletions syntheseus/reaction_prediction/inference/megan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
import os
import sys
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, Optional

from rdkit import Chem

from syntheseus.interface.models import BackwardPredictionList, BackwardReactionModel
from syntheseus.interface.models import BackwardPredictionList
from syntheseus.interface.molecule import Molecule
from syntheseus.reaction_prediction.inference.base import ExternalBackwardReactionModel
from syntheseus.reaction_prediction.utils.inference import (
get_module_path,
get_unique_file_in_dir,
Expand All @@ -25,14 +26,14 @@
from syntheseus.reaction_prediction.utils.misc import suppress_outputs


class MEGANModel(BackwardReactionModel):
class MEGANModel(ExternalBackwardReactionModel):
def __init__(
self,
model_dir: Union[str, Path],
device: str = "cuda:0",
*args,
n_max_atoms: int = 200,
max_gen_steps: int = 16,
beam_batch_size: int = 10,
**kwargs,
) -> None:
"""Initializes the MEGAN model wrapper.
Expand All @@ -41,6 +42,7 @@ def __init__(
- `model_dir/model_best.pt` is the model checkpoint
- `model_dir/{featurizer_key}` contains files needed to build MEGAN's featurizer
"""
super().__init__(*args, **kwargs)

import gin
import megan
Expand All @@ -64,23 +66,22 @@ def __init__(
self.beam_batch_size = beam_batch_size

# Get the model config using `gin`.
gin.parse_config_file(get_unique_file_in_dir(model_dir, pattern="*.gin"))
gin.parse_config_file(get_unique_file_in_dir(self.model_dir, pattern="*.gin"))

# Set up the data featurizer.
featurizer_key = gin.query_parameter("train_megan.featurizer_key")
featurizer = get_featurizer(featurizer_key)

# Get the action vocab and masks.
assert isinstance(featurizer, MeganTrainingSamplesFeaturizer)
self.action_vocab = featurizer.get_actions_vocabulary(model_dir)
self.action_vocab = featurizer.get_actions_vocabulary(self.model_dir)
self.base_action_masks = get_base_action_masks(
n_max_atoms + 1, action_vocab=self.action_vocab
)
self.rdkit_cache = RdkitCache(props=self.action_vocab["props"])
self.device = device

# Load the MEGAN model.
checkpoint = load_state_dict(Path(model_dir) / "model_best.pt")
checkpoint = load_state_dict(Path(self.model_dir) / "model_best.pt")
self.model = MeganModel(
n_atom_actions=self.action_vocab["n_atom_actions"],
n_bond_actions=self.action_vocab["n_bond_actions"],
Expand Down
Loading

0 comments on commit fcf0b44

Please sign in to comment.