diff --git a/botorch/acquisition/input_constructors.py b/botorch/acquisition/input_constructors.py index 99f5d6de10..ae000201f1 100644 --- a/botorch/acquisition/input_constructors.py +++ b/botorch/acquisition/input_constructors.py @@ -99,7 +99,7 @@ from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler from botorch.utils.constraints import get_outcome_constraint_transforms from botorch.utils.containers import BotorchContainer -from botorch.utils.datasets import BotorchDataset, SupervisedDataset +from botorch.utils.datasets import SupervisedDataset from botorch.utils.multi_objective.box_decompositions.non_dominated import ( FastNondominatedPartitioning, NondominatedPartitioning, @@ -114,7 +114,7 @@ def _field_is_shared( - datasets: Union[Iterable[BotorchDataset], Dict[Hashable, BotorchDataset]], + datasets: Union[Iterable[SupervisedDataset], Dict[Hashable, SupervisedDataset]], fieldname: Hashable, ) -> bool: r"""Determines whether or not a given field is shared by all datasets.""" @@ -136,7 +136,7 @@ def _field_is_shared( def _get_dataset_field( - dataset: MaybeDict[BotorchDataset], + dataset: MaybeDict[SupervisedDataset], fieldname: str, transform: Optional[Callable[[BotorchContainer], Any]] = None, join_rule: Optional[Callable[[Sequence[Any]], Any]] = None, diff --git a/botorch/models/gp_regression_mixed.py b/botorch/models/gp_regression_mixed.py index 668035541f..52fdb6c849 100644 --- a/botorch/models/gp_regression_mixed.py +++ b/botorch/models/gp_regression_mixed.py @@ -185,7 +185,7 @@ def construct_inputs( likelihood: Optional[Likelihood] = None, **kwargs: Any, ) -> Dict[str, Any]: - r"""Construct `Model` keyword arguments from a dict of `BotorchDataset`. + r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`. Args: training_data: A `SupervisedDataset` containing the training data. diff --git a/botorch/models/model.py b/botorch/models/model.py index de614efc0a..179ba66c3e 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -38,7 +38,7 @@ from botorch.posteriors import Posterior, PosteriorList from botorch.sampling.base import MCSampler from botorch.sampling.list_sampler import ListSampler -from botorch.utils.datasets import BotorchDataset +from botorch.utils.datasets import SupervisedDataset from botorch.utils.transforms import is_fully_bayesian from torch import Tensor from torch.nn import Module, ModuleDict, ModuleList @@ -169,10 +169,10 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode @classmethod def construct_inputs( cls, - training_data: Union[BotorchDataset, Dict[Hashable, BotorchDataset]], + training_data: Union[SupervisedDataset, Dict[Hashable, SupervisedDataset]], **kwargs: Any, ) -> Dict[str, Any]: - r"""Construct `Model` keyword arguments from a dict of `BotorchDataset`.""" + r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.""" from botorch.models.utils.parse_training_data import parse_training_data return parse_training_data(cls, training_data, **kwargs) diff --git a/botorch/models/utils/parse_training_data.py b/botorch/models/utils/parse_training_data.py index 0988ba3380..6194e122ef 100644 --- a/botorch/models/utils/parse_training_data.py +++ b/botorch/models/utils/parse_training_data.py @@ -16,12 +16,7 @@ from botorch.models.model import Model from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP -from botorch.utils.datasets import ( - BotorchDataset, - FixedNoiseDataset, - RankingDataset, - SupervisedDataset, -) +from botorch.utils.datasets import RankingDataset, SupervisedDataset from botorch.utils.dispatcher import Dispatcher from torch import cat, Tensor from torch.nn.functional import pad @@ -37,13 +32,13 @@ def _encoder(arg: Any) -> Type: def parse_training_data( consumer: Any, - training_data: Union[BotorchDataset, Dict[Hashable, BotorchDataset]], + training_data: Union[SupervisedDataset, Dict[Hashable, SupervisedDataset]], **kwargs: Any, ) -> Dict[Hashable, Tensor]: r"""Prepares a (collection of) datasets for consumption by a given object. Args: - training_datas: A BoTorchDataset or dictionary thereof. + training_datas: A SupervisedDataset or dictionary thereof. consumer: The object that will consume the parsed data, or type thereof. Returns: @@ -56,18 +51,10 @@ def parse_training_data( def _parse_model_supervised( consumer: Model, dataset: SupervisedDataset, **ignore: Any ) -> Dict[Hashable, Tensor]: - return {"train_X": dataset.X(), "train_Y": dataset.Y()} - - -@dispatcher.register(Model, FixedNoiseDataset) -def _parse_model_fixedNoise( - consumer: Model, dataset: FixedNoiseDataset, **ignore: Any -) -> Dict[Hashable, Tensor]: - return { - "train_X": dataset.X(), - "train_Y": dataset.Y(), - "train_Yvar": dataset.Yvar(), - } + parsed_data = {"train_X": dataset.X(), "train_Y": dataset.Y()} + if dataset.Yvar is not None: + parsed_data["train_Yvar"] = dataset.Yvar() + return parsed_data @dispatcher.register(PairwiseGP, RankingDataset) @@ -88,7 +75,7 @@ def _parse_pairwiseGP_ranking( @dispatcher.register(Model, dict) def _parse_model_dict( consumer: Model, - training_data: Dict[Hashable, BotorchDataset], + training_data: Dict[Hashable, SupervisedDataset], **kwargs: Any, ) -> Dict[Hashable, Tensor]: if len(training_data) != 1: @@ -102,7 +89,7 @@ def _parse_model_dict( @dispatcher.register((MultiTaskGP, FixedNoiseMultiTaskGP), dict) def _parse_multitask_dict( consumer: Model, - training_data: Dict[Hashable, BotorchDataset], + training_data: Dict[Hashable, SupervisedDataset], *, task_feature: int = 0, task_feature_container: Hashable = "train_X", diff --git a/botorch/utils/datasets.py b/botorch/utils/datasets.py index bfe12aa047..60621fb8ab 100644 --- a/botorch/utils/datasets.py +++ b/botorch/utils/datasets.py @@ -8,67 +8,24 @@ from __future__ import annotations -from dataclasses import dataclass, fields, MISSING -from itertools import chain, count, repeat +import warnings +from itertools import count, repeat from typing import Any, Dict, Hashable, Iterable, Optional, TypeVar, Union from botorch.utils.containers import BotorchContainer, DenseContainer, SliceContainer from torch import long, ones, Tensor -from typing_extensions import get_type_hints T = TypeVar("T") ContainerLike = Union[BotorchContainer, Tensor] MaybeIterable = Union[T, Iterable[T]] -@dataclass -class BotorchDataset: - # TODO: Once v3.10 becomes standard, expose `validate_init` as a kw_only InitVar - def __post_init__(self, validate_init: bool = True) -> None: - if validate_init: - self._validate() +class SupervisedDataset: + r"""Base class for datasets consisting of labelled pairs `(X, Y)` + and an optional `Yvar` that stipulates observations variances so + that `Y[i] ~ N(f(X[i]), Yvar[i])`. - def _validate(self) -> None: - pass - - -class SupervisedDatasetMeta(type): - def __call__(cls, *args: Any, **kwargs: Any): - r"""Converts Tensor-valued fields to DenseContainer under the assumption - that said fields house collections of feature vectors.""" - hints = get_type_hints(cls) - fields_iter = (item for item in fields(cls) if item.init is not None) - f_dict = {} - for value, field in chain( - zip(args, fields_iter), - ((kwargs.pop(field.name, MISSING), field) for field in fields_iter), - ): - if value is MISSING: - if field.default is not MISSING: - value = field.default - elif field.default_factory is not MISSING: - value = field.default_factory() - else: - raise RuntimeError(f"Missing required field `{field.name}`.") - - if issubclass(hints[field.name], BotorchContainer): - if isinstance(value, Tensor): - value = DenseContainer(value, event_shape=value.shape[-1:]) - elif not isinstance(value, BotorchContainer): - raise TypeError( - "Expected for field " - f"`{field.name}` but was {type(value)}." - ) - f_dict[field.name] = value - - return super().__call__(**f_dict, **kwargs) - - -@dataclass -class SupervisedDataset(BotorchDataset, metaclass=SupervisedDatasetMeta): - r"""Base class for datasets consisting of labelled pairs `(x, y)`. - - This class object's `__call__` method converts Tensors `src` to + This class object's `__init__` method converts Tensors `src` to DenseContainers under the assumption that `event_shape=src.shape[-1:]`. Example: @@ -87,6 +44,29 @@ class SupervisedDataset(BotorchDataset, metaclass=SupervisedDatasetMeta): X: BotorchContainer Y: BotorchContainer + Yvar: Optional[BotorchContainer] + + def __init__( + self, + X: ContainerLike, + Y: ContainerLike, + Yvar: Optional[ContainerLike] = None, + validate_init: bool = True, + ) -> None: + r"""Constructs a `SupervisedDataset`. + + Args: + X: A `Tensor` or `BotorchContainer` representing the input features. + Y: A `Tensor` or `BotorchContainer` representing the outcomes. + Yvar: An optional `Tensor` or `BotorchContainer` representing + the observation noise. + validate_init: If `True`, validates the input shapes. + """ + self.X = _containerize(X) + self.Y = _containerize(Y) + self.Yvar = None if Yvar is None else _containerize(Yvar) + if validate_init: + self._validate() def _validate(self) -> None: shape_X = self.X.shape @@ -95,12 +75,15 @@ def _validate(self) -> None: shape_Y = shape_Y[: len(shape_Y) - len(self.Y.event_shape)] if shape_X != shape_Y: raise ValueError("Batch dimensions of `X` and `Y` are incompatible.") + if self.Yvar is not None and self.Yvar.shape != self.Y.shape: + raise ValueError("Shapes of `Y` and `Yvar` are incompatible.") @classmethod def dict_from_iter( cls, X: MaybeIterable[ContainerLike], Y: MaybeIterable[ContainerLike], + Yvar: Optional[MaybeIterable[ContainerLike]] = None, *, keys: Optional[Iterable[Hashable]] = None, ) -> Dict[Hashable, SupervisedDataset]: @@ -111,40 +94,46 @@ def dict_from_iter( X = (X,) if single_Y else repeat(X) if single_Y: Y = (Y,) if single_X else repeat(Y) - return {key: cls(x, y) for key, x, y in zip(keys or count(), X, Y)} + Yvar = repeat(Yvar) if isinstance(Yvar, (Tensor, BotorchContainer)) else Yvar + + # Pass in Yvar only if it is not None. + iterables = (X, Y) if Yvar is None else (X, Y, Yvar) + return { + elements[0]: cls(*elements[1:]) + for elements in zip(keys or count(), *iterables) + } + + def __eq__(self, other: Any) -> bool: + return ( + type(other) is type(self) + and self.X == other.X + and self.Y == other.Y + and self.Yvar == other.Yvar + ) -@dataclass class FixedNoiseDataset(SupervisedDataset): r"""A SupervisedDataset with an additional field `Yvar` that stipulates - observations variances so that `Y[i] ~ N(f(X[i]), Yvar[i])`.""" + observations variances so that `Y[i] ~ N(f(X[i]), Yvar[i])`. - X: BotorchContainer - Y: BotorchContainer - Yvar: BotorchContainer - - @classmethod - def dict_from_iter( - cls, - X: MaybeIterable[ContainerLike], - Y: MaybeIterable[ContainerLike], - Yvar: Optional[MaybeIterable[ContainerLike]] = None, - *, - keys: Optional[Iterable[Hashable]] = None, - ) -> Dict[Hashable, SupervisedDataset]: - r"""Returns a dictionary of `FixedNoiseDataset` from iterables.""" - single_X = isinstance(X, (Tensor, BotorchContainer)) - single_Y = isinstance(Y, (Tensor, BotorchContainer)) - if single_X: - X = (X,) if single_Y else repeat(X) - if single_Y: - Y = (Y,) if single_X else repeat(Y) + NOTE: This is deprecated. Use `SupervisedDataset` instead. + """ - Yvar = repeat(Yvar) if isinstance(Yvar, (Tensor, BotorchContainer)) else Yvar - return {key: cls(x, y, c) for key, x, y, c in zip(keys or count(), X, Y, Yvar)} + def __init__( + self, + X: ContainerLike, + Y: ContainerLike, + Yvar: ContainerLike, + validate_init: bool = True, + ) -> None: + r"""Initialize a `FixedNoiseDataset` -- deprecated!""" + warnings.warn( + "`FixedNoiseDataset` is deprecated. Use `SupervisedDataset` instead.", + DeprecationWarning, + ) + super().__init__(X=X, Y=Y, Yvar=Yvar, validate_init=validate_init) -@dataclass class RankingDataset(SupervisedDataset): r"""A SupervisedDataset whose labelled pairs `(x, y)` consist of m-ary combinations `x ∈ Z^{m}` of elements from a ground set `Z = (z_1, ...)` and ranking vectors @@ -173,6 +162,18 @@ class RankingDataset(SupervisedDataset): X: SliceContainer Y: BotorchContainer + def __init__( + self, X: SliceContainer, Y: ContainerLike, validate_init: bool = True + ) -> None: + r"""Construct a `RankingDataset`. + + Args: + X: A `SliceContainer` representing the input features being ranked. + Y: A `Tensor` or `BotorchContainer` representing the rankings. + validate_init: If `True`, validates the input shapes. + """ + super().__init__(X=X, Y=Y, Yvar=None, validate_init=validate_init) + def _validate(self) -> None: super()._validate() @@ -201,3 +202,13 @@ def _validate(self) -> None: # Same as: torch.where(y_diff == 0, y_incr + 1, 1) y_incr = y_incr - y_diff + 1 + + +def _containerize(value: ContainerLike) -> BotorchContainer: + r"""Converts Tensor-valued arguments to DenseContainer under the assumption + that said arguments house collections of feature vectors. + """ + if isinstance(value, Tensor): + return DenseContainer(value, event_shape=value.shape[-1:]) + else: + return value diff --git a/test/models/test_fully_bayesian.py b/test/models/test_fully_bayesian.py index 2bdb7d6c0f..654babebe3 100644 --- a/test/models/test_fully_bayesian.py +++ b/test/models/test_fully_bayesian.py @@ -47,7 +47,7 @@ from botorch.models.transforms import Normalize, Standardize from botorch.posteriors.fully_bayesian import batched_bisect, FullyBayesianPosterior from botorch.sampling.get_sampler import get_sampler -from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset +from botorch.utils.datasets import SupervisedDataset from botorch.utils.multi_objective.box_decompositions.non_dominated import ( NondominatedPartitioning, ) @@ -550,10 +550,7 @@ def test_construct_inputs(self): X, Y, Yvar, model = self._get_data_and_model( infer_noise=infer_noise, **tkwargs ) - if infer_noise: - training_data = SupervisedDataset(X, Y) - else: - training_data = FixedNoiseDataset(X, Y, Yvar) + training_data = SupervisedDataset(X, Y, Yvar) data_dict = model.construct_inputs(training_data) self.assertTrue(X.equal(data_dict["train_X"])) diff --git a/test/models/test_gp_regression.py b/test/models/test_gp_regression.py index 86a5039b5f..2ac4f8f835 100644 --- a/test/models/test_gp_regression.py +++ b/test/models/test_gp_regression.py @@ -20,7 +20,7 @@ from botorch.models.utils import add_output_dim from botorch.posteriors import GPyTorchPosterior from botorch.sampling import SobolQMCNormalSampler -from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset +from botorch.utils.datasets import SupervisedDataset from botorch.utils.sampling import manual_seed from botorch.utils.testing import _get_random_data, BotorchTestCase from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel @@ -450,7 +450,7 @@ def test_construct_inputs(self): X = model_kwargs["train_X"] Y = model_kwargs["train_Y"] Yvar = model_kwargs["train_Yvar"] - training_data = FixedNoiseDataset(X, Y, Yvar) + training_data = SupervisedDataset(X, Y, Yvar) data_dict = model.construct_inputs(training_data) self.assertTrue(X.equal(data_dict["train_X"])) self.assertTrue(Y.equal(data_dict["train_Y"])) diff --git a/test/models/test_gp_regression_fidelity.py b/test/models/test_gp_regression_fidelity.py index 777c45b130..ad5f748dd3 100644 --- a/test/models/test_gp_regression_fidelity.py +++ b/test/models/test_gp_regression_fidelity.py @@ -20,7 +20,7 @@ from botorch.models.transforms import Normalize, Standardize from botorch.posteriors import GPyTorchPosterior from botorch.sampling import SobolQMCNormalSampler -from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset +from botorch.utils.datasets import SupervisedDataset from botorch.utils.testing import _get_random_data, BotorchTestCase from gpytorch.kernels.scale_kernel import ScaleKernel from gpytorch.likelihoods import FixedNoiseGaussianLikelihood @@ -487,7 +487,7 @@ def test_construct_inputs(self): self.assertTrue("train_Yvar" not in data_dict) # len(Xs) == len(Ys) == 1 - training_data = FixedNoiseDataset( + training_data = SupervisedDataset( X=kwargs["train_X"], Y=kwargs["train_Y"], Yvar=torch.full(kwargs["train_Y"].shape[:-1] + (1,), 0.1), diff --git a/test/models/test_multitask.py b/test/models/test_multitask.py index ca27c47709..280f8c05ff 100644 --- a/test/models/test_multitask.py +++ b/test/models/test_multitask.py @@ -22,7 +22,7 @@ from botorch.models.transforms.outcome import Standardize from botorch.posteriors import GPyTorchPosterior from botorch.posteriors.transformed import TransformedPosterior -from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset +from botorch.utils.datasets import SupervisedDataset from botorch.utils.testing import BotorchTestCase from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal from gpytorch.kernels import ( @@ -59,7 +59,7 @@ def _gen_datasets(yvar: Optional[float] = None, **tkwargs): Yvar1 = torch.full_like(Y1, yvar) Yvar2 = torch.full_like(Y2, yvar) train_Yvar = torch.cat([Yvar1, Yvar2]) - datasets = {0: FixedNoiseDataset(X, Y1, Yvar1), 1: FixedNoiseDataset(X, Y2, Yvar2)} + datasets = {0: SupervisedDataset(X, Y1, Yvar1), 1: SupervisedDataset(X, Y2, Yvar2)} return datasets, (train_X, train_Y, train_Yvar) diff --git a/test/models/utils/test_parse_training_data.py b/test/models/utils/test_parse_training_data.py index f53128fdef..cd521664ac 100644 --- a/test/models/utils/test_parse_training_data.py +++ b/test/models/utils/test_parse_training_data.py @@ -67,9 +67,9 @@ def test_dict(self): with self.assertRaisesRegex(UnsupportedError, "multiple datasets to single"): parse_training_data(Model, datasets) + _datasets = datasets.copy() + _datasets[m] = SupervisedDataset(rand(n, 2), rand(n, 1), rand(n, 1)) with self.assertRaisesRegex(UnsupportedError, "Cannot combine .* hetero"): - _datasets = datasets.copy() - _datasets[m] = FixedNoiseDataset(rand(n, 2), rand(n, 1), rand(n, 1)) parse_training_data(MultiTaskGP, _datasets) with self.assertRaisesRegex(ValueError, "Missing required term"): diff --git a/test/utils/test_datasets.py b/test/utils/test_datasets.py index 02c9234f55..fd60509c7c 100644 --- a/test/utils/test_datasets.py +++ b/test/utils/test_datasets.py @@ -4,85 +4,19 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import field, make_dataclass -from unittest.mock import patch - from botorch.utils.containers import DenseContainer, SliceContainer -from botorch.utils.datasets import ( - BotorchDataset, - FixedNoiseDataset, - RankingDataset, - SupervisedDataset, -) +from botorch.utils.datasets import FixedNoiseDataset, RankingDataset, SupervisedDataset from botorch.utils.testing import BotorchTestCase -from torch import rand, randperm, Size, stack, tensor, Tensor +from torch import rand, randperm, Size, stack, tensor class TestDatasets(BotorchTestCase): - def test_base(self): - with patch.object(BotorchDataset, "_validate", new=lambda self: 1 / 0): - with self.assertRaises(ZeroDivisionError): - BotorchDataset() - - dataset = BotorchDataset() - self.assertTrue(dataset._validate() is None) - - def test_supervised_meta(self): - X = rand(3, 2) - Y = rand(3, 1) - t = rand(3, 5) - A = DenseContainer(t, event_shape=Size([5])) - B = rand(2, 1) - - SupervisedDatasetWithDefaults = make_dataclass( - cls_name="SupervisedDatasetWithDefaults", - bases=(SupervisedDataset,), - fields=[ - ("default", DenseContainer, field(default=t)), - ("factory", DenseContainer, field(default_factory=lambda: A)), - ("other", Tensor, field(default_factory=lambda: B)), - ], - ) - - # Check that call signature is property enforced - with self.assertRaisesRegex(RuntimeError, "Missing .* `X`"): - SupervisedDatasetWithDefaults(Y=Y) - - with self.assertRaisesRegex(RuntimeError, "Missing .* `Y`"): - SupervisedDatasetWithDefaults(X=X) - - with self.assertRaisesRegex(TypeError, "Expected "): - SupervisedDatasetWithDefaults(X=X, Y=Y.tolist()) - - # Check handling of default values and factories - dataset = SupervisedDatasetWithDefaults(X=X, Y=Y) - self.assertIsInstance(dataset.default, DenseContainer) - self.assertEqual(dataset.default, A) - self.assertEqual(dataset.factory, A) - self.assertTrue(dataset.other is B) - - # Check type coercion - dataset = SupervisedDatasetWithDefaults(X=X, Y=Y, default=X, factory=Y, other=B) - self.assertIsInstance(dataset.X, DenseContainer) - self.assertIsInstance(dataset.Y, DenseContainer) - self.assertEqual(dataset.default, dataset.X) - self.assertEqual(dataset.factory, dataset.Y) - self.assertTrue(dataset.other is B) - - # Check handling of positional arguments - dataset = SupervisedDatasetWithDefaults(X, Y, X, Y, X) - self.assertIsInstance(dataset.X, DenseContainer) - self.assertIsInstance(dataset.Y, DenseContainer) - self.assertEqual(dataset.default, dataset.X) - self.assertEqual(dataset.factory, dataset.Y) - self.assertTrue(dataset.other is X) - def test_supervised(self): # Generate some data Xs = rand(4, 3, 2) Ys = rand(4, 3, 1) - # Test `__post_init__` + # Test `__init__` dataset = SupervisedDataset(X=Xs[0], Y=Ys[0]) for name in ("X", "Y"): field = getattr(dataset, name) @@ -133,6 +67,11 @@ def test_fixedNoise(self): self.assertTrue(Xs[0].equal(dataset.X())) self.assertTrue(Ys[1].equal(dataset.Y())) + with self.assertRaisesRegex( + ValueError, "`Y` and `Yvar`" + ), self.assertWarnsRegex(DeprecationWarning, "SupervisedDataset"): + FixedNoiseDataset(X=Xs, Y=Ys, Yvar=Ys_var[0]) + def test_ranking(self): # Test `_validate` X_val = rand(16, 2)