Skip to content

Commit

Permalink
Merge SupervisedDataset & FixedNoiseDataset (#1945)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1945

This diff deprecates `FixedNoiseDataset` and merges it into `SupervisedDataset` with Yvar becoming an optional field. This also simplifies the class hierarchy a bit, removing `SupervisedDatasetMeta` in favor of an `__init__` method.

I plan to follow up on this by adding optional metric names to datasets and introducing a MultiTaskDataset, which will simplify some of the planned work in Ax MBM.

Reviewed By: esantorella

Differential Revision: D47729430

fbshipit-source-id: 551cd78a02755505573b10ea1f075aa21f838ab7
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Aug 9, 2023
1 parent 633d9c0 commit 3506538
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 187 deletions.
6 changes: 3 additions & 3 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.constraints import get_outcome_constraint_transforms
from botorch.utils.containers import BotorchContainer
from botorch.utils.datasets import BotorchDataset, SupervisedDataset
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
FastNondominatedPartitioning,
NondominatedPartitioning,
Expand All @@ -114,7 +114,7 @@


def _field_is_shared(
datasets: Union[Iterable[BotorchDataset], Dict[Hashable, BotorchDataset]],
datasets: Union[Iterable[SupervisedDataset], Dict[Hashable, SupervisedDataset]],
fieldname: Hashable,
) -> bool:
r"""Determines whether or not a given field is shared by all datasets."""
Expand All @@ -136,7 +136,7 @@ def _field_is_shared(


def _get_dataset_field(
dataset: MaybeDict[BotorchDataset],
dataset: MaybeDict[SupervisedDataset],
fieldname: str,
transform: Optional[Callable[[BotorchContainer], Any]] = None,
join_rule: Optional[Callable[[Sequence[Any]], Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion botorch/models/gp_regression_mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def construct_inputs(
likelihood: Optional[Likelihood] = None,
**kwargs: Any,
) -> Dict[str, Any]:
r"""Construct `Model` keyword arguments from a dict of `BotorchDataset`.
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
Args:
training_data: A `SupervisedDataset` containing the training data.
Expand Down
6 changes: 3 additions & 3 deletions botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from botorch.posteriors import Posterior, PosteriorList
from botorch.sampling.base import MCSampler
from botorch.sampling.list_sampler import ListSampler
from botorch.utils.datasets import BotorchDataset
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.transforms import is_fully_bayesian
from torch import Tensor
from torch.nn import Module, ModuleDict, ModuleList
Expand Down Expand Up @@ -169,10 +169,10 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
@classmethod
def construct_inputs(
cls,
training_data: Union[BotorchDataset, Dict[Hashable, BotorchDataset]],
training_data: Union[SupervisedDataset, Dict[Hashable, SupervisedDataset]],
**kwargs: Any,
) -> Dict[str, Any]:
r"""Construct `Model` keyword arguments from a dict of `BotorchDataset`."""
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`."""
from botorch.models.utils.parse_training_data import parse_training_data

return parse_training_data(cls, training_data, **kwargs)
Expand Down
31 changes: 9 additions & 22 deletions botorch/models/utils/parse_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@
from botorch.models.model import Model
from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP
from botorch.utils.datasets import (
BotorchDataset,
FixedNoiseDataset,
RankingDataset,
SupervisedDataset,
)
from botorch.utils.datasets import RankingDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher
from torch import cat, Tensor
from torch.nn.functional import pad
Expand All @@ -37,13 +32,13 @@ def _encoder(arg: Any) -> Type:

def parse_training_data(
consumer: Any,
training_data: Union[BotorchDataset, Dict[Hashable, BotorchDataset]],
training_data: Union[SupervisedDataset, Dict[Hashable, SupervisedDataset]],
**kwargs: Any,
) -> Dict[Hashable, Tensor]:
r"""Prepares a (collection of) datasets for consumption by a given object.
Args:
training_datas: A BoTorchDataset or dictionary thereof.
training_datas: A SupervisedDataset or dictionary thereof.
consumer: The object that will consume the parsed data, or type thereof.
Returns:
Expand All @@ -56,18 +51,10 @@ def parse_training_data(
def _parse_model_supervised(
consumer: Model, dataset: SupervisedDataset, **ignore: Any
) -> Dict[Hashable, Tensor]:
return {"train_X": dataset.X(), "train_Y": dataset.Y()}


@dispatcher.register(Model, FixedNoiseDataset)
def _parse_model_fixedNoise(
consumer: Model, dataset: FixedNoiseDataset, **ignore: Any
) -> Dict[Hashable, Tensor]:
return {
"train_X": dataset.X(),
"train_Y": dataset.Y(),
"train_Yvar": dataset.Yvar(),
}
parsed_data = {"train_X": dataset.X(), "train_Y": dataset.Y()}
if dataset.Yvar is not None:
parsed_data["train_Yvar"] = dataset.Yvar()
return parsed_data


@dispatcher.register(PairwiseGP, RankingDataset)
Expand All @@ -88,7 +75,7 @@ def _parse_pairwiseGP_ranking(
@dispatcher.register(Model, dict)
def _parse_model_dict(
consumer: Model,
training_data: Dict[Hashable, BotorchDataset],
training_data: Dict[Hashable, SupervisedDataset],
**kwargs: Any,
) -> Dict[Hashable, Tensor]:
if len(training_data) != 1:
Expand All @@ -102,7 +89,7 @@ def _parse_model_dict(
@dispatcher.register((MultiTaskGP, FixedNoiseMultiTaskGP), dict)
def _parse_multitask_dict(
consumer: Model,
training_data: Dict[Hashable, BotorchDataset],
training_data: Dict[Hashable, SupervisedDataset],
*,
task_feature: int = 0,
task_feature_container: Hashable = "train_X",
Expand Down
163 changes: 87 additions & 76 deletions botorch/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,67 +8,24 @@

from __future__ import annotations

from dataclasses import dataclass, fields, MISSING
from itertools import chain, count, repeat
import warnings
from itertools import count, repeat
from typing import Any, Dict, Hashable, Iterable, Optional, TypeVar, Union

from botorch.utils.containers import BotorchContainer, DenseContainer, SliceContainer
from torch import long, ones, Tensor
from typing_extensions import get_type_hints

T = TypeVar("T")
ContainerLike = Union[BotorchContainer, Tensor]
MaybeIterable = Union[T, Iterable[T]]


@dataclass
class BotorchDataset:
# TODO: Once v3.10 becomes standard, expose `validate_init` as a kw_only InitVar
def __post_init__(self, validate_init: bool = True) -> None:
if validate_init:
self._validate()
class SupervisedDataset:
r"""Base class for datasets consisting of labelled pairs `(X, Y)`
and an optional `Yvar` that stipulates observations variances so
that `Y[i] ~ N(f(X[i]), Yvar[i])`.
def _validate(self) -> None:
pass


class SupervisedDatasetMeta(type):
def __call__(cls, *args: Any, **kwargs: Any):
r"""Converts Tensor-valued fields to DenseContainer under the assumption
that said fields house collections of feature vectors."""
hints = get_type_hints(cls)
fields_iter = (item for item in fields(cls) if item.init is not None)
f_dict = {}
for value, field in chain(
zip(args, fields_iter),
((kwargs.pop(field.name, MISSING), field) for field in fields_iter),
):
if value is MISSING:
if field.default is not MISSING:
value = field.default
elif field.default_factory is not MISSING:
value = field.default_factory()
else:
raise RuntimeError(f"Missing required field `{field.name}`.")

if issubclass(hints[field.name], BotorchContainer):
if isinstance(value, Tensor):
value = DenseContainer(value, event_shape=value.shape[-1:])
elif not isinstance(value, BotorchContainer):
raise TypeError(
"Expected <BotorchContainer | Tensor> for field "
f"`{field.name}` but was {type(value)}."
)
f_dict[field.name] = value

return super().__call__(**f_dict, **kwargs)


@dataclass
class SupervisedDataset(BotorchDataset, metaclass=SupervisedDatasetMeta):
r"""Base class for datasets consisting of labelled pairs `(x, y)`.
This class object's `__call__` method converts Tensors `src` to
This class object's `__init__` method converts Tensors `src` to
DenseContainers under the assumption that `event_shape=src.shape[-1:]`.
Example:
Expand All @@ -87,6 +44,29 @@ class SupervisedDataset(BotorchDataset, metaclass=SupervisedDatasetMeta):

X: BotorchContainer
Y: BotorchContainer
Yvar: Optional[BotorchContainer]

def __init__(
self,
X: ContainerLike,
Y: ContainerLike,
Yvar: Optional[ContainerLike] = None,
validate_init: bool = True,
) -> None:
r"""Constructs a `SupervisedDataset`.
Args:
X: A `Tensor` or `BotorchContainer` representing the input features.
Y: A `Tensor` or `BotorchContainer` representing the outcomes.
Yvar: An optional `Tensor` or `BotorchContainer` representing
the observation noise.
validate_init: If `True`, validates the input shapes.
"""
self.X = _containerize(X)
self.Y = _containerize(Y)
self.Yvar = None if Yvar is None else _containerize(Yvar)
if validate_init:
self._validate()

def _validate(self) -> None:
shape_X = self.X.shape
Expand All @@ -95,12 +75,15 @@ def _validate(self) -> None:
shape_Y = shape_Y[: len(shape_Y) - len(self.Y.event_shape)]
if shape_X != shape_Y:
raise ValueError("Batch dimensions of `X` and `Y` are incompatible.")
if self.Yvar is not None and self.Yvar.shape != self.Y.shape:
raise ValueError("Shapes of `Y` and `Yvar` are incompatible.")

@classmethod
def dict_from_iter(
cls,
X: MaybeIterable[ContainerLike],
Y: MaybeIterable[ContainerLike],
Yvar: Optional[MaybeIterable[ContainerLike]] = None,
*,
keys: Optional[Iterable[Hashable]] = None,
) -> Dict[Hashable, SupervisedDataset]:
Expand All @@ -111,40 +94,46 @@ def dict_from_iter(
X = (X,) if single_Y else repeat(X)
if single_Y:
Y = (Y,) if single_X else repeat(Y)
return {key: cls(x, y) for key, x, y in zip(keys or count(), X, Y)}
Yvar = repeat(Yvar) if isinstance(Yvar, (Tensor, BotorchContainer)) else Yvar

# Pass in Yvar only if it is not None.
iterables = (X, Y) if Yvar is None else (X, Y, Yvar)
return {
elements[0]: cls(*elements[1:])
for elements in zip(keys or count(), *iterables)
}

def __eq__(self, other: Any) -> bool:
return (
type(other) is type(self)
and self.X == other.X
and self.Y == other.Y
and self.Yvar == other.Yvar
)


@dataclass
class FixedNoiseDataset(SupervisedDataset):
r"""A SupervisedDataset with an additional field `Yvar` that stipulates
observations variances so that `Y[i] ~ N(f(X[i]), Yvar[i])`."""
observations variances so that `Y[i] ~ N(f(X[i]), Yvar[i])`.
X: BotorchContainer
Y: BotorchContainer
Yvar: BotorchContainer

@classmethod
def dict_from_iter(
cls,
X: MaybeIterable[ContainerLike],
Y: MaybeIterable[ContainerLike],
Yvar: Optional[MaybeIterable[ContainerLike]] = None,
*,
keys: Optional[Iterable[Hashable]] = None,
) -> Dict[Hashable, SupervisedDataset]:
r"""Returns a dictionary of `FixedNoiseDataset` from iterables."""
single_X = isinstance(X, (Tensor, BotorchContainer))
single_Y = isinstance(Y, (Tensor, BotorchContainer))
if single_X:
X = (X,) if single_Y else repeat(X)
if single_Y:
Y = (Y,) if single_X else repeat(Y)
NOTE: This is deprecated. Use `SupervisedDataset` instead.
"""

Yvar = repeat(Yvar) if isinstance(Yvar, (Tensor, BotorchContainer)) else Yvar
return {key: cls(x, y, c) for key, x, y, c in zip(keys or count(), X, Y, Yvar)}
def __init__(
self,
X: ContainerLike,
Y: ContainerLike,
Yvar: ContainerLike,
validate_init: bool = True,
) -> None:
r"""Initialize a `FixedNoiseDataset` -- deprecated!"""
warnings.warn(
"`FixedNoiseDataset` is deprecated. Use `SupervisedDataset` instead.",
DeprecationWarning,
)
super().__init__(X=X, Y=Y, Yvar=Yvar, validate_init=validate_init)


@dataclass
class RankingDataset(SupervisedDataset):
r"""A SupervisedDataset whose labelled pairs `(x, y)` consist of m-ary combinations
`x ∈ Z^{m}` of elements from a ground set `Z = (z_1, ...)` and ranking vectors
Expand Down Expand Up @@ -173,6 +162,18 @@ class RankingDataset(SupervisedDataset):
X: SliceContainer
Y: BotorchContainer

def __init__(
self, X: SliceContainer, Y: ContainerLike, validate_init: bool = True
) -> None:
r"""Construct a `RankingDataset`.
Args:
X: A `SliceContainer` representing the input features being ranked.
Y: A `Tensor` or `BotorchContainer` representing the rankings.
validate_init: If `True`, validates the input shapes.
"""
super().__init__(X=X, Y=Y, Yvar=None, validate_init=validate_init)

def _validate(self) -> None:
super()._validate()

Expand Down Expand Up @@ -201,3 +202,13 @@ def _validate(self) -> None:

# Same as: torch.where(y_diff == 0, y_incr + 1, 1)
y_incr = y_incr - y_diff + 1


def _containerize(value: ContainerLike) -> BotorchContainer:
r"""Converts Tensor-valued arguments to DenseContainer under the assumption
that said arguments house collections of feature vectors.
"""
if isinstance(value, Tensor):
return DenseContainer(value, event_shape=value.shape[-1:])
else:
return value
7 changes: 2 additions & 5 deletions test/models/test_fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from botorch.models.transforms import Normalize, Standardize
from botorch.posteriors.fully_bayesian import batched_bisect, FullyBayesianPosterior
from botorch.sampling.get_sampler import get_sampler
from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
NondominatedPartitioning,
)
Expand Down Expand Up @@ -550,10 +550,7 @@ def test_construct_inputs(self):
X, Y, Yvar, model = self._get_data_and_model(
infer_noise=infer_noise, **tkwargs
)
if infer_noise:
training_data = SupervisedDataset(X, Y)
else:
training_data = FixedNoiseDataset(X, Y, Yvar)
training_data = SupervisedDataset(X, Y, Yvar)

data_dict = model.construct_inputs(training_data)
self.assertTrue(X.equal(data_dict["train_X"]))
Expand Down
4 changes: 2 additions & 2 deletions test/models/test_gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from botorch.models.utils import add_output_dim
from botorch.posteriors import GPyTorchPosterior
from botorch.sampling import SobolQMCNormalSampler
from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.sampling import manual_seed
from botorch.utils.testing import _get_random_data, BotorchTestCase
from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel
Expand Down Expand Up @@ -450,7 +450,7 @@ def test_construct_inputs(self):
X = model_kwargs["train_X"]
Y = model_kwargs["train_Y"]
Yvar = model_kwargs["train_Yvar"]
training_data = FixedNoiseDataset(X, Y, Yvar)
training_data = SupervisedDataset(X, Y, Yvar)
data_dict = model.construct_inputs(training_data)
self.assertTrue(X.equal(data_dict["train_X"]))
self.assertTrue(Y.equal(data_dict["train_Y"]))
Expand Down
Loading

0 comments on commit 3506538

Please sign in to comment.