Skip to content

Commit

Permalink
Return model predictions as dataclasses instead of pydantic models (#47)
Browse files Browse the repository at this point in the history
The `Prediction` and `PredictionList` objects were made based on
`pydantic`, anticipating this will make it easier to e.g. expose a model
through an API using a framework such as FastAPI. However, this meant
introducing `pydantic` into the core of the library in order to
facilitate a feature that most users may not care about. On the other
hand, FastAPI can deal with non-`pydantic` classes too (`pydantic` is
only needed for some extras like validation). Thus, once we want to
integrate the server code properly, we should do it in such a way that
`pydantic` (and further dependencies like `fastapi`) remain optional
extras rather than part of the core of the library.

For now, in this PR I make the core prediction objects plain
dataclasses, and drop `pydantic` from dependencies. This required a few
changes (e.g. dropping special handling of prediction classes in
`dictify`, using `default_factory` for the dict-valued fields), as well
as fixing some of the tests, which did not provide all fields needed to
construct the prediction objects (which for some reason is still
accepted by `pydantic`).
  • Loading branch information
kmaziarz authored Dec 14, 2023
1 parent fc8a4b0 commit 2310e89
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 37 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.

- Simplify single-step model setup ([#41](https://github.com/microsoft/syntheseus/pull/41)) ([@kmaziarz])
- Refactor single-step evaluation script and move it to cli/ ([#43](https://github.com/microsoft/syntheseus/pull/43)) ([@kmaziarz])
- Return model predictions as dataclasses instead of pydantic models ([#47](https://github.com/microsoft/syntheseus/pull/47)) ([@kmaziarz])

## [0.2.0] - 2023-11-21

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ dependencies = [
"networkx", # search
"numpy", # reaction_prediction, search
"omegaconf", # reaction_prediction
"pydantic>=1.10.5,<2", # reaction_prediction
"rdkit", # reaction_prediction, search
"tqdm", # reaction_prediction
]
Expand Down
48 changes: 18 additions & 30 deletions syntheseus/interface/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,20 @@

import math
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Generic, List, Optional, TypeVar

from pydantic import root_validator
from pydantic.generics import GenericModel

from syntheseus.interface.bag import Bag
from syntheseus.interface.molecule import Molecule

InputType = TypeVar("InputType")
OutputType = TypeVar("OutputType")


class Prediction(GenericModel, Generic[InputType, OutputType]):
@dataclass(frozen=True, order=False)
class Prediction(Generic[InputType, OutputType]):
"""Reaction prediction from a model, either a forward or a backward one."""

# Make `pydantic` accept custom types such as `Molecule` or `Bag`.
class Config:
arbitrary_types_allowed = True

# The molecule that the prediction is for and the predicted output:
input: InputType
output: OutputType
Expand All @@ -32,9 +27,14 @@ class Config:
reaction: Optional[str] = None # Reaction smiles.
rxnid: Optional[int] = None # Template id, if applicable.

# Dictionary to hold additional metadata. Note that we use a mutable default value here, which
# could be a problem in plain Python, but is handled correctly by `pydantic`.
metadata: Dict[str, Any] = {}
# Dictionary to hold additional metadata.
metadata: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
if self.probability is not None and self.log_prob is not None:
raise ValueError(
"Probability can be stored as probability or log probability, not both"
)

def get_prob(self) -> float:
if self.probability is not None:
Expand All @@ -52,33 +52,21 @@ def get_log_prob(self) -> float:
else:
raise ValueError("Prediction does not have associated log prob or probability value.")

@root_validator()
def check_at_most_one_source_prob(cls, values):
if values.get("probability") is not None and values.get("log_prob") is not None:
raise ValueError(
"Probability can be stored as probability or log probability, not both"
)
return values


class PredictionList(GenericModel, Generic[InputType, OutputType]):
@dataclass(frozen=True, order=False)
class PredictionList(Generic[InputType, OutputType]):
"""Several possible predictions."""

# Make `pydantic` accept custom types such as `Molecule` or `Bag`.
class Config:
arbitrary_types_allowed = True

input: InputType
predictions: List[Prediction[InputType, OutputType]]

# Dictionary to hold additional metadata (see note above regarding the mutable default value).
metadata: Dict[str, Any] = {}
# Dictionary to hold additional metadata.
metadata: Dict[str, Any] = field(default_factory=dict)

def truncated(self, num_results: int) -> PredictionList[InputType, OutputType]:
fields = self.dict()
fields["predictions"] = fields["predictions"][:num_results]

return PredictionList(**fields)
return PredictionList(
input=self.input, predictions=self.predictions[:num_results], metadata=self.metadata
)


class ReactionModel(Generic[InputType, OutputType]):
Expand Down
3 changes: 0 additions & 3 deletions syntheseus/reaction_prediction/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np

from syntheseus.interface.bag import Bag
from syntheseus.interface.models import Prediction, PredictionList
from syntheseus.interface.molecule import Molecule


Expand Down Expand Up @@ -48,8 +47,6 @@ def dictify(data: Any) -> Any:
elif isinstance(data, (List, tuple, Bag)):
# Captures possible lists of `Prediction`s and lists of `PredictionList`s
return [dictify(x) for x in data]
elif isinstance(data, (PredictionList, Prediction)):
return dictify(dict(data))
elif isinstance(data, dict):
return {k: dictify(v) for k, v in data.items()}
elif is_dataclass(data):
Expand Down
13 changes: 11 additions & 2 deletions syntheseus/tests/interface/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@

import numpy as np

from syntheseus.interface.models import Prediction
from syntheseus.interface.models import Prediction, PredictionList
from syntheseus.interface.molecule import Molecule

# TODO(kmaziarz): Currently this test mostly checks that importing from `models.py` works, and that
# a `Prediction` object can be instantiated. We should extend it later.


def test_prediction():
prediction = Prediction(probability=0.5)
prediction = Prediction(input=Molecule("C"), output=Molecule("CC"), probability=0.5)
assert np.isclose(prediction.get_prob(), 0.5)
assert np.isclose(prediction.get_log_prob(), math.log(0.5))

other_prediction = Prediction(input=Molecule("N"), output=Molecule("NC=O"), probability=0.5)
prediction_list = PredictionList(
input=Molecule("C"), predictions=[prediction, other_prediction]
)

assert prediction_list.predictions == [prediction, other_prediction]
assert prediction_list.truncated(num_results=1).predictions == [prediction]
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def __call__(
self, inputs: List[Molecule], num_results: int
) -> List[PredictionList[Molecule, Bag[Molecule]]]:
return [
PredictionList(predictions=[Prediction(input=mol, output=Bag([mol]))]) for mol in inputs
PredictionList(input=mol, predictions=[Prediction(input=mol, output=Bag([mol]))])
for mol in inputs
]

def get_parameters(self):
Expand Down

0 comments on commit 2310e89

Please sign in to comment.