diff --git a/botorch/utils/datasets.py b/botorch/utils/datasets.py index c33a040d32..bfe12aa047 100644 --- a/botorch/utils/datasets.py +++ b/botorch/utils/datasets.py @@ -37,28 +37,29 @@ def __call__(cls, *args: Any, **kwargs: Any): r"""Converts Tensor-valued fields to DenseContainer under the assumption that said fields house collections of feature vectors.""" hints = get_type_hints(cls) - f_iter = filter(lambda f: f.init, fields(cls)) + fields_iter = (item for item in fields(cls) if item.init is not None) f_dict = {} - for obj, f in chain( - zip(args, f_iter), ((kwargs.pop(f.name, MISSING), f) for f in f_iter) + for value, field in chain( + zip(args, fields_iter), + ((kwargs.pop(field.name, MISSING), field) for field in fields_iter), ): - if obj is MISSING: - if f.default is not MISSING: - obj = f.default - elif f.default_factory is not MISSING: - obj = f.default_factory() + if value is MISSING: + if field.default is not MISSING: + value = field.default + elif field.default_factory is not MISSING: + value = field.default_factory() else: - raise RuntimeError(f"Missing required field `{f.name}`.") + raise RuntimeError(f"Missing required field `{field.name}`.") - if issubclass(hints[f.name], BotorchContainer): - if isinstance(obj, Tensor): - obj = DenseContainer(obj, event_shape=obj.shape[-1:]) - elif not isinstance(obj, BotorchContainer): + if issubclass(hints[field.name], BotorchContainer): + if isinstance(value, Tensor): + value = DenseContainer(value, event_shape=value.shape[-1:]) + elif not isinstance(value, BotorchContainer): raise TypeError( - f"Expected for field `{f.name}` " - f"but was {type(obj)}." + "Expected for field " + f"`{field.name}` but was {type(value)}." ) - f_dict[f.name] = obj + f_dict[field.name] = value return super().__call__(**f_dict, **kwargs) diff --git a/test/utils/test_datasets.py b/test/utils/test_datasets.py index fd550a2846..02c9234f55 100644 --- a/test/utils/test_datasets.py +++ b/test/utils/test_datasets.py @@ -30,14 +30,15 @@ def test_base(self): def test_supervised_meta(self): X = rand(3, 2) Y = rand(3, 1) - A = DenseContainer(rand(3, 5), event_shape=Size([5])) + t = rand(3, 5) + A = DenseContainer(t, event_shape=Size([5])) B = rand(2, 1) SupervisedDatasetWithDefaults = make_dataclass( cls_name="SupervisedDatasetWithDefaults", bases=(SupervisedDataset,), fields=[ - ("default", DenseContainer, field(default=A)), + ("default", DenseContainer, field(default=t)), ("factory", DenseContainer, field(default_factory=lambda: A)), ("other", Tensor, field(default_factory=lambda: B)), ], @@ -55,6 +56,7 @@ def test_supervised_meta(self): # Check handling of default values and factories dataset = SupervisedDatasetWithDefaults(X=X, Y=Y) + self.assertIsInstance(dataset.default, DenseContainer) self.assertEqual(dataset.default, A) self.assertEqual(dataset.factory, A) self.assertTrue(dataset.other is B)