From 4ec8d2fd05ff6063da314220563841e7a8cac677 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Tue, 27 Feb 2024 15:37:51 -0800 Subject: [PATCH] Remove deprecated args from base MCSampler Summary: These have been deprecated since 0.8.0, time to let them go. Reviewed By: Balandat Differential Revision: D54280379 --- botorch/sampling/base.py | 36 +++---------------- test/sampling/test_base.py | 14 +------- .../Multi_objective_multi_fidelity_BO.ipynb | 2 -- ...k_averse_bo_with_input_perturbations.ipynb | 2 +- 4 files changed, 7 insertions(+), 47 deletions(-) diff --git a/botorch/sampling/base.py b/botorch/sampling/base.py index becfdb4814..cf31e3d7d9 100644 --- a/botorch/sampling/base.py +++ b/botorch/sampling/base.py @@ -10,9 +10,8 @@ from __future__ import annotations -import warnings from abc import ABC, abstractmethod -from typing import Any, Optional, Tuple +from typing import Optional, Tuple import torch from botorch.exceptions.errors import InputDataError @@ -49,7 +48,6 @@ def __init__( self, sample_shape: torch.Size, seed: Optional[int] = None, - **kwargs: Any, ) -> None: r"""Abstract base class for samplers. @@ -57,37 +55,13 @@ def __init__( sample_shape: The `sample_shape` of the samples to generate. The full shape of the samples is given by `posterior._extended_shape(sample_shape)`. seed: An optional seed to use for sampling. - **kwargs: Catch-all for deprecated kwargs. """ super().__init__() if not isinstance(sample_shape, torch.Size): - if isinstance(sample_shape, int): - sample_shape = torch.Size([sample_shape]) - warnings.warn( - "The first positional argument of samplers, `num_samples`, has " - "been deprecated and replaced with `sample_shape`, which expects " - "a `torch.Size` object.", - DeprecationWarning, - ) - else: - raise InputDataError( - "Expected `sample_shape` to be a `torch.Size` object, " - f"got {sample_shape}." - ) - for k, v in kwargs.items(): - if k == "resample": - if v is True: - raise RuntimeError(KWARG_ERR_MSG.format(k, "StochasticSampler")) - else: - warnings.warn(KWARGS_DEPRECATED_MSG.format(k), DeprecationWarning) - elif k == "collapse_batch_dims": - if v is False: - raise RuntimeError(KWARG_ERR_MSG.format(k, "ForkedRNGSampler")) - else: - warnings.warn(KWARGS_DEPRECATED_MSG.format(k), DeprecationWarning) - else: - raise RuntimeError(f"Recevied an unknown argument {k}: {v}.") - + raise InputDataError( + "Expected `sample_shape` to be a `torch.Size` object, " + f"got {sample_shape}." + ) self.sample_shape = sample_shape self.seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item() self.register_buffer("base_samples", None) diff --git a/test/sampling/test_base.py b/test/sampling/test_base.py index c4cd385043..a3373e3bd9 100644 --- a/test/sampling/test_base.py +++ b/test/sampling/test_base.py @@ -36,21 +36,9 @@ def test_init(self): # Default seed. sampler = NonAbstractSampler(sample_shape=torch.Size([4])) self.assertIsInstance(sampler.seed, int) - # Deprecated args & error handling. - with self.assertWarnsRegex(DeprecationWarning, "positional argument"): - NonAbstractSampler(4) + # Error handling. with self.assertRaisesRegex(InputDataError, "sample_shape"): NonAbstractSampler(4.5) - with self.assertWarnsRegex(DeprecationWarning, "resample"): - NonAbstractSampler(sample_shape=torch.Size([4]), resample=False) - with self.assertRaisesRegex(RuntimeError, "StochasticSampler"): - NonAbstractSampler(sample_shape=torch.Size([4]), resample=True) - with self.assertWarnsRegex(DeprecationWarning, "collapse_batch"): - NonAbstractSampler(sample_shape=torch.Size([4]), collapse_batch_dims=True) - with self.assertRaisesRegex(RuntimeError, "ForkedRNGSampler"): - NonAbstractSampler(sample_shape=torch.Size([4]), collapse_batch_dims=False) - with self.assertRaisesRegex(RuntimeError, "unknown argument"): - NonAbstractSampler(sample_shape=torch.Size([4]), dummy_arg=True) def test_batch_range(self): posterior = MockPosterior() diff --git a/tutorials/Multi_objective_multi_fidelity_BO.ipynb b/tutorials/Multi_objective_multi_fidelity_BO.ipynb index 8829c6d79d..82f55c9c7c 100644 --- a/tutorials/Multi_objective_multi_fidelity_BO.ipynb +++ b/tutorials/Multi_objective_multi_fidelity_BO.ipynb @@ -501,8 +501,6 @@ " ref_point=ref_point,\n", " sampler=SobolQMCNormalSampler(\n", " sample_shape=torch.Size([NUM_INNER_MC_SAMPLES]),\n", - " resample=False,\n", - " collapse_batch_dims=True,\n", " ),\n", " use_posterior_mean=True,\n", " ),\n", diff --git a/tutorials/risk_averse_bo_with_input_perturbations.ipynb b/tutorials/risk_averse_bo_with_input_perturbations.ipynb index 89f7498182..6346adc3b2 100644 --- a/tutorials/risk_averse_bo_with_input_perturbations.ipynb +++ b/tutorials/risk_averse_bo_with_input_perturbations.ipynb @@ -224,7 +224,7 @@ " acqf = qNoisyExpectedImprovement(\n", " model=model,\n", " X_baseline=train_X,\n", - " sampler=SobolQMCNormalSampler(128),\n", + " sampler=SobolQMCNormalSampler(sample_shape=torch.Size([128])),\n", " objective=risk_measure,\n", " prune_baseline=True,\n", " )\n",