Skip to content

Commit

Permalink
Remove deprecated args from base MCSampler
Browse files Browse the repository at this point in the history
Summary: These have been deprecated since 0.8.0, time to let them go.

Differential Revision: D54280379
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 27, 2024
1 parent b6b8f9c commit 23ee878
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 44 deletions.
36 changes: 5 additions & 31 deletions botorch/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple
from typing import Optional, Tuple

import torch
from botorch.exceptions.errors import InputDataError
Expand Down Expand Up @@ -49,45 +48,20 @@ def __init__(
self,
sample_shape: torch.Size,
seed: Optional[int] = None,
**kwargs: Any,
) -> None:
r"""Abstract base class for samplers.
Args:
sample_shape: The `sample_shape` of the samples to generate. The full shape
of the samples is given by `posterior._extended_shape(sample_shape)`.
seed: An optional seed to use for sampling.
**kwargs: Catch-all for deprecated kwargs.
"""
super().__init__()
if not isinstance(sample_shape, torch.Size):
if isinstance(sample_shape, int):
sample_shape = torch.Size([sample_shape])
warnings.warn(
"The first positional argument of samplers, `num_samples`, has "
"been deprecated and replaced with `sample_shape`, which expects "
"a `torch.Size` object.",
DeprecationWarning,
)
else:
raise InputDataError(
"Expected `sample_shape` to be a `torch.Size` object, "
f"got {sample_shape}."
)
for k, v in kwargs.items():
if k == "resample":
if v is True:
raise RuntimeError(KWARG_ERR_MSG.format(k, "StochasticSampler"))
else:
warnings.warn(KWARGS_DEPRECATED_MSG.format(k), DeprecationWarning)
elif k == "collapse_batch_dims":
if v is False:
raise RuntimeError(KWARG_ERR_MSG.format(k, "ForkedRNGSampler"))
else:
warnings.warn(KWARGS_DEPRECATED_MSG.format(k), DeprecationWarning)
else:
raise RuntimeError(f"Recevied an unknown argument {k}: {v}.")

raise InputDataError(
"Expected `sample_shape` to be a `torch.Size` object, "
f"got {sample_shape}."
)
self.sample_shape = sample_shape
self.seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item()
self.register_buffer("base_samples", None)
Expand Down
14 changes: 1 addition & 13 deletions test/sampling/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,9 @@ def test_init(self):
# Default seed.
sampler = NonAbstractSampler(sample_shape=torch.Size([4]))
self.assertIsInstance(sampler.seed, int)
# Deprecated args & error handling.
with self.assertWarnsRegex(DeprecationWarning, "positional argument"):
NonAbstractSampler(4)
# Error handling.
with self.assertRaisesRegex(InputDataError, "sample_shape"):
NonAbstractSampler(4.5)
with self.assertWarnsRegex(DeprecationWarning, "resample"):
NonAbstractSampler(sample_shape=torch.Size([4]), resample=False)
with self.assertRaisesRegex(RuntimeError, "StochasticSampler"):
NonAbstractSampler(sample_shape=torch.Size([4]), resample=True)
with self.assertWarnsRegex(DeprecationWarning, "collapse_batch"):
NonAbstractSampler(sample_shape=torch.Size([4]), collapse_batch_dims=True)
with self.assertRaisesRegex(RuntimeError, "ForkedRNGSampler"):
NonAbstractSampler(sample_shape=torch.Size([4]), collapse_batch_dims=False)
with self.assertRaisesRegex(RuntimeError, "unknown argument"):
NonAbstractSampler(sample_shape=torch.Size([4]), dummy_arg=True)

def test_batch_range(self):
posterior = MockPosterior()
Expand Down

0 comments on commit 23ee878

Please sign in to comment.