Skip to content

Commit

Permalink
Remove model protocol (facebookresearch#633)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#633

The extra model protocol is no longer necessary with a consistent base apsych model class.

Differential Revision: D69216298
  • Loading branch information
JasonKChow authored and facebook-github-bot committed Feb 6, 2025
1 parent aa08b65 commit 30b6475
Show file tree
Hide file tree
Showing 15 changed files with 75 additions and 140 deletions.
16 changes: 8 additions & 8 deletions aepsych/benchmark/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import numpy as np
import torch
from aepsych.models.model_protocol import ModelProtocol
from aepsych.models.base import AEPsychModelMixin
from aepsych.models.utils import p_below_threshold
from aepsych.strategy import SequentialStrategy
from aepsych.utils import make_scaled_sobol
Expand Down Expand Up @@ -78,11 +78,11 @@ def sample_y(
"""
return bernoulli.rvs(self.p(x))

def f_hat(self, model: ModelProtocol) -> torch.Tensor:
def f_hat(self, model: AEPsychModelMixin) -> torch.Tensor:
"""Generate mean predictions from the model over the evaluation grid.
Args:
model (TensoModelProtocolr): Model to evaluate.
model (AEPsychModelMixin): Model to evaluate.
Returns:
torch.Tensor: Posterior mean from underlying model over the evaluation grid.
Expand All @@ -109,11 +109,11 @@ def p_true(self) -> torch.Tensor:
normal_dist = torch.distributions.Normal(0, 1)
return normal_dist.cdf(self.f_true)

def p_hat(self, model: ModelProtocol) -> torch.Tensor:
def p_hat(self, model: AEPsychModelMixin) -> torch.Tensor:
"""Generate mean predictions from the model over the evaluation grid.
Args:
model (TensoModelProtocolr): Model to evaluate.
model (AEPsychModelMixin): Model to evaluate.
Returns:
torch.Tensor: Posterior mean from underlying model over the evaluation grid.
Expand Down Expand Up @@ -171,9 +171,9 @@ def evaluate(
# eval in samp-based expectation over posterior instead of just mean
fsamps = model.sample(self.eval_grid, num_samples=1000)
try:
psamps = (
model.sample(self.eval_grid, num_samples=1000, probability_space=True) # type: ignore
)
psamps = model.sample(
self.eval_grid, num_samples=1000, probability_space=True
) # type: ignore
except (
TypeError
): # vanilla models don't have proba_space samps, TODO maybe we should add them
Expand Down
9 changes: 4 additions & 5 deletions aepsych/generators/acqf_grid_search_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@

import numpy as np
import torch
from aepsych.models.model_protocol import ModelProtocol
from aepsych.generators.grid_eval_acqf_generator import GridEvalAcqfGenerator
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils_logging import getLogger
from numpy.random import choice

from .grid_eval_acqf_generator import GridEvalAcqfGenerator

logger = getLogger()


Expand All @@ -25,7 +24,7 @@ class AcqfGridSearchGenerator(GridEvalAcqfGenerator):
def _gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> torch.Tensor:
Expand All @@ -34,7 +33,7 @@ def _gen(
Args:
num_points (int): The number of points to query.
model (ModelProtocol): The fitted model used to evaluate the acquisition function.
model (AEPsychModelMixin): The fitted model used to evaluate the acquisition function.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
gen_options (dict): Additional options for generating points, including:
- "seed": Random seed for reproducibility.
Expand Down
6 changes: 3 additions & 3 deletions aepsych/generators/acqf_thompson_sampler_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import torch
from aepsych.models.model_protocol import ModelProtocol
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils_logging import getLogger
from numpy.random import choice

Expand All @@ -25,7 +25,7 @@ class AcqfThompsonSamplerGenerator(GridEvalAcqfGenerator):
def _gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> torch.Tensor:
Expand All @@ -34,7 +34,7 @@ def _gen(
Args:
num_points (int): The number of points to query.
model (ModelProtocol): The fitted model used to evaluate the acquisition function.
model (AEPsychModelMixin): The fitted model used to evaluate the acquisition function.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
gen_options (dict): Additional options for generating points, including:
- "seed": Random seed for reproducibility.
Expand Down
10 changes: 6 additions & 4 deletions aepsych/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
)
from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption

from ..models.model_protocol import ModelProtocol

AEPsychModelType = TypeVar("AEPsychModelType", bound=AEPsychModelMixin)


Expand Down Expand Up @@ -166,12 +164,14 @@ def _get_acqf_options(

return extra_acqf_args

def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFunction:
def _instantiate_acquisition_fn(
self, model: AEPsychModelMixin
) -> AcquisitionFunction:
"""
Instantiates the acquisition function with the specified model and additional arguments.
Args:
model (ModelProtocol): The model to use with the acquisition function.
model (AEPsychModelMixin): The model to use with the acquisition function.
Returns:
AcquisitionFunction: Configured acquisition function.
Expand All @@ -193,6 +193,8 @@ def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFuncti
self.acqf_kwargs["ub"] = self.acqf_kwargs["ub"].to(model.device)

if self.acqf in self.baseline_requiring_acqfs:
if model.train_inputs is None:
raise ValueError(f"model needs data as a baseline for {self.acqf}")
return self.acqf(model, model.train_inputs[0], **self.acqf_kwargs)
else:
return self.acqf(model=model, **self.acqf_kwargs)
Expand Down
6 changes: 3 additions & 3 deletions aepsych/generators/epsilon_greedy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import numpy as np
import torch
from aepsych.models.base import AEPsychModelMixin

from ..models.model_protocol import ModelProtocol
from .base import AEPsychGenerator
from .optimize_acqf_generator import OptimizeAcqfGenerator

Expand Down Expand Up @@ -65,15 +65,15 @@ def get_config_options(
def gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by sampling from the subgenerator with probability 1-epsilon, and randomly otherwise.
Args:
num_points (int): Number of points to query.
model (ModelProtocol): Model to use for generating points.
model (AEPsychModelMixin): Model to use for generating points.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
**kwargs: Passed to subgenerator if not exploring
"""
Expand Down
10 changes: 5 additions & 5 deletions aepsych/generators/grid_eval_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aepsych.config import Config
from aepsych.generators.base import AcqfGenerator, AEPsychGenerator
from aepsych.generators.sobol_generator import SobolGenerator
from aepsych.models.model_protocol import ModelProtocol
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils_logging import getLogger
from botorch.acquisition import AcquisitionFunction

Expand Down Expand Up @@ -53,14 +53,14 @@ def __init__(
def gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> torch.Tensor:
"""Query next point(s) to run by optimizing the acquisition function.
Args:
num_points (int): Number of points to query.
model (ModelProtocol): Fitted model of the data.
model (AEPsychModelMixin): Fitted model of the data.
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
"""
Expand Down Expand Up @@ -89,7 +89,7 @@ def gen(
def _gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> torch.Tensor:
Expand All @@ -98,7 +98,7 @@ def _gen(
def _eval_acqf(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
10 changes: 5 additions & 5 deletions aepsych/generators/optimize_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from aepsych.acquisition.lookahead import LookaheadAcquisitionFunction
from aepsych.config import Config
from aepsych.generators.base import AcqfGenerator
from aepsych.models.model_protocol import ModelProtocol
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils_logging import getLogger
from botorch.acquisition import AcquisitionFunction
from botorch.optim import optimize_acqf
Expand Down Expand Up @@ -60,14 +60,14 @@ def __init__(
def gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> torch.Tensor:
"""Query next point(s) to run by optimizing the acquisition function.
Args:
num_points (int): Number of points to query.
model (ModelProtocol): Fitted model of the data.
model (AEPsychModelMixin): Fitted model of the data.
fixed_features (Dict[int, float], optional): The values where the specified
parameters should be at when generating. Should be a dictionary where
the keys are the indices of the parameters to fix and the values are the
Expand Down Expand Up @@ -116,7 +116,7 @@ def gen(
def _gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
acqf: AcquisitionFunction,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options: Dict[str, Any],
Expand All @@ -126,7 +126,7 @@ def _gen(
Args:
num_points (int): Number of points to query.
model (ModelProtocol): Fitted model of the data.
model (AEPsychModelMixin): Fitted model of the data.
acqf (AcquisitionFunction): Acquisition function.
fixed_features (Dict[int, float], optional): The values where the specified
parameters should be at when generating. Should be a dictionary where
Expand Down
18 changes: 15 additions & 3 deletions aepsych/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,21 @@ class AEPsychModelMixin(GPyTorchModel, ConfigurableMixin):

extremum_solver = "Nelder-Mead"
outcome_types: List[str] = []
stimuli_per_trial: int = 1

dim: int
_train_inputs: Optional[Tuple[torch.Tensor]]
_train_targets: Optional[torch.Tensor]
stimuli_per_trial: int = 1

def fit(self, train_x: torch.Tensor, train_y: torch.Tensor, **kwargs: Any) -> None:
"""Fit underlying model. Must be overriden by subclasses.
Args:
train_x (torch.Tensor): Inputs.
train_y (torch.LongTensor): Responses.
**kwargs: Extra kwargs for fitting the model.
"""
raise NotImplementedError

@property
def device(self) -> torch.device:
Expand Down Expand Up @@ -192,7 +204,7 @@ def predict(
**kwargs: Keyword arguments for model-specific predict kwargs.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points.
Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points.
"""
with torch.no_grad():
x = x.to(self.device)
Expand All @@ -217,7 +229,7 @@ def predict_transform(
transformation is applied.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Transformed posterior mean and variance at queries points.
Tuple[torch.Tensor, torch.Tensor]: Transformed posterior mean and variance at query points.
"""
if transformed_posterior_cls is None:
return self.predict(x)
Expand Down
6 changes: 3 additions & 3 deletions aepsych/models/gp_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def predict(
response probability instead of latent function value. Defaults to False.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points.
Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points.
"""

if not probability_space:
Expand All @@ -117,7 +117,7 @@ def predict_transform(
transformation is applied.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Transformed posterior mean and variance at queries points.
Tuple[torch.Tensor, torch.Tensor]: Transformed posterior mean and variance at query points.
"""

return super().predict_transform(
Expand All @@ -131,6 +131,6 @@ def predict_probability(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
x (torch.Tensor): Points at which to predict from the model.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at queries points.
Tuple[torch.Tensor, torch.Tensor]: Posterior mean and variance at query points.
"""
return self.predict(x, probability_space=True)
Loading

0 comments on commit 30b6475

Please sign in to comment.