-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Automatically download a default checkpoint if it is not provided (#48)
This PR extends the single-step model classes to automatically download the default (trained on USPTO-50K) model checkpoint if no model directory is provided. This makes it even easier to get started with `syntheseus` as one does not have to manually download and unzip the files. Once downloaded, the checkpoints are cached. On top of this, I also extend the integration CI to run a new single-step model test that checks that the supported models do run and return reasonable predictions for a simple input. While locally the test works for all 6 `pip`-installable models, on GitHub the GPU is not available, and setting `device="cpu"` uncovers that some of the models appear to only work on GPU. Thus, for now the test only covers Chemformer, LocalRetro and MEGAN; the other models will be added in the future. When running those tests, I found that the links we provided for LocalRetro and RetroKNN were not quite correct (the initial upload missed some of the files, which I later corrected, but the links were still pointing to the first upload rather than the corrected one). Morever, I found that RetroKNN was zipped using a method that `python` could not deal with, so I re-generated the file and re-uploaded it (again fixing the link).
- Loading branch information
Showing
15 changed files
with
275 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from pathlib import Path | ||
from typing import Optional, Union | ||
|
||
from syntheseus.interface.models import ( | ||
BackwardReactionModel, | ||
ForwardReactionModel, | ||
InputType, | ||
OutputType, | ||
ReactionModel, | ||
) | ||
from syntheseus.reaction_prediction.utils.downloading import get_default_model_dir_from_cache | ||
|
||
|
||
class ExternalReactionModel(ReactionModel[InputType, OutputType]): | ||
"""Base class for the external reaction models, abstracting out common functinality.""" | ||
|
||
def __init__( | ||
self, model_dir: Optional[Union[str, Path]] = None, device: Optional[str] = None | ||
) -> None: | ||
import torch | ||
|
||
self.model_dir = Path(model_dir or self.get_default_model_dir()) | ||
self.device = device or ("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
def get_default_model_dir(self) -> Path: | ||
model_dir = get_default_model_dir_from_cache(self.name, is_forward=self.is_forward()) | ||
|
||
if model_dir is None: | ||
raise ValueError( | ||
f"Could not obtain a default checkpoint for model {self.name}, " | ||
"please provide an explicit value for `model_dir`" | ||
) | ||
|
||
return model_dir | ||
|
||
@property | ||
def name(self) -> str: | ||
return self.__class__.__name__.removesuffix("Model") | ||
|
||
|
||
class ExternalBackwardReactionModel(ExternalReactionModel, BackwardReactionModel): | ||
pass | ||
|
||
|
||
class ExternalForwardReactionModel(ExternalReactionModel, ForwardReactionModel): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
10 changes: 10 additions & 0 deletions
10
syntheseus/reaction_prediction/inference/default_checkpoint_links.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
backward: | ||
Chemformer: https://figshare.com/ndownloader/files/42009888 | ||
GLN: https://figshare.com/ndownloader/files/42012720 | ||
LocalRetro: https://figshare.com/ndownloader/files/42287319 | ||
MEGAN: https://figshare.com/ndownloader/files/42012732 | ||
MHNreact: https://figshare.com/ndownloader/files/42012777 | ||
RetroKNN: https://figshare.com/ndownloader/files/43636584 | ||
RootAligned: https://figshare.com/ndownloader/files/42012792 | ||
forward: | ||
Chemformer: https://figshare.com/ndownloader/files/42012708 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.