Skip to content

Commit

Permalink
introduce trial_indices argument to SupervisedDataset (pytorch#2595)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebook/Ax#2960


Adds optional `trial_indices` to SupervisedDataset, whose dimensionality should correspond 1:1 with the first few dimensions of X and Y tensors, as validated in `_validate` ([pointer](https://www.internalfb.com/diff/D64764019?permalink=1739375523489084)).

Reviewed By: Balandat

Differential Revision: D64764019
  • Loading branch information
bernardbeckerman authored and facebook-github-bot committed Oct 25, 2024
1 parent 9d37e90 commit 932c56f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
29 changes: 28 additions & 1 deletion botorch/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from botorch.exceptions.errors import InputDataError, UnsupportedError
from botorch.utils.containers import BotorchContainer, SliceContainer
from pyre_extensions import none_throws
from torch import long, ones, Tensor


Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(
outcome_names: list[str],
Yvar: BotorchContainer | Tensor | None = None,
validate_init: bool = True,
trial_indices: Tensor | None = None,
) -> None:
r"""Constructs a `SupervisedDataset`.
Expand All @@ -65,12 +67,16 @@ def __init__(
Yvar: An optional `Tensor` or `BotorchContainer` representing
the observation noise.
validate_init: If `True`, validates the input shapes.
trial_indices: A `Tensor` representing the trial indices of X and Y. This is
used to support learning-curve-based modeling. If provided, it must
have compatible shape with X and Y.
"""
self._X = X
self._Y = Y
self._Yvar = Yvar
self.feature_names = feature_names
self.outcome_names = outcome_names
self.trial_indices = trial_indices
if validate_init:
self._validate()

Expand All @@ -96,6 +102,7 @@ def _validate(
self,
validate_feature_names: bool = True,
validate_outcome_names: bool = True,
validate_trial_indices: bool = True,
) -> None:
r"""Checks that the shapes of the inputs are compatible with each other.
Expand All @@ -108,6 +115,8 @@ def _validate(
`outcomes_names` matches the # of columns of `self.Y`. If a
particular dataset, e.g., `RankingDataset`, is known to violate
this assumption, this can be set to `False`.
validate_trial_indices: By default, we validate that the shape of
`trial_indices` matches the shape of X and Y.
"""
shape_X = self.X.shape
if isinstance(self._X, BotorchContainer):
Expand All @@ -133,8 +142,19 @@ def _validate(
"`Y` must have the same number of columns as the number of "
"outcomes in `outcome_names`."
)
if validate_trial_indices and self.trial_indices is not None:
if self.trial_indices.shape != shape_X:
raise ValueError(
f"{shape_X=} must have the same shape as {none_throws(self.trial_indices).shape=}."
)

def __eq__(self, other: Any) -> bool:
if self.trial_indices is None and other.trial_indices is None:
trial_indices_equal = True
elif self.trial_indices is None or other.trial_indices is None:
trial_indices_equal = False
else:
trial_indices_equal = torch.equal(self.trial_indices, other.trial_indices)
return (
type(other) is type(self)
and torch.equal(self.X, other.X)
Expand All @@ -146,6 +166,7 @@ def __eq__(self, other: Any) -> bool:
)
and self.feature_names == other.feature_names
and self.outcome_names == other.outcome_names
and trial_indices_equal
)


Expand Down Expand Up @@ -241,7 +262,11 @@ def __init__(
)

def _validate(self) -> None:
super()._validate(validate_feature_names=False, validate_outcome_names=False)
super()._validate(
validate_feature_names=False,
validate_outcome_names=False,
validate_trial_indices=False,
)
if len(self.feature_names) != self._X.values.shape[-1]:
raise ValueError(
"The `values` field of `X` must have the same number of columns as "
Expand Down Expand Up @@ -316,6 +341,7 @@ def __init__(
self.has_heterogeneous_features = any(
datasets[0].feature_names != ds.feature_names for ds in datasets[1:]
)
self.trial_indices = None

@classmethod
def from_joint_dataset(
Expand Down Expand Up @@ -538,6 +564,7 @@ def __init__(
c: [self.feature_names.index(i) for i in parameter_decomposition[c]]
for c in self.context_buckets
}
self.trial_indices = None

@property
def X(self) -> Tensor:
Expand Down
22 changes: 19 additions & 3 deletions test/utils/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,35 @@ def make_dataset(
class TestDatasets(BotorchTestCase):
def test_supervised(self):
# Generate some data
X = rand(3, 2)
Y = rand(3, 1)
n_rows = 3
X = rand(n_rows, 2)
Y = rand(n_rows, 1)
feature_names = ["x1", "x2"]
outcome_names = ["y"]
trial_indices = tensor(range(n_rows))

# Test `__init__`
dataset = SupervisedDataset(
X=X, Y=Y, feature_names=feature_names, outcome_names=outcome_names
X=X,
Y=Y,
feature_names=feature_names,
outcome_names=outcome_names,
trial_indices=trial_indices,
)
self.assertIsInstance(dataset.X, Tensor)
self.assertIsInstance(dataset._X, Tensor)
self.assertIsInstance(dataset.Y, Tensor)
self.assertIsInstance(dataset._Y, Tensor)
self.assertEqual(dataset.feature_names, feature_names)
self.assertEqual(dataset.outcome_names, outcome_names)
self.assertTrue(torch.equal(dataset.trial_indices, trial_indices))

dataset2 = SupervisedDataset(
X=DenseContainer(X, X.shape[-1:]),
Y=DenseContainer(Y, Y.shape[-1:]),
feature_names=feature_names,
outcome_names=outcome_names,
trial_indices=trial_indices,
)
self.assertIsInstance(dataset2.X, Tensor)
self.assertIsInstance(dataset2._X, DenseContainer)
Expand Down Expand Up @@ -101,6 +109,14 @@ def test_supervised(self):
feature_names=feature_names,
outcome_names=[],
)
with self.assertRaisesRegex(ValueError, "trial_indices"):
SupervisedDataset(
X=rand(2, 2),
Y=rand(2, 1),
feature_names=feature_names,
outcome_names=outcome_names,
trial_indices=tensor(range(n_rows + 1)),
)

# Test with Yvar.
dataset = SupervisedDataset(
Expand Down

0 comments on commit 932c56f

Please sign in to comment.