Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return model predictions as dataclasses instead of pydantic models #47

Merged
merged 5 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.

- Simplify single-step model setup ([#41](https://github.com/microsoft/syntheseus/pull/41)) ([@kmaziarz])
- Refactor single-step evaluation script and move it to cli/ ([#43](https://github.com/microsoft/syntheseus/pull/43)) ([@kmaziarz])
- Return model predictions as dataclasses instead of pydantic models ([#47](https://github.com/microsoft/syntheseus/pull/47)) ([@kmaziarz])

## [0.2.0] - 2023-11-21

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ dependencies = [
"networkx", # search
"numpy", # reaction_prediction, search
"omegaconf", # reaction_prediction
"pydantic>=1.10.5,<2", # reaction_prediction
"rdkit", # reaction_prediction, search
"tqdm", # reaction_prediction
]
Expand Down
48 changes: 18 additions & 30 deletions syntheseus/interface/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,20 @@

import math
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Generic, List, Optional, TypeVar

from pydantic import root_validator
from pydantic.generics import GenericModel

from syntheseus.interface.bag import Bag
from syntheseus.interface.molecule import Molecule

InputType = TypeVar("InputType")
OutputType = TypeVar("OutputType")


class Prediction(GenericModel, Generic[InputType, OutputType]):
@dataclass(frozen=True, order=False)
class Prediction(Generic[InputType, OutputType]):
"""Reaction prediction from a model, either a forward or a backward one."""

# Make `pydantic` accept custom types such as `Molecule` or `Bag`.
class Config:
arbitrary_types_allowed = True

# The molecule that the prediction is for and the predicted output:
input: InputType
output: OutputType
Expand All @@ -32,9 +27,14 @@ class Config:
reaction: Optional[str] = None # Reaction smiles.
rxnid: Optional[int] = None # Template id, if applicable.

# Dictionary to hold additional metadata. Note that we use a mutable default value here, which
# could be a problem in plain Python, but is handled correctly by `pydantic`.
metadata: Dict[str, Any] = {}
# Dictionary to hold additional metadata.
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
if self.probability is not None and self.log_prob is not None:
raise ValueError(
"Probability can be stored as probability or log probability, not both"
kmaziarz marked this conversation as resolved.
Show resolved Hide resolved
)

def get_prob(self) -> float:
if self.probability is not None:
Expand All @@ -52,33 +52,21 @@ def get_log_prob(self) -> float:
else:
raise ValueError("Prediction does not have associated log prob or probability value.")

@root_validator()
def check_at_most_one_source_prob(cls, values):
if values.get("probability") is not None and values.get("log_prob") is not None:
raise ValueError(
"Probability can be stored as probability or log probability, not both"
)
return values


class PredictionList(GenericModel, Generic[InputType, OutputType]):
@dataclass(frozen=True, order=False)
class PredictionList(Generic[InputType, OutputType]):
"""Several possible predictions."""

# Make `pydantic` accept custom types such as `Molecule` or `Bag`.
class Config:
arbitrary_types_allowed = True

input: InputType
predictions: List[Prediction[InputType, OutputType]]

# Dictionary to hold additional metadata (see note above regarding the mutable default value).
metadata: Dict[str, Any] = {}
# Dictionary to hold additional metadata.
metadata: Dict[str, Any] = field(default_factory=dict)

def truncated(self, num_results: int) -> PredictionList[InputType, OutputType]:
fields = self.dict()
fields["predictions"] = fields["predictions"][:num_results]

return PredictionList(**fields)
return PredictionList(
input=self.input, predictions=self.predictions[:num_results], metadata=self.metadata
)


class ReactionModel(Generic[InputType, OutputType]):
Expand Down
3 changes: 0 additions & 3 deletions syntheseus/reaction_prediction/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np

from syntheseus.interface.bag import Bag
from syntheseus.interface.models import Prediction, PredictionList
from syntheseus.interface.molecule import Molecule


Expand Down Expand Up @@ -48,8 +47,6 @@ def dictify(data: Any) -> Any:
elif isinstance(data, (List, tuple, Bag)):
# Captures possible lists of `Prediction`s and lists of `PredictionList`s
return [dictify(x) for x in data]
elif isinstance(data, (PredictionList, Prediction)):
return dictify(dict(data))
elif isinstance(data, dict):
return {k: dictify(v) for k, v in data.items()}
elif is_dataclass(data):
Expand Down
13 changes: 11 additions & 2 deletions syntheseus/tests/interface/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@

import numpy as np

from syntheseus.interface.models import Prediction
from syntheseus.interface.models import Prediction, PredictionList
from syntheseus.interface.molecule import Molecule

# TODO(kmaziarz): Currently this test mostly checks that importing from `models.py` works, and that
# a `Prediction` object can be instantiated. We should extend it later.


def test_prediction():
prediction = Prediction(probability=0.5)
prediction = Prediction(input=Molecule("C"), output=Molecule("CC"), probability=0.5)
assert np.isclose(prediction.get_prob(), 0.5)
assert np.isclose(prediction.get_log_prob(), math.log(0.5))

other_prediction = Prediction(input=Molecule("N"), output=Molecule("NC=O"), probability=0.5)
prediction_list = PredictionList(
input=Molecule("C"), predictions=[prediction, other_prediction]
)

assert prediction_list.predictions == [prediction, other_prediction]
assert prediction_list.truncated(num_results=1).predictions == [prediction]
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def __call__(
self, inputs: List[Molecule], num_results: int
) -> List[PredictionList[Molecule, Bag[Molecule]]]:
return [
PredictionList(predictions=[Prediction(input=mol, output=Bag([mol]))]) for mol in inputs
PredictionList(input=mol, predictions=[Prediction(input=mol, output=Bag([mol]))])
for mol in inputs
]

def get_parameters(self):
Expand Down
Loading