Skip to content

Commit

Permalink
Python 3.11 compliance for dataclasses: Don't use mutable defaults (#…
Browse files Browse the repository at this point in the history
…1927)

Summary:
## Motivation

See error here: https://github.com/pytorch/botorch/actions/runs/5508336054/jobs/10039481686?pr=1924
* Tested with a tensor instead of a `DenseContainer` since we needed something non-mutable
* Added a test that the tensor is properly cast to a `DenseContainer`
* made code more readable (for me anyway)

### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)?

Yes

Pull Request resolved: #1927

Test Plan:
Units

## Related PRs

Unblocks #1924

Reviewed By: saitcakmak

Differential Revision: D47341881

Pulled By: esantorella

fbshipit-source-id: 64c497504bffde8cc227dff8eaaa85e17aa36b95
  • Loading branch information
esantorella authored and facebook-github-bot committed Jul 10, 2023
1 parent 1d077c9 commit 56198cd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
33 changes: 17 additions & 16 deletions botorch/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,29 @@ def __call__(cls, *args: Any, **kwargs: Any):
r"""Converts Tensor-valued fields to DenseContainer under the assumption
that said fields house collections of feature vectors."""
hints = get_type_hints(cls)
f_iter = filter(lambda f: f.init, fields(cls))
fields_iter = (item for item in fields(cls) if item.init is not None)
f_dict = {}
for obj, f in chain(
zip(args, f_iter), ((kwargs.pop(f.name, MISSING), f) for f in f_iter)
for value, field in chain(
zip(args, fields_iter),
((kwargs.pop(field.name, MISSING), field) for field in fields_iter),
):
if obj is MISSING:
if f.default is not MISSING:
obj = f.default
elif f.default_factory is not MISSING:
obj = f.default_factory()
if value is MISSING:
if field.default is not MISSING:
value = field.default
elif field.default_factory is not MISSING:
value = field.default_factory()
else:
raise RuntimeError(f"Missing required field `{f.name}`.")
raise RuntimeError(f"Missing required field `{field.name}`.")

if issubclass(hints[f.name], BotorchContainer):
if isinstance(obj, Tensor):
obj = DenseContainer(obj, event_shape=obj.shape[-1:])
elif not isinstance(obj, BotorchContainer):
if issubclass(hints[field.name], BotorchContainer):
if isinstance(value, Tensor):
value = DenseContainer(value, event_shape=value.shape[-1:])
elif not isinstance(value, BotorchContainer):
raise TypeError(
f"Expected <BotorchContainer | Tensor> for field `{f.name}` "
f"but was {type(obj)}."
"Expected <BotorchContainer | Tensor> for field "
f"`{field.name}` but was {type(value)}."
)
f_dict[f.name] = obj
f_dict[field.name] = value

return super().__call__(**f_dict, **kwargs)

Expand Down
6 changes: 4 additions & 2 deletions test/utils/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ def test_base(self):
def test_supervised_meta(self):
X = rand(3, 2)
Y = rand(3, 1)
A = DenseContainer(rand(3, 5), event_shape=Size([5]))
t = rand(3, 5)
A = DenseContainer(t, event_shape=Size([5]))
B = rand(2, 1)

SupervisedDatasetWithDefaults = make_dataclass(
cls_name="SupervisedDatasetWithDefaults",
bases=(SupervisedDataset,),
fields=[
("default", DenseContainer, field(default=A)),
("default", DenseContainer, field(default=t)),
("factory", DenseContainer, field(default_factory=lambda: A)),
("other", Tensor, field(default_factory=lambda: B)),
],
Expand All @@ -55,6 +56,7 @@ def test_supervised_meta(self):

# Check handling of default values and factories
dataset = SupervisedDatasetWithDefaults(X=X, Y=Y)
self.assertIsInstance(dataset.default, DenseContainer)
self.assertEqual(dataset.default, A)
self.assertEqual(dataset.factory, A)
self.assertTrue(dataset.other is B)
Expand Down

0 comments on commit 56198cd

Please sign in to comment.