Skip to content

Commit

Permalink
Implement MultiTaskDataset (pytorch#2019)
Browse files Browse the repository at this point in the history
Summary:

Introduces a `MultiTaskDataset` for carrying the datasets for individual tasks along with the relevant metadata.

> This is a multi-task dataset that is constructed from the datasets of individual tasks. It offers functionality to combine parts of individual datasets to construct the inputs necessary for the `MultiTaskGP` models.

Differential Revision: D49509321
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Sep 21, 2023
1 parent af38927 commit 86103d7
Show file tree
Hide file tree
Showing 2 changed files with 350 additions and 5 deletions.
208 changes: 204 additions & 4 deletions botorch/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
from __future__ import annotations

import warnings
from typing import Any, Iterable, List, Optional, TypeVar, Union
from typing import Any, Dict, List, Optional, Union

import torch
from botorch.exceptions.errors import InputDataError, UnsupportedError
from botorch.utils.containers import BotorchContainer, SliceContainer
from torch import long, ones, Tensor

T = TypeVar("T")
MaybeIterable = Union[T, Iterable[T]]


class SupervisedDataset:
r"""Base class for datasets consisting of labelled pairs `(X, Y)`
Expand Down Expand Up @@ -273,3 +271,205 @@ def _validate(self) -> None:

# Same as: torch.where(y_diff == 0, y_incr + 1, 1)
y_incr = y_incr - y_diff + 1


class MultiTaskDataset(SupervisedDataset):
"""This is a multi-task dataset that is constructed from the datasets of
individual tasks. It offers functionality to combine parts of individual
datasets to construct the inputs necessary for the `MultiTaskGP` models.
"""

def __init__(
self,
datasets: List[SupervisedDataset],
target_outcome_name: str,
task_feature_index: Optional[int] = None,
):
"""Construct a `MultiTaskDataset`.
Args:
datasets: A list of the datasets of individual tasks. Each dataset
is expected to contain data for only one outcome.
target_outcome_name: Name of the target outcome to be modeled.
task_feature_index: If the task feature is included in the `X`s of the
individual datasets, this should be used to specify its index.
If omitted, the task feature will be appended while concatenating Xs.
If given, we sanity-check that the names of the task features
match between all datasets.
"""
self.datasets: Dict[str, SupervisedDataset] = {
ds.outcome_names[0]: ds for ds in datasets
}
self.target_outcome_name = target_outcome_name
self.task_feature_index = task_feature_index
self._validate_datasets(datasets=datasets)
self.feature_names = self.datasets[target_outcome_name].feature_names
self.outcome_names = [target_outcome_name]

@classmethod
def from_joint_dataset(
cls,
dataset: SupervisedDataset,
task_feature_index: int,
target_task_value: int,
outcome_names_per_task: Optional[Dict[int, str]] = None,
) -> MultiTaskDataset:
r"""Construct a `MultiTaskDataset` from a joint dataset that includes the
data for all tasks with the task feature index.
This will break down the joint dataset into individual datasets by the value
of the task feature. Each resulting dataset will have its outcome name set
based on `outcome_names_per_task`, with the missing values defaulting to
`task_[task_feature]` (except for the target task, which will retain the
original outcome name from the dataset).
Args:
dataset: The joint dataset.
task_feature_index: The column index of the task feature in `dataset.X`.
target_task_value: The value of the task feature for the target task
in the dataset. The data for the target task is filtered according to
`dataset.X[task_feature_index] == target_task_value`.
outcome_names_per_task: Optional dictionary mapping task feature values
to the outcome names for each task. If not provided, the auxillary
tasks will be named `task_[task_feature]` and the target task will
retain the outcome name from the dataset.
Returns:
A `MultiTaskDataset` instance.
"""
if len(dataset.outcome_names) > 1:
raise UnsupportedError(
"Dataset containing more than one outcome is not supported. "
f"Got {dataset.outcome_names=}."
)
outcome_names_per_task = outcome_names_per_task or {}
# Split datasets by task feature.
datasets = []
all_task_features = dataset.X[:, task_feature_index]
for task_value in all_task_features.unique().long().tolist():
default_name = (
dataset.outcome_names[0]
if task_value == target_task_value
else f"task_{task_value}"
)
outcome_name = outcome_names_per_task.get(task_value, default_name)
filter_mask = all_task_features == task_value
new_dataset = SupervisedDataset(
X=dataset.X[filter_mask],
Y=dataset.Y[filter_mask],
Yvar=dataset.Yvar[filter_mask] if dataset.Yvar is not None else None,
feature_names=dataset.feature_names,
outcome_names=[outcome_name],
)
datasets.append(new_dataset)
# Return the new
return cls(
datasets=datasets,
target_outcome_name=outcome_names_per_task.get(
target_task_value, dataset.outcome_names[0]
),
task_feature_index=task_feature_index,
)

def _validate_datasets(self, datasets: List[SupervisedDataset]) -> None:
"""Validates that:
* Each dataset models only one outcome;
* Each outcome is modeled by only one dataset;
* The target outcome is included in the datasets;
* The datasets do not model batched inputs;
* The task feature names of the datasets all match;
* Either all or none of the datasets specify Yvar.
"""
if any(len(ds.outcome_names) > 1 for ds in datasets):
raise UnsupportedError(
"Datasets containing more than one outcome are not supported."
)
if len(self.datasets) != len(datasets):
raise UnsupportedError(
"Received multiple datasets for the same outcome. Each dataset "
"must contain data for a unique outcome. Got datasets with "
f"outcome names: {(ds.outcome_names for ds in datasets)}."
)
if self.target_outcome_name not in self.datasets:
raise InputDataError(
"Target outcome is not present in the datasets. "
f"Got {self.target_outcome_name=} and datasets for "
f"outcomes {list(self.datasets.keys())}."
)
if any(len(ds.X.shape) > 2 for ds in datasets):
raise UnsupportedError(
"Datasets modeling batched inputs are not supported."
)
if self.task_feature_index is not None:
tf_names = [ds.feature_names[self.task_feature_index] for ds in datasets]
if any(name != tf_names[0] for name in tf_names[1:]):
raise InputDataError(
"Expected the names of the task features to match across all "
f"datasets. Got {tf_names}."
)
all_Yvars = [ds.Yvar for ds in datasets]
is_none = [yvar is None for yvar in all_Yvars]
# Check that either all or None of the Yvars exist.
if not all(is_none) and any(is_none):
raise UnsupportedError(
"Expected either all or none of the datasets to have a Yvar. "
"Only subset of datasets define Yvar, which is unsupported. "
)

@property
def X(self) -> Tensor:
"""Appends task features, if needed, and concatenates the Xs of datasets to
produce the `train_X` expected by `MultiTaskGP` and subclasses.
If appending the task features, 0 is reserved for the target task and the
remaining tasks are populated with 1, 2, ..., len(datasets) - 1.
"""
all_Xs = []
next_task = 1
for outcome, ds in self.datasets.items():
if self.task_feature_index is None:
# Append the task feature index.
if outcome == self.target_outcome_name:
task_feature = 0
else:
task_feature = next_task
next_task = next_task + 1
all_Xs.append(torch.nn.functional.pad(ds.X, (0, 1), value=task_feature))
else:
all_Xs.append(ds.X)
return torch.cat(all_Xs, dim=0)

@property
def Y(self) -> Tensor:
"""Concatenates Ys of the datasets."""
return torch.cat([ds.Y for ds in self.datasets.values()], dim=0)

@property
def Yvar(self) -> Optional[Tensor]:
"""Concatenates Yvars of the datasets if they exist."""
all_Yvars = [ds.Yvar for ds in self.datasets.values()]
return None if all_Yvars[0] is None else torch.cat(all_Yvars, dim=0)

def get_dataset_without_task_feature(self, outcome_name: str) -> SupervisedDataset:
"""A helper for extracting the child datasets with their task features removed.
If the task feature index is `None`, the dataset will be returned as is.
Args:
outcome_name: The outcome name for the dataset to extract.
Returns:
The dataset without the task feature.
"""
if self.task_feature_index is None:
return self.datasets[outcome_name]
dataset = self.datasets[outcome_name]
indices = list(range(len(self.feature_names)))
indices.pop(self.task_feature_index)
return SupervisedDataset(
X=dataset.X[..., indices],
Y=dataset.Y,
Yvar=dataset.Yvar,
feature_names=[
fn for i, fn in enumerate(dataset.feature_names) if i in indices
],
outcome_names=[outcome_name],
)
147 changes: 146 additions & 1 deletion test/utils/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,42 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional

import torch
from botorch.exceptions.errors import InputDataError, UnsupportedError
from botorch.utils.containers import DenseContainer, SliceContainer
from botorch.utils.datasets import FixedNoiseDataset, RankingDataset, SupervisedDataset
from botorch.utils.datasets import (
FixedNoiseDataset,
MultiTaskDataset,
RankingDataset,
SupervisedDataset,
)
from botorch.utils.testing import BotorchTestCase
from torch import rand, randperm, Size, stack, Tensor, tensor


def make_dataset(
num_samples: int = 3,
d: int = 2,
m: int = 1,
has_yvar: bool = False,
feature_names: Optional[List[str]] = None,
outcome_names: Optional[List[str]] = None,
batch_shape: Optional[torch.Size] = None,
) -> SupervisedDataset:
feature_names = feature_names or [f"x{i}" for i in range(d)]
outcome_names = outcome_names or [f"y{i}" for i in range(m)]
batch_shape = batch_shape or torch.Size()
return SupervisedDataset(
X=rand(*batch_shape, num_samples, d),
Y=rand(*batch_shape, num_samples, m),
Yvar=rand(*batch_shape, num_samples, m) if has_yvar else None,
feature_names=feature_names,
outcome_names=outcome_names,
)


class TestDatasets(BotorchTestCase):
def test_supervised(self):
# Generate some data
Expand Down Expand Up @@ -190,3 +219,119 @@ def test_ranking(self):
feature_names=feature_names,
outcome_names=outcome_names,
)

def test_multi_task(self):
dataset_1 = make_dataset(outcome_names=["y"])
dataset_2 = make_dataset(outcome_names=["z"])
dataset_3 = make_dataset(has_yvar=True, outcome_names=["z"])
dataset_4 = make_dataset(has_yvar=True, outcome_names=["y"])
# Test validation.
with self.assertRaisesRegex(
UnsupportedError, "containing more than one outcome"
):
MultiTaskDataset(datasets=[make_dataset(m=2)], target_outcome_name="y0")
with self.assertRaisesRegex(
UnsupportedError, "multiple datasets for the same outcome"
):
MultiTaskDataset(datasets=[dataset_1, dataset_1], target_outcome_name="y")
with self.assertRaisesRegex(InputDataError, "Target outcome is not present"):
MultiTaskDataset(datasets=[dataset_1], target_outcome_name="z")
with self.assertRaisesRegex(UnsupportedError, "modeling batched inputs"):
MultiTaskDataset(
datasets=[make_dataset(batch_shape=torch.Size([2]))],
target_outcome_name="y0",
)
with self.assertRaisesRegex(InputDataError, "names of the task features"):
MultiTaskDataset(
datasets=[
dataset_1,
make_dataset(feature_names=["x1", "x3"], outcome_names=["z"]),
],
target_outcome_name="z",
task_feature_index=1,
)
with self.assertRaisesRegex(
UnsupportedError, "all or none of the datasets to have a Yvar."
):
MultiTaskDataset(datasets=[dataset_1, dataset_3], target_outcome_name="z")

# Test correct construction.
mt_dataset = MultiTaskDataset(
datasets=[dataset_1, dataset_2],
target_outcome_name="z",
)
self.assertEqual(len(mt_dataset.datasets), 2)
self.assertIsNone(mt_dataset.task_feature_index)
self.assertIs(mt_dataset.datasets["y"], dataset_1)
self.assertIs(mt_dataset.datasets["z"], dataset_2)
self.assertIsNone(mt_dataset.Yvar)
expected_X = torch.cat(
[
torch.cat([dataset_1.X, torch.ones(3, 1)], dim=-1),
torch.cat([dataset_2.X, torch.zeros(3, 1)], dim=-1),
],
dim=0,
)
expected_Y = torch.cat([ds.Y for ds in [dataset_1, dataset_2]], dim=0)
self.assertTrue(torch.equal(expected_X, mt_dataset.X))
self.assertTrue(torch.equal(expected_Y, mt_dataset.Y))
self.assertIs(
mt_dataset.get_dataset_without_task_feature(outcome_name="y"), dataset_1
)

# Test with Yvar and target_feature_index.
mt_dataset = MultiTaskDataset(
datasets=[dataset_3, dataset_4],
target_outcome_name="z",
task_feature_index=1,
)
self.assertEqual(mt_dataset.task_feature_index, 1)
expected_X_2 = torch.cat([dataset_3.X, dataset_4.X], dim=0)
expected_Yvar_2 = torch.cat([dataset_3.Yvar, dataset_4.Yvar], dim=0)
self.assertTrue(torch.equal(expected_X_2, mt_dataset.X))
self.assertTrue(torch.equal(expected_Yvar_2, mt_dataset.Yvar))
# Check that the task feature is removed correctly.
ds_3_no_task = mt_dataset.get_dataset_without_task_feature(outcome_name="z")
self.assertTrue(torch.equal(ds_3_no_task.X, dataset_3.X[:, :1]))
self.assertTrue(torch.equal(ds_3_no_task.Y, dataset_3.Y))
self.assertTrue(torch.equal(ds_3_no_task.Yvar, dataset_3.Yvar))
self.assertEqual(ds_3_no_task.feature_names, dataset_3.feature_names[:1])
self.assertEqual(ds_3_no_task.outcome_names, dataset_3.outcome_names)

# Test from_joint_dataset.
sort_idcs = [3, 4, 5, 0, 1, 2] # X & Y will get sorted based on task feature.
for outcome_names_per_task in [None, {0: "x", 1: "y"}]:
joint_dataset = SupervisedDataset(
X=expected_X,
Y=expected_Y,
feature_names=["x0", "x1", "task"],
outcome_names=["z"],
)
mt_dataset = MultiTaskDataset.from_joint_dataset(
dataset=joint_dataset,
task_feature_index=-1,
target_task_value=0,
outcome_names_per_task=outcome_names_per_task,
)
self.assertEqual(len(mt_dataset.datasets), 2)
if outcome_names_per_task is None:
self.assertEqual(list(mt_dataset.datasets.keys()), ["z", "task_1"])
self.assertEqual(mt_dataset.target_outcome_name, "z")
else:
self.assertEqual(list(mt_dataset.datasets.keys()), ["x", "y"])
self.assertEqual(mt_dataset.target_outcome_name, "x")

self.assertTrue(torch.equal(mt_dataset.X, expected_X[sort_idcs]))
self.assertTrue(torch.equal(mt_dataset.Y, expected_Y[sort_idcs]))
self.assertTrue(
torch.equal(
mt_dataset.datasets[mt_dataset.target_outcome_name].Y, dataset_2.Y
)
)
self.assertIsNone(mt_dataset.Yvar)
with self.assertRaisesRegex(UnsupportedError, "more than one outcome"):
MultiTaskDataset.from_joint_dataset(
dataset=make_dataset(m=2),
task_feature_index=-1,
target_task_value=0,
)

0 comments on commit 86103d7

Please sign in to comment.