From 23ee87882c6f3887b0f7f26c674b7c107193f15a Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Tue, 27 Feb 2024 15:03:39 -0800 Subject: [PATCH] Remove deprecated args from base MCSampler Summary: These have been deprecated since 0.8.0, time to let them go. Differential Revision: D54280379 --- botorch/sampling/base.py | 36 +++++------------------------------- test/sampling/test_base.py | 14 +------------- 2 files changed, 6 insertions(+), 44 deletions(-) diff --git a/botorch/sampling/base.py b/botorch/sampling/base.py index becfdb4814..cf31e3d7d9 100644 --- a/botorch/sampling/base.py +++ b/botorch/sampling/base.py @@ -10,9 +10,8 @@ from __future__ import annotations -import warnings from abc import ABC, abstractmethod -from typing import Any, Optional, Tuple +from typing import Optional, Tuple import torch from botorch.exceptions.errors import InputDataError @@ -49,7 +48,6 @@ def __init__( self, sample_shape: torch.Size, seed: Optional[int] = None, - **kwargs: Any, ) -> None: r"""Abstract base class for samplers. @@ -57,37 +55,13 @@ def __init__( sample_shape: The `sample_shape` of the samples to generate. The full shape of the samples is given by `posterior._extended_shape(sample_shape)`. seed: An optional seed to use for sampling. - **kwargs: Catch-all for deprecated kwargs. """ super().__init__() if not isinstance(sample_shape, torch.Size): - if isinstance(sample_shape, int): - sample_shape = torch.Size([sample_shape]) - warnings.warn( - "The first positional argument of samplers, `num_samples`, has " - "been deprecated and replaced with `sample_shape`, which expects " - "a `torch.Size` object.", - DeprecationWarning, - ) - else: - raise InputDataError( - "Expected `sample_shape` to be a `torch.Size` object, " - f"got {sample_shape}." - ) - for k, v in kwargs.items(): - if k == "resample": - if v is True: - raise RuntimeError(KWARG_ERR_MSG.format(k, "StochasticSampler")) - else: - warnings.warn(KWARGS_DEPRECATED_MSG.format(k), DeprecationWarning) - elif k == "collapse_batch_dims": - if v is False: - raise RuntimeError(KWARG_ERR_MSG.format(k, "ForkedRNGSampler")) - else: - warnings.warn(KWARGS_DEPRECATED_MSG.format(k), DeprecationWarning) - else: - raise RuntimeError(f"Recevied an unknown argument {k}: {v}.") - + raise InputDataError( + "Expected `sample_shape` to be a `torch.Size` object, " + f"got {sample_shape}." + ) self.sample_shape = sample_shape self.seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item() self.register_buffer("base_samples", None) diff --git a/test/sampling/test_base.py b/test/sampling/test_base.py index c4cd385043..a3373e3bd9 100644 --- a/test/sampling/test_base.py +++ b/test/sampling/test_base.py @@ -36,21 +36,9 @@ def test_init(self): # Default seed. sampler = NonAbstractSampler(sample_shape=torch.Size([4])) self.assertIsInstance(sampler.seed, int) - # Deprecated args & error handling. - with self.assertWarnsRegex(DeprecationWarning, "positional argument"): - NonAbstractSampler(4) + # Error handling. with self.assertRaisesRegex(InputDataError, "sample_shape"): NonAbstractSampler(4.5) - with self.assertWarnsRegex(DeprecationWarning, "resample"): - NonAbstractSampler(sample_shape=torch.Size([4]), resample=False) - with self.assertRaisesRegex(RuntimeError, "StochasticSampler"): - NonAbstractSampler(sample_shape=torch.Size([4]), resample=True) - with self.assertWarnsRegex(DeprecationWarning, "collapse_batch"): - NonAbstractSampler(sample_shape=torch.Size([4]), collapse_batch_dims=True) - with self.assertRaisesRegex(RuntimeError, "ForkedRNGSampler"): - NonAbstractSampler(sample_shape=torch.Size([4]), collapse_batch_dims=False) - with self.assertRaisesRegex(RuntimeError, "unknown argument"): - NonAbstractSampler(sample_shape=torch.Size([4]), dummy_arg=True) def test_batch_range(self): posterior = MockPosterior()