Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove model protocol #633

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aepsych/acquisition/lookahead_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def posterior_at_xstar_xq(
- Sigma_sq: (b x m) covariance between Xstar and each point in Xq.
"""
# Evaluate posterior and extract needed components
Xq = Xq.to(Xstar)
Xext = torch.cat((Xstar, Xq), dim=-2)
posterior = model.posterior(Xext, posterior_transform=posterior_transform)
mu = posterior.mean[..., :, 0]
Expand Down
16 changes: 8 additions & 8 deletions aepsych/benchmark/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import numpy as np
import torch
from aepsych.models.model_protocol import ModelProtocol
from aepsych.models.base import AEPsychModelMixin
from aepsych.models.utils import p_below_threshold
from aepsych.strategy import SequentialStrategy
from aepsych.utils import make_scaled_sobol
Expand Down Expand Up @@ -78,11 +78,11 @@ def sample_y(
"""
return bernoulli.rvs(self.p(x))

def f_hat(self, model: ModelProtocol) -> torch.Tensor:
def f_hat(self, model: AEPsychModelMixin) -> torch.Tensor:
"""Generate mean predictions from the model over the evaluation grid.

Args:
model (TensoModelProtocolr): Model to evaluate.
model (AEPsychModelMixin): Model to evaluate.

Returns:
torch.Tensor: Posterior mean from underlying model over the evaluation grid.
Expand All @@ -109,11 +109,11 @@ def p_true(self) -> torch.Tensor:
normal_dist = torch.distributions.Normal(0, 1)
return normal_dist.cdf(self.f_true)

def p_hat(self, model: ModelProtocol) -> torch.Tensor:
def p_hat(self, model: AEPsychModelMixin) -> torch.Tensor:
"""Generate mean predictions from the model over the evaluation grid.

Args:
model (TensoModelProtocolr): Model to evaluate.
model (AEPsychModelMixin): Model to evaluate.

Returns:
torch.Tensor: Posterior mean from underlying model over the evaluation grid.
Expand Down Expand Up @@ -171,9 +171,9 @@ def evaluate(
# eval in samp-based expectation over posterior instead of just mean
fsamps = model.sample(self.eval_grid, num_samples=1000)
try:
psamps = (
model.sample(self.eval_grid, num_samples=1000, probability_space=True) # type: ignore
)
psamps = model.sample(
self.eval_grid, num_samples=1000, probability_space=True
) # type: ignore
except (
TypeError
): # vanilla models don't have proba_space samps, TODO maybe we should add them
Expand Down
9 changes: 4 additions & 5 deletions aepsych/generators/acqf_grid_search_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@

import numpy as np
import torch
from aepsych.models.model_protocol import ModelProtocol
from aepsych.generators.grid_eval_acqf_generator import GridEvalAcqfGenerator
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils_logging import getLogger
from numpy.random import choice

from .grid_eval_acqf_generator import GridEvalAcqfGenerator

logger = getLogger()


Expand All @@ -25,7 +24,7 @@ class AcqfGridSearchGenerator(GridEvalAcqfGenerator):
def _gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> torch.Tensor:
Expand All @@ -34,7 +33,7 @@ def _gen(

Args:
num_points (int): The number of points to query.
model (ModelProtocol): The fitted model used to evaluate the acquisition function.
model (AEPsychModelMixin): The fitted model used to evaluate the acquisition function.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
gen_options (dict): Additional options for generating points, including:
- "seed": Random seed for reproducibility.
Expand Down
6 changes: 3 additions & 3 deletions aepsych/generators/acqf_thompson_sampler_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import torch
from aepsych.models.model_protocol import ModelProtocol
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils_logging import getLogger
from numpy.random import choice

Expand All @@ -25,7 +25,7 @@ class AcqfThompsonSamplerGenerator(GridEvalAcqfGenerator):
def _gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> torch.Tensor:
Expand All @@ -34,7 +34,7 @@ def _gen(

Args:
num_points (int): The number of points to query.
model (ModelProtocol): The fitted model used to evaluate the acquisition function.
model (AEPsychModelMixin): The fitted model used to evaluate the acquisition function.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
gen_options (dict): Additional options for generating points, including:
- "seed": Random seed for reproducibility.
Expand Down
14 changes: 8 additions & 6 deletions aepsych/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch
from aepsych.config import Config, ConfigurableMixin
from aepsych.models.base import AEPsychMixin
from aepsych.models.base import AEPsychModelMixin
from botorch.acquisition import (
AcquisitionFunction,
LogNoisyExpectedImprovement,
Expand All @@ -21,9 +21,7 @@
)
from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption

from ..models.model_protocol import ModelProtocol

AEPsychModelType = TypeVar("AEPsychModelType", bound=AEPsychMixin)
AEPsychModelType = TypeVar("AEPsychModelType", bound=AEPsychModelMixin)


@runtime_checkable
Expand Down Expand Up @@ -166,12 +164,14 @@ def _get_acqf_options(

return extra_acqf_args

def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFunction:
def _instantiate_acquisition_fn(
self, model: AEPsychModelMixin
) -> AcquisitionFunction:
"""
Instantiates the acquisition function with the specified model and additional arguments.

Args:
model (ModelProtocol): The model to use with the acquisition function.
model (AEPsychModelMixin): The model to use with the acquisition function.

Returns:
AcquisitionFunction: Configured acquisition function.
Expand All @@ -193,6 +193,8 @@ def _instantiate_acquisition_fn(self, model: ModelProtocol) -> AcquisitionFuncti
self.acqf_kwargs["ub"] = self.acqf_kwargs["ub"].to(model.device)

if self.acqf in self.baseline_requiring_acqfs:
if model.train_inputs is None:
raise ValueError(f"model needs data as a baseline for {self.acqf}")
return self.acqf(model, model.train_inputs[0], **self.acqf_kwargs)
else:
return self.acqf(model=model, **self.acqf_kwargs)
Expand Down
6 changes: 3 additions & 3 deletions aepsych/generators/epsilon_greedy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import numpy as np
import torch
from aepsych.models.base import AEPsychModelMixin

from ..models.model_protocol import ModelProtocol
from .base import AEPsychGenerator
from .optimize_acqf_generator import OptimizeAcqfGenerator

Expand Down Expand Up @@ -65,15 +65,15 @@ def get_config_options(
def gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by sampling from the subgenerator with probability 1-epsilon, and randomly otherwise.

Args:
num_points (int): Number of points to query.
model (ModelProtocol): Model to use for generating points.
model (AEPsychModelMixin): Model to use for generating points.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
**kwargs: Passed to subgenerator if not exploring
"""
Expand Down
10 changes: 5 additions & 5 deletions aepsych/generators/grid_eval_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aepsych.config import Config
from aepsych.generators.base import AcqfGenerator, AEPsychGenerator
from aepsych.generators.sobol_generator import SobolGenerator
from aepsych.models.model_protocol import ModelProtocol
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils_logging import getLogger
from botorch.acquisition import AcquisitionFunction

Expand Down Expand Up @@ -53,14 +53,14 @@ def __init__(
def gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> torch.Tensor:
"""Query next point(s) to run by optimizing the acquisition function.
Args:
num_points (int): Number of points to query.
model (ModelProtocol): Fitted model of the data.
model (AEPsychModelMixin): Fitted model of the data.
Returns:
torch.Tensor: Next set of point(s) to evaluate, [num_points x dim].
"""
Expand Down Expand Up @@ -89,7 +89,7 @@ def gen(
def _gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> torch.Tensor:
Expand All @@ -98,7 +98,7 @@ def _gen(
def _eval_acqf(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
6 changes: 3 additions & 3 deletions aepsych/generators/manual_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from aepsych.config import Config
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import AEPsychMixin
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils import _process_bounds
from torch.quasirandom import SobolEngine

Expand Down Expand Up @@ -53,14 +53,14 @@ def __init__(
def gen(
self,
num_points: int = 1,
model: Optional[AEPsychMixin] = None, # included for API compatibility
model: Optional[AEPsychModelMixin] = None, # included for API compatibility
fixed_features: Optional[Dict[int, float]] = None,
**kwargs, # Ignored
) -> torch.Tensor:
"""Query next point(s) to run by quasi-randomly sampling the parameter space.
Args:
num_points (int): Number of points to query. Defaults to 1.
model (AEPsychMixin, optional): Model to use for generating points. Not used in this generator. Defaults to None.
model (AEPsychModelMixin, optional): Model to use for generating points. Not used in this generator. Defaults to None.
fixed_features (Dict[int, float], optional): Ignored, kept for consistent
API.
**kwargs: Ignored, API compatibility
Expand Down
10 changes: 5 additions & 5 deletions aepsych/generators/optimize_acqf_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from aepsych.acquisition.lookahead import LookaheadAcquisitionFunction
from aepsych.config import Config
from aepsych.generators.base import AcqfGenerator
from aepsych.models.model_protocol import ModelProtocol
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils_logging import getLogger
from botorch.acquisition import AcquisitionFunction
from botorch.optim import optimize_acqf
Expand Down Expand Up @@ -60,14 +60,14 @@ def __init__(
def gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options,
) -> torch.Tensor:
"""Query next point(s) to run by optimizing the acquisition function.
Args:
num_points (int): Number of points to query.
model (ModelProtocol): Fitted model of the data.
model (AEPsychModelMixin): Fitted model of the data.
fixed_features (Dict[int, float], optional): The values where the specified
parameters should be at when generating. Should be a dictionary where
the keys are the indices of the parameters to fix and the values are the
Expand Down Expand Up @@ -116,7 +116,7 @@ def gen(
def _gen(
self,
num_points: int,
model: ModelProtocol,
model: AEPsychModelMixin,
acqf: AcquisitionFunction,
fixed_features: Optional[Dict[int, float]] = None,
**gen_options: Dict[str, Any],
Expand All @@ -126,7 +126,7 @@ def _gen(

Args:
num_points (int): Number of points to query.
model (ModelProtocol): Fitted model of the data.
model (AEPsychModelMixin): Fitted model of the data.
acqf (AcquisitionFunction): Acquisition function.
fixed_features (Dict[int, float], optional): The values where the specified
parameters should be at when generating. Should be a dictionary where
Expand Down
6 changes: 3 additions & 3 deletions aepsych/generators/random_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from aepsych.config import Config
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import AEPsychMixin
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils import _process_bounds


Expand Down Expand Up @@ -38,14 +38,14 @@ def __init__(
def gen(
self,
num_points: int = 1,
model: Optional[AEPsychMixin] = None, # included for API compatibility.
model: Optional[AEPsychModelMixin] = None, # included for API compatibility.
fixed_features: Optional[Dict[int, float]] = None,
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by randomly sampling the parameter space.
Args:
num_points (int): Number of points to query. Currently, only 1 point can be queried at a time.
model (AEPsychMixin, optional): Model to use for generating points. Not used in this generator.
model (AEPsychModelMixin, optional): Model to use for generating points. Not used in this generator.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
**kwargs: Ignored, API compatibility

Expand Down
6 changes: 3 additions & 3 deletions aepsych/generators/sobol_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from aepsych.config import Config
from aepsych.generators.base import AEPsychGenerator
from aepsych.models.base import AEPsychMixin
from aepsych.models.base import AEPsychModelMixin
from aepsych.utils import _process_bounds
from torch.quasirandom import SobolEngine

Expand Down Expand Up @@ -49,14 +49,14 @@ def __init__(
def gen(
self,
num_points: int = 1,
model: Optional[AEPsychMixin] = None, # included for API compatibility
model: Optional[AEPsychModelMixin] = None, # included for API compatibility
fixed_features: Optional[Dict[int, float]] = None,
**kwargs,
) -> torch.Tensor:
"""Query next point(s) to run by quasi-randomly sampling the parameter space.
Args:
num_points (int): Number of points to query. Defaults to 1.
moodel (AEPsychMixin, optional): Model to use for generating points. Not used in this generator. Defaults to None.
moodel (AEPsychModelMixin, optional): Model to use for generating points. Not used in this generator. Defaults to None.
fixed_features: (Dict[int, float], optional): Parameters that are fixed to specific values.
**kwargs: Ignored, API compatibility
Returns:
Expand Down
Loading