Skip to content

Commit

Permalink
Implement MultiTaskDataset.__eq__ (pytorch#2594)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2594

Previously, this would fallback to `SupervisedDataset.__eq__`, which uses `self.X` for comparison. If the underlying datasets have heterogeneous feature sets, `self.X` errors out.

The new `MultiTaskDataset.__eq__` resolves this issue by comparing the underlying datasets one by one.

Reviewed By: Balandat

Differential Revision: D64911436

fbshipit-source-id: ecb7343d86c4526d06f61725c1663e50f1f1902f
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 24, 2024
1 parent e7539db commit 9d37e90
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
10 changes: 9 additions & 1 deletion botorch/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,14 @@ def get_dataset_without_task_feature(self, outcome_name: str) -> SupervisedDatas
outcome_names=[outcome_name],
)

def __eq__(self, other: Any) -> bool:
return (
type(other) is type(self)
and self.datasets == other.datasets
and self.target_outcome_name == other.target_outcome_name
and self.task_feature_index == other.task_feature_index
)


class ContextualDataset(SupervisedDataset):
"""This is a contextual dataset that is constructed from either a single
Expand Down Expand Up @@ -548,7 +556,7 @@ def Y(self) -> Tensor:
return torch.cat(Ys, dim=-1)

@property
def Yvar(self) -> Tensor:
def Yvar(self) -> Tensor | None:
"""Concatenates the Yvars from the child datasets to create the Y expected
by LCEM model if there are multiple datasets; Or return the Yvar expected
by LCEA model if there is only one dataset.
Expand Down
11 changes: 11 additions & 0 deletions test/utils/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,17 @@ def test_multi_task(self):
):
mt_dataset.X

# Test equality.
self.assertEqual(mt_dataset, mt_dataset)
self.assertNotEqual(mt_dataset, dataset_5)
self.assertNotEqual(
mt_dataset, MultiTaskDataset(datasets=[dataset_1], target_outcome_name="y")
)
self.assertNotEqual(
mt_dataset,
MultiTaskDataset(datasets=[dataset_1, dataset_5], target_outcome_name="z"),
)

def test_contextual_datasets(self):
num_contexts = 3
feature_names = [f"x_c{i}" for i in range(num_contexts)]
Expand Down

0 comments on commit 9d37e90

Please sign in to comment.