Skip to content

Commit

Permalink
More removal of unused ** arguments (#2336)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2336

I am auditing cases in which BoTorch functions admit ** arguments and then don't use them. In many cases, it was easier to fix the issue than write it up. These are the easy cases.

Note on inheritance: Type-checkers say that if a method accepts an argument, so must methods that override it.
* I removed `**kwargs` from abstract methods to remove the implication that their subclasses must also support `**kwargs` even if they don't use them.
* In cases where the base method admits `**kwargs` and is GPyTorch, I chose to "inconsistent override" GPyTorch rather than also having ignored `**kwargs` in BoTorch. This was the case for overriding `Module.forward` (in an ExactGP), `Kernel.forward`, and `Likelihood.forward`.

Changes:
* Small correctness fix in error-catching
* Some small typing fixes
* Removed `**kwargs` in many cases.

Reviewed By: saitcakmak

Differential Revision: D56849296

fbshipit-source-id: b059148e018608fac9691ee255cd73118d2e52b1
  • Loading branch information
esantorella authored and facebook-github-bot committed May 10, 2024
1 parent 6600655 commit faf8d8d
Show file tree
Hide file tree
Showing 14 changed files with 92 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def _safe_update_omega(
check_no_nans(omega_f_nat_cov_new)
return omega_f_nat_mean_new, omega_f_nat_cov_new

except RuntimeError or InputDataError:
except (RuntimeError, InputDataError):
return omega_f_nat_mean, omega_f_nat_cov


Expand Down Expand Up @@ -1070,7 +1070,7 @@ def _update_damping_when_converged(
damping_factor: Tensor,
iteration: Tensor,
threshold: float = 1e-3,
) -> Tensor:
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Set the damping factor to 0 once converged. Convergence is determined by the
relative change in the entries of the mean and covariance matrix.
Expand All @@ -1087,8 +1087,10 @@ def _update_damping_when_converged(
damping_factor: A `batch_shape`-dim Tensor containing the damping factor.
Returns:
A `batch_shape x param_shape`-dim Tensor containing the updated damping
- A `batch_shape x param_shape`-dim Tensor containing the updated damping
factor.
- Difference between `mean_new` and `mean_old`
- Difference between `cov_new` and `cov_old`
"""
df = damping_factor.clone()
delta_mean = mean_new - mean_old
Expand Down
3 changes: 2 additions & 1 deletion botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def sample(self) -> None:

@abstractmethod
def postprocess_mcmc_samples(
self, mcmc_samples: Dict[str, Tensor], **kwargs: Any
self,
mcmc_samples: Dict[str, Tensor],
) -> Dict[str, Tensor]:
"""Post-process the final MCMC samples."""
pass # pragma: no cover
Expand Down
15 changes: 8 additions & 7 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def posterior(
return posterior_transform(posterior)
return posterior

def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Model:
def condition_on_observations(
self, X: Tensor, Y: Tensor, noise: Optional[Tensor] = None, **kwargs: Any
) -> Model:
r"""Condition the model on new observations.
Args:
Expand All @@ -219,6 +221,9 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
standard broadcasting semantics. If `Y` has fewer batch dimensions
than `X`, its is assumed that the missing batch dimensions are
the same for all `Y`.
noise: If not `None`, a tensor of the same shape as `Y` representing
the associated noise variance.
kwargs: Passed to `self.get_fantasy_model`.
Returns:
A `Model` object of the same type, representing the original model
Expand All @@ -233,14 +238,14 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
>>> new_Y = torch.sin(new_X[:, 0]) + torch.cos(new_X[:, 1])
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
"""
Yvar = kwargs.pop("noise", None)
Yvar = noise

if hasattr(self, "outcome_transform"):
# pass the transformed data to get_fantasy_model below
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)
if not isinstance(self, BatchedMultiOutputGPyTorchModel):
# `noise` is assumed to already be outcome-transformed.
Y, _ = self.outcome_transform(Y, Yvar)
Y, _ = self.outcome_transform(Y=Y, Yvar=Yvar)
# validate using strict=False, since we cannot tell if Y has an explicit
# output dimension
self._validate_tensor_args(X=X, Y=Y, Yvar=Yvar, strict=False)
Expand Down Expand Up @@ -356,7 +361,6 @@ def posterior(
output_indices: Optional[List[int]] = None,
observation_noise: Union[bool, Tensor] = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> Union[GPyTorchPosterior, TransformedPosterior]:
r"""Computes the posterior over model outputs at the provided points.
Expand Down Expand Up @@ -609,7 +613,6 @@ def posterior(
output_indices: Optional[List[int]] = None,
observation_noise: Union[bool, Tensor] = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> Union[GPyTorchPosterior, PosteriorList]:
r"""Computes the posterior over model outputs at the provided points.
If any model returns a MultitaskMultivariateNormal posterior, then that
Expand Down Expand Up @@ -661,7 +664,6 @@ def posterior(
X=X,
output_indices=output_indices,
observation_noise=observation_noise,
**kwargs,
)
if not returns_untransformed:
mvns = [p.distribution for p in posterior.posteriors]
Expand Down Expand Up @@ -756,7 +758,6 @@ def posterior(
output_indices: Optional[List[int]] = None,
observation_noise: Union[bool, Tensor] = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> Union[GPyTorchPosterior, TransformedPosterior]:
r"""Computes the posterior over model outputs at the provided points.
Expand Down
13 changes: 7 additions & 6 deletions botorch/models/higher_order_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class FlattenedStandardize(Standardize):
def __init__(
self,
output_shape: torch.Size,
batch_shape: torch.Size = None,
batch_shape: Optional[torch.Size] = None,
min_stdv: float = 1e-8,
):
r"""
Expand Down Expand Up @@ -385,7 +385,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
return super().get_fantasy_model(inputs, reshaped_targets, **kwargs)

def condition_on_observations(
self, X: Tensor, Y: Tensor, **kwargs: Any
self, X: Tensor, Y: Tensor, noise: Optional[torch.Tensor] = None, **kwargs: Any
) -> HigherOrderGP:
r"""Condition the model on new observations.
Expand All @@ -401,17 +401,19 @@ def condition_on_observations(
standard broadcasting semantics. If `Y` has fewer batch dimensions
than `X`, its is assumed that the missing batch dimensions are
the same for all `Y`.
noise: If not None, a tensor of the same shape as `Y` representing
the noise variance associated with each observation.
kwargs: Passed to `condition_on_observations`.
Returns:
A `BatchedMultiOutputGPyTorchModel` object of the same type with
`n + n'` training examples, representing the original model
conditioned on the new observations `(X, Y)` (and possibly noise
observations passed in via kwargs).
"""
noise = kwargs.get("noise")
if hasattr(self, "outcome_transform"):
# we need to apply transforms before shifting batch indices around
Y, noise = self.outcome_transform(Y, noise)
Y, noise = self.outcome_transform(Y=Y, Yvar=noise)
self._validate_tensor_args(X=X, Y=Y, Yvar=noise, strict=False)

# we don't need to do un-squeezing because Y already is batched
Expand All @@ -420,7 +422,7 @@ def condition_on_observations(
# kwargs.update({"noise": noise})
fantasy_model = super(
BatchedMultiOutputGPyTorchModel, self
).condition_on_observations(X=X, Y=Y, **kwargs)
).condition_on_observations(X=X, Y=Y, noise=noise, **kwargs)
fantasy_model._input_batch_shape = fantasy_model.train_targets.shape[
: (-1 if self._num_outputs == 1 else -2)
]
Expand All @@ -433,7 +435,6 @@ def posterior(
output_indices: Optional[List[int]] = None,
observation_noise: Union[bool, Tensor] = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> GPyTorchPosterior:
self.eval() # make sure we're calling a posterior

Expand Down
1 change: 0 additions & 1 deletion botorch/models/kernels/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def forward(
x2: Tensor,
diag: bool = False,
last_dim_is_batch: bool = False,
**kwargs,
) -> Tensor:
delta = x1.unsqueeze(-2) != x2.unsqueeze(-3)
dists = delta / self.lengthscale.unsqueeze(-2)
Expand Down
4 changes: 2 additions & 2 deletions botorch/models/likelihoods/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import math
from abc import ABC, abstractmethod
from typing import Any, Tuple
from typing import Tuple

import torch
from botorch.utils.probability.utils import (
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(self, max_plate_nesting: int = 1):
"""
super().__init__(max_plate_nesting)

def forward(self, utility: Tensor, D: Tensor, **kwargs: Any) -> Bernoulli:
def forward(self, utility: Tensor, D: Tensor) -> Bernoulli:
"""Given the difference in (estimated) utility util_diff = f(v) - f(u),
return a Bernoulli distribution object representing the likelihood of
the user prefer v over u.
Expand Down
7 changes: 3 additions & 4 deletions botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def posterior(
output_indices: Optional[List[int]] = None,
observation_noise: Union[bool, Tensor] = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> Posterior:
r"""Computes the posterior over model outputs at the provided points.
Expand Down Expand Up @@ -301,7 +300,9 @@ def __init__(self, args):

@abstractmethod
def condition_on_observations(
self: TFantasizeMixin, X: Tensor, Y: Tensor, **kwargs: Any
self: TFantasizeMixin,
X: Tensor,
Y: Tensor,
) -> TFantasizeMixin:
"""
Classes that inherit from `FantasizeMixin` must implement
Expand All @@ -314,7 +315,6 @@ def posterior(
X: Tensor,
*args,
observation_noise: bool = False,
**kwargs: Any,
) -> Posterior:
"""
Classes that inherit from `FantasizeMixin` must implement
Expand Down Expand Up @@ -474,7 +474,6 @@ def posterior(
output_indices: Optional[List[int]] = None,
observation_noise: Union[bool, Tensor] = False,
posterior_transform: Optional[Callable[[PosteriorList], Posterior]] = None,
**kwargs: Any,
) -> Posterior:
r"""Computes the posterior over model outputs at the provided points.
Expand Down
1 change: 0 additions & 1 deletion botorch/models/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,6 @@ def posterior(
output_indices: Optional[List[int]] = None,
observation_noise: Union[bool, Tensor] = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> MultitaskGPPosterior:
self.eval()

Expand Down
5 changes: 2 additions & 3 deletions botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,6 @@ def posterior(
output_indices: Optional[List[int]] = None,
observation_noise: bool = False,
posterior_transform: Optional[PosteriorTransform] = None,
**kwargs: Any,
) -> Posterior:
r"""Computes the posterior over model outputs at the provided points.
Expand Down Expand Up @@ -1100,11 +1099,11 @@ def posterior(
return posterior_transform(posterior)
return posterior

def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Model:
def condition_on_observations(self, X: Tensor, Y: Tensor) -> Model:
r"""Condition the model on new observations.
Note that unlike other BoTorch models, PairwiseGP requires Y to be
pairwise comparisons
pairwise comparisons.
Args:
X: A `batch_shape x n x d` dimension tensor X
Expand Down
14 changes: 7 additions & 7 deletions botorch/optim/closures/model_closures.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def get_loss_closure_with_grads(
@GetLossClosureWithGrads.register(object, object, object, object)
def _get_loss_closure_with_grads_fallback(
mll: MarginalLogLikelihood,
_: object,
__: object,
_likelihood_type: object,
_model_type: object,
data_loader: Optional[DataLoader],
parameters: Dict[str, Tensor],
reducer: Callable[[Tensor], Tensor] = Tensor.sum,
Expand All @@ -127,8 +127,8 @@ def _get_loss_closure_with_grads_fallback(
@GetLossClosure.register(MarginalLogLikelihood, object, object, DataLoader)
def _get_loss_closure_fallback_external(
mll: MarginalLogLikelihood,
_: object,
__: object,
_likelihood_type: object,
_model_type: object,
data_loader: DataLoader,
**ignore: Any,
) -> Callable[[], Tensor]:
Expand All @@ -153,7 +153,7 @@ def closure(**kwargs: Any) -> Tensor:

@GetLossClosure.register(MarginalLogLikelihood, object, object, NoneType)
def _get_loss_closure_fallback_internal(
mll: MarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
mll: MarginalLogLikelihood, _: object, __: object, ___: None, **ignore: Any
) -> Callable[[], Tensor]:
r"""Fallback loss closure with internally managed data."""

Expand All @@ -167,7 +167,7 @@ def closure(**kwargs: Any) -> Tensor:

@GetLossClosure.register(ExactMarginalLogLikelihood, object, object, NoneType)
def _get_loss_closure_exact_internal(
mll: ExactMarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
mll: ExactMarginalLogLikelihood, _: object, __: object, ___: None, **ignore: Any
) -> Callable[[], Tensor]:
r"""ExactMarginalLogLikelihood loss closure with internally managed data."""

Expand All @@ -183,7 +183,7 @@ def closure(**kwargs: Any) -> Tensor:

@GetLossClosure.register(SumMarginalLogLikelihood, object, object, NoneType)
def _get_loss_closure_sum_internal(
mll: SumMarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
mll: SumMarginalLogLikelihood, _: object, __: object, ___: None, **ignore: Any
) -> Callable[[], Tensor]:
r"""SumMarginalLogLikelihood loss closure with internally managed data."""

Expand Down
Loading

0 comments on commit faf8d8d

Please sign in to comment.