diff --git a/aepsych/benchmark/problem.py b/aepsych/benchmark/problem.py index 2aea5ec45..966431c19 100644 --- a/aepsych/benchmark/problem.py +++ b/aepsych/benchmark/problem.py @@ -9,7 +9,7 @@ import numpy as np import torch -from aepsych.models.model_protocol import ModelProtocol +from aepsych.models.base import AEPsychModelMixin from aepsych.models.utils import p_below_threshold from aepsych.strategy import SequentialStrategy from aepsych.utils import make_scaled_sobol @@ -78,11 +78,11 @@ def sample_y( """ return bernoulli.rvs(self.p(x)) - def f_hat(self, model: ModelProtocol) -> torch.Tensor: + def f_hat(self, model: AEPsychModelMixin) -> torch.Tensor: """Generate mean predictions from the model over the evaluation grid. Args: - model (TensoModelProtocolr): Model to evaluate. + model (AEPsychModelMixin): Model to evaluate. Returns: torch.Tensor: Posterior mean from underlying model over the evaluation grid. @@ -109,11 +109,11 @@ def p_true(self) -> torch.Tensor: normal_dist = torch.distributions.Normal(0, 1) return normal_dist.cdf(self.f_true) - def p_hat(self, model: ModelProtocol) -> torch.Tensor: + def p_hat(self, model: AEPsychModelMixin) -> torch.Tensor: """Generate mean predictions from the model over the evaluation grid. Args: - model (TensoModelProtocolr): Model to evaluate. + model (AEPsychModelMixin): Model to evaluate. Returns: torch.Tensor: Posterior mean from underlying model over the evaluation grid. @@ -171,9 +171,9 @@ def evaluate( # eval in samp-based expectation over posterior instead of just mean fsamps = model.sample(self.eval_grid, num_samples=1000) try: - psamps = ( - model.sample(self.eval_grid, num_samples=1000, probability_space=True) # type: ignore - ) + psamps = model.sample( + self.eval_grid, num_samples=1000, probability_space=True + ) # type: ignore except ( TypeError ): # vanilla models don't have proba_space samps, TODO maybe we should add them diff --git a/aepsych/generators/acqf_grid_search_generator.py b/aepsych/generators/acqf_grid_search_generator.py index eef1c0dba..446d9033b 100644 --- a/aepsych/generators/acqf_grid_search_generator.py +++ b/aepsych/generators/acqf_grid_search_generator.py @@ -10,12 +10,11 @@ import numpy as np import torch -from aepsych.models.model_protocol import ModelProtocol +from aepsych.generators.grid_eval_acqf_generator import GridEvalAcqfGenerator +from aepsych.models.base import AEPsychModelMixin from aepsych.utils_logging import getLogger from numpy.random import choice -from .grid_eval_acqf_generator import GridEvalAcqfGenerator - logger = getLogger() @@ -25,7 +24,7 @@ class AcqfGridSearchGenerator(GridEvalAcqfGenerator): def _gen( self, num_points: int, - model: ModelProtocol, + model: AEPsychModelMixin, fixed_features: Optional[Dict[int, float]] = None, **gen_options, ) -> torch.Tensor: @@ -34,7 +33,7 @@ def _gen( Args: num_points (int): The number of points to query. - model (ModelProtocol): The fitted model used to evaluate the acquisition function. + model (AEPsychModelMixin): The fitted model used to evaluate the acquisition function. fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values. gen_options (dict): Additional options for generating points, including: - "seed": Random seed for reproducibility. diff --git a/aepsych/generators/acqf_thompson_sampler_generator.py b/aepsych/generators/acqf_thompson_sampler_generator.py index 5dbc4afa5..7c4ae35d6 100644 --- a/aepsych/generators/acqf_thompson_sampler_generator.py +++ b/aepsych/generators/acqf_thompson_sampler_generator.py @@ -10,7 +10,7 @@ import numpy as np import torch -from aepsych.models.model_protocol import ModelProtocol +from aepsych.models.base import AEPsychModelMixin from aepsych.utils_logging import getLogger from numpy.random import choice @@ -25,7 +25,7 @@ class AcqfThompsonSamplerGenerator(GridEvalAcqfGenerator): def _gen( self, num_points: int, - model: ModelProtocol, + model: AEPsychModelMixin, fixed_features: Optional[Dict[int, float]] = None, **gen_options, ) -> torch.Tensor: @@ -34,7 +34,7 @@ def _gen( Args: num_points (int): The number of points to query. - model (ModelProtocol): The fitted model used to evaluate the acquisition function. + model (AEPsychModelMixin): The fitted model used to evaluate the acquisition function. fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values. gen_options (dict): Additional options for generating points, including: - "seed": Random seed for reproducibility. diff --git a/aepsych/generators/base.py b/aepsych/generators/base.py index 8bbb2a5dc..9aeb29984 100644 --- a/aepsych/generators/base.py +++ b/aepsych/generators/base.py @@ -21,8 +21,6 @@ ) from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption -from ..models.model_protocol import ModelProtocol - AEPsychModelType = TypeVar("AEPsychModelType", bound=AEPsychModelMixin) @@ -166,12 +164,14 @@ def _get_acqf_options( return extra_acqf_args - def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFunction: + def _instantiate_acquisition_fn( + self, model: AEPsychModelMixin + ) -> AcquisitionFunction: """ Instantiates the acquisition function with the specified model and additional arguments. Args: - model (ModelProtocol): The model to use with the acquisition function. + model (AEPsychModelMixin): The model to use with the acquisition function. Returns: AcquisitionFunction: Configured acquisition function. @@ -193,6 +193,8 @@ def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFuncti self.acqf_kwargs["ub"] = self.acqf_kwargs["ub"].to(model.device) if self.acqf in self.baseline_requiring_acqfs: + if model.train_inputs is None: + raise ValueError(f"model needs data as a baseline for {self.acqf}") return self.acqf(model, model.train_inputs[0], **self.acqf_kwargs) else: return self.acqf(model=model, **self.acqf_kwargs) diff --git a/aepsych/generators/epsilon_greedy_generator.py b/aepsych/generators/epsilon_greedy_generator.py index a35b9d95c..2a26a4b6b 100644 --- a/aepsych/generators/epsilon_greedy_generator.py +++ b/aepsych/generators/epsilon_greedy_generator.py @@ -9,8 +9,8 @@ import numpy as np import torch +from aepsych.models.base import AEPsychModelMixin -from ..models.model_protocol import ModelProtocol from .base import AEPsychGenerator from .optimize_acqf_generator import OptimizeAcqfGenerator @@ -65,7 +65,7 @@ def get_config_options( def gen( self, num_points: int, - model: ModelProtocol, + model: AEPsychModelMixin, fixed_features: Optional[Dict[int, float]] = None, **kwargs, ) -> torch.Tensor: @@ -73,7 +73,7 @@ def gen( Args: num_points (int): Number of points to query. - model (ModelProtocol): Model to use for generating points. + model (AEPsychModelMixin): Model to use for generating points. fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values. **kwargs: Passed to subgenerator if not exploring """ diff --git a/aepsych/generators/grid_eval_acqf_generator.py b/aepsych/generators/grid_eval_acqf_generator.py index 16a1aadad..215187f51 100644 --- a/aepsych/generators/grid_eval_acqf_generator.py +++ b/aepsych/generators/grid_eval_acqf_generator.py @@ -11,7 +11,7 @@ from aepsych.config import Config from aepsych.generators.base import AcqfGenerator, AEPsychGenerator from aepsych.generators.sobol_generator import SobolGenerator -from aepsych.models.model_protocol import ModelProtocol +from aepsych.models.base import AEPsychModelMixin from aepsych.utils_logging import getLogger from botorch.acquisition import AcquisitionFunction @@ -53,14 +53,14 @@ def __init__( def gen( self, num_points: int, - model: ModelProtocol, + model: AEPsychModelMixin, fixed_features: Optional[Dict[int, float]] = None, **gen_options, ) -> torch.Tensor: """Query next point(s) to run by optimizing the acquisition function. Args: num_points (int): Number of points to query. - model (ModelProtocol): Fitted model of the data. + model (AEPsychModelMixin): Fitted model of the data. Returns: torch.Tensor: Next set of point(s) to evaluate, [num_points x dim]. """ @@ -89,7 +89,7 @@ def gen( def _gen( self, num_points: int, - model: ModelProtocol, + model: AEPsychModelMixin, fixed_features: Optional[Dict[int, float]] = None, **gen_options, ) -> torch.Tensor: @@ -98,7 +98,7 @@ def _gen( def _eval_acqf( self, num_points: int, - model: ModelProtocol, + model: AEPsychModelMixin, fixed_features: Optional[Dict[int, float]] = None, **gen_options, ) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/aepsych/generators/optimize_acqf_generator.py b/aepsych/generators/optimize_acqf_generator.py index fd708b652..6d19897e1 100644 --- a/aepsych/generators/optimize_acqf_generator.py +++ b/aepsych/generators/optimize_acqf_generator.py @@ -14,7 +14,7 @@ from aepsych.acquisition.lookahead import LookaheadAcquisitionFunction from aepsych.config import Config from aepsych.generators.base import AcqfGenerator -from aepsych.models.model_protocol import ModelProtocol +from aepsych.models.base import AEPsychModelMixin from aepsych.utils_logging import getLogger from botorch.acquisition import AcquisitionFunction from botorch.optim import optimize_acqf @@ -60,14 +60,14 @@ def __init__( def gen( self, num_points: int, - model: ModelProtocol, + model: AEPsychModelMixin, fixed_features: Optional[Dict[int, float]] = None, **gen_options, ) -> torch.Tensor: """Query next point(s) to run by optimizing the acquisition function. Args: num_points (int): Number of points to query. - model (ModelProtocol): Fitted model of the data. + model (AEPsychModelMixin): Fitted model of the data. fixed_features (Dict[int, float], optional): The values where the specified parameters should be at when generating. Should be a dictionary where the keys are the indices of the parameters to fix and the values are the @@ -116,7 +116,7 @@ def gen( def _gen( self, num_points: int, - model: ModelProtocol, + model: AEPsychModelMixin, acqf: AcquisitionFunction, fixed_features: Optional[Dict[int, float]] = None, **gen_options: Dict[str, Any], @@ -126,7 +126,7 @@ def _gen( Args: num_points (int): Number of points to query. - model (ModelProtocol): Fitted model of the data. + model (AEPsychModelMixin): Fitted model of the data. acqf (AcquisitionFunction): Acquisition function. fixed_features (Dict[int, float], optional): The values where the specified parameters should be at when generating. Should be a dictionary where diff --git a/aepsych/models/base.py b/aepsych/models/base.py index e0b353fcf..1a62f0fa6 100644 --- a/aepsych/models/base.py +++ b/aepsych/models/base.py @@ -28,9 +28,21 @@ class AEPsychModelMixin(GPyTorchModel, ConfigurableMixin): extremum_solver = "Nelder-Mead" outcome_types: List[str] = [] + stimuli_per_trial: int = 1 + + dim: int _train_inputs: Optional[Tuple[torch.Tensor]] _train_targets: Optional[torch.Tensor] - stimuli_per_trial: int = 1 + + def fit(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs: Any) -> None: + """Fit underlying model. Must be overriden by subclasses. + + Args: + train_x (torch.Tensor): Inputs. + train_y (torch.LongTensor): Responses. + **kwargs: Extra kwargs for fitting the model. + """ + raise NotImplementedError @property def device(self) -> torch.device: @@ -192,7 +204,7 @@ def predict( **kwargs: Keyword arguments for model-specific predict kwargs. Returns: - Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points. + Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points. """ with torch.no_grad(): x = x.to(self.device) @@ -217,7 +229,7 @@ def predict_transform( transformation is applied. Returns: - Tuple[torch.Tensor, torch.Tensor]: Transformed posterior mean and variance at queries points. + Tuple[torch.Tensor, torch.Tensor]: Transformed posterior mean and variance at query points. """ if transformed_posterior_cls is None: return self.predict(x) diff --git a/aepsych/models/gp_classification.py b/aepsych/models/gp_classification.py index 04c952bfd..fa6b8f934 100644 --- a/aepsych/models/gp_classification.py +++ b/aepsych/models/gp_classification.py @@ -90,7 +90,7 @@ def predict( response probability instead of latent function value. Defaults to False. Returns: - Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points. + Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points. """ if not probability_space: @@ -117,7 +117,7 @@ def predict_transform( transformation is applied. Returns: - Tuple[torch.Tensor, torch.Tensor]: Transformed posterior mean and variance at queries points. + Tuple[torch.Tensor, torch.Tensor]: Transformed posterior mean and variance at query points. """ return super().predict_transform( @@ -131,6 +131,6 @@ def predict_probability(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens x (torch.Tensor): Points at which to predict from the model. Returns: - Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points. + Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points. """ return self.predict(x, probability_space=True) diff --git a/aepsych/models/model_protocol.py b/aepsych/models/model_protocol.py deleted file mode 100644 index c08206da0..000000000 --- a/aepsych/models/model_protocol.py +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any, Optional, Protocol - -import torch -from botorch.posteriors import Posterior -from gpytorch.likelihoods import Likelihood - -from .transformed_posteriors import TransformedPosterior - - -class ModelProtocol(Protocol): - @property - def _num_outputs(self) -> int: - pass - - @property - def outcome_type(self) -> str: - pass - - @property - def extremum_solver(self) -> str: - pass - - @property - def train_inputs(self) -> torch.Tensor: - pass - - @property - def dim(self) -> int: - pass - - @property - def device(self) -> torch.device: - pass - - def posterior(self, X: torch.Tensor) -> Posterior: - pass - - def predict(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - pass - - def predict_probability(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - pass - - def predict_transform( - self, - x: torch.Tensor, - transformed_posterior_cls: Optional[type[TransformedPosterior]] = None, - **transform_kwargs, - ): - pass - - @property - def stimuli_per_trial(self) -> int: - pass - - @property - def likelihood(self) -> Likelihood: - pass - - def sample(self, x: torch.Tensor, num_samples: int) -> torch.Tensor: - pass - - def fit(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs: Any) -> None: - pass - - def update( - self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs: Any - ) -> None: - pass diff --git a/aepsych/models/pairwise_probit.py b/aepsych/models/pairwise_probit.py index a88a7d77f..f4eddf441 100644 --- a/aepsych/models/pairwise_probit.py +++ b/aepsych/models/pairwise_probit.py @@ -184,7 +184,7 @@ def predict( rereference (str): How to sample. Options are "x_min", "x_max", "f_min", "f_max". Defaults to "x_min". Returns: - Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points. + Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points. """ if rereference is not None: samps = self.sample(x, num_samples, rereference) @@ -217,7 +217,7 @@ def predict_probability( rereference (str): How to sample. Options are "x_min", "x_max", "f_min", "f_max". Defaults to "x_min". Returns: - Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points. + Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points. """ return self.predict( x, probability_space=True, num_samples=num_samples, rereference=rereference diff --git a/aepsych/models/semi_p.py b/aepsych/models/semi_p.py index b2ac1991d..8ab7c37cd 100644 --- a/aepsych/models/semi_p.py +++ b/aepsych/models/semi_p.py @@ -606,7 +606,7 @@ def predict( response probability instead of latent function value. Defaults to False. Returns: - Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points. + Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points. """ if probability_space: if hasattr(self.likelihood, "objective"): diff --git a/aepsych/models/utils.py b/aepsych/models/utils.py index b29787d5c..9ad89c529 100644 --- a/aepsych/models/utils.py +++ b/aepsych/models/utils.py @@ -11,7 +11,7 @@ import numpy as np import torch -from aepsych.models.model_protocol import ModelProtocol +from aepsych.models.base import AEPsychModelMixin from aepsych.utils import dim_grid, get_jnd_multid, promote_0d from botorch.acquisition import PosteriorMean from botorch.acquisition.objective import ( @@ -159,7 +159,7 @@ def get_extremum( def get_min( - model: ModelProtocol, + model: AEPsychModelMixin, bounds: torch.Tensor, locked_dims: Optional[Mapping[int, float]] = None, probability_space: bool = False, @@ -168,7 +168,7 @@ def get_min( ) -> Tuple[float, torch.Tensor]: """Return the minimum of the modeled function, subject to constraints Args: - model (ModelProtocol): AEPsychModel to get the minimum of. + model (AEPsychModelMixin): AEPsychModel to get the minimum of. bounds (torch.Tensor): Bounds of the space to find the minimum. locked_dims (Mapping[int, float], optional): Dimensions to fix, so that the inverse is along a slice of the full surface. @@ -193,7 +193,7 @@ def get_min( def get_max( - model: ModelProtocol, + model: AEPsychModelMixin, bounds: torch.Tensor, locked_dims: Optional[Mapping[int, float]] = None, probability_space: bool = False, @@ -203,7 +203,7 @@ def get_max( """Return the maximum of the modeled function, subject to constraints Args: - model (ModelProtocol): AEPsychModel to get the maximum of. + model (AEPsychModelMixin): AEPsychModel to get the maximum of. bounds (torch.Tensor): Bounds of the space to find the maximum. locked_dims (Mapping[int, float], optional): Dimensions to fix, so that the inverse is along a slice of the full surface. Defaults to None. @@ -228,7 +228,7 @@ def get_max( def inv_query( - model: ModelProtocol, + model: AEPsychModelMixin, y: Union[float, torch.Tensor], bounds: torch.Tensor, locked_dims: Optional[Mapping[int, float]] = None, @@ -241,7 +241,7 @@ def inv_query( Return nearest x such that f(x) = queried y, and also return the value of f at that point. Args: - model (ModelProtocol): AEPsychModel to get the find the inverse from y. + model (AEPsychModelMixin): AEPsychModel to get the find the inverse from y. y (Union[float, torch.Tensor]): Points at which to find the inverse. bounds (torch.Tensor): Lower and upper bounds of the search space. locked_dims (Mapping[int, float], optional): Dimensions to fix, so that the @@ -288,7 +288,7 @@ def inv_query( def get_jnd( - model: ModelProtocol, + model: AEPsychModelMixin, lb: torch.Tensor, ub: torch.Tensor, dim: int, @@ -311,7 +311,7 @@ def get_jnd( Both definitions are equivalent for linear psychometric functions. Args: - model (ModelProtocol): Model to use for prediction. + model (AEPsychModelMixin): Model to use for prediction. lb (torch.Tensor): Lower bounds of the input space. ub (torch.Tensor): Upper bounds of the input space. dim (int): Dimensionality of the input space. @@ -389,7 +389,7 @@ def get_jnd( def p_below_threshold( - model: ModelProtocol, x: torch.Tensor, f_thresh: torch.Tensor + model: AEPsychModelMixin, x: torch.Tensor, f_thresh: torch.Tensor ) -> torch.Tensor: """Compute the probability that the latent function is below a threshold. @@ -417,7 +417,7 @@ def bernoulli_probit_prob_transform(mean: torch.Tensor, var: torch.Tensor): mean (torch.Tensor): The latent variance of a Bernoulli-probit model evaluated at a set of query points. Returns: - Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points in probability space. + Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points in probability space. """ fmean = mean.squeeze() fvar = var.squeeze() diff --git a/aepsych/strategy/strategy.py b/aepsych/strategy/strategy.py index 097a2f44c..07a33edbd 100644 --- a/aepsych/strategy/strategy.py +++ b/aepsych/strategy/strategy.py @@ -69,7 +69,7 @@ def __init__( of lb and ub. min_total_tells (int): The minimum number of total observations needed to complete this strategy. min_asks (int): The minimum number of points that should be generated from this strategy. - model (ModelProtocol, optional): The AEPsych model of the data. + model (AEPsychModelMixin, optional): The AEPsych model of the data. use_gpu_modeling (bool): Whether to move the model to GPU fitting/predictions, defaults to False. use_gpu_generating (bool): Whether to use the GPU for generating points, defaults to False. refit_every (int): How often to refit the model from scratch. @@ -381,7 +381,7 @@ def predict( probability_space (bool): Whether to return the output in probability space. Defaults to False. Returns: - Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points. + Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points. """ assert self.model is not None, "model is None! Cannot predict without a model!" self.model.to(self.model_device) diff --git a/aepsych/transforms/parameters.py b/aepsych/transforms/parameters.py index ce1b54a12..3e4f6e9dd 100644 --- a/aepsych/transforms/parameters.py +++ b/aepsych/transforms/parameters.py @@ -16,7 +16,6 @@ from aepsych.config import Config, ConfigurableMixin from aepsych.generators.base import AcqfGenerator, AEPsychGenerator from aepsych.models.base import AEPsychModelMixin -from aepsych.models.model_protocol import ModelProtocol from aepsych.transforms.ops import Fixed, Log10Plus, NormalizeScale, Round from aepsych.transforms.ops.base import Transform from aepsych.utils import get_bounds @@ -524,11 +523,11 @@ class ParameterTransformedModel(ParameterTransformWrapper, ConfigurableMixin): untransforms any outputs from the model back to raw parameter space. """ - _base_obj: ModelProtocol + _base_obj: AEPsychModelMixin def __init__( self, - model: Union[Type, ModelProtocol], + model: Union[Type, AEPsychModelMixin], transforms: ChainedInputTransform = ChainedInputTransform(**{}), **kwargs: Any, ) -> None: @@ -547,7 +546,7 @@ def __init__( The object's name will be ParameterTransformed. Args: - model (Union[Type, ModelProtocol]): Model to wrap, this could either be a + model (Union[Type, AEPsychModelMixin]): Model to wrap, this could either be a completely initialized model or just the model class. An initialized model is expected to have been initialized in the transformed parameter space (i.e., bounds are transformed). If a model class is @@ -597,9 +596,7 @@ def wrapper(self, *args, **kwargs) -> torch.Tensor: return wrapper @_promote_1d - def predict( - self, x: torch.Tensor, **kwargs - ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + def predict(self, x: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: """Query the model on its posterior given transformed x. Args: @@ -608,7 +605,7 @@ def predict( **kwargs: Keyword arguments to pass to the model.predict() call. Returns: - Union[Tensor, Tuple[Tensor]]: At least one Tensor will be returned. + Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points. """ x = self.transforms.transform(x) return self._base_obj.predict(x, **kwargs) @@ -616,7 +613,7 @@ def predict( @_promote_1d def predict_probability( self, x: torch.Tensor, **kwargs - ) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """Query the model on its posterior given transformed x and return units in response probability space. @@ -626,7 +623,7 @@ def predict_probability( **kwargs: Keyword arguments to pass to the model.predict() call. Returns: - Union[Tensor, Tuple[Tensor]]: At least one Tensor will be returned. + Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points. """ x = self.transforms.transform(x) return self._base_obj.predict_probability(x, **kwargs)