Skip to content

Commit

Permalink
add output_tasks to MTGP in MBM (#2241)
Browse files Browse the repository at this point in the history
Summary:

see title. output_tasks were not set for MTGP in MBM.

Differential Revision: D54453253
  • Loading branch information
sdaulton authored and facebook-github-bot committed Mar 4, 2024
1 parent b98f3be commit c27a5f8
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
)
from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP
from botorch.models.transforms.input import (
ChainedInputTransform,
Expand All @@ -59,6 +60,7 @@
from botorch.utils.containers import SliceContainer
from botorch.utils.datasets import RankingDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher
from botorch_fb.models.heterogeneous_mtgp import HeterogeneousMTGP
from gpytorch.kernels import Kernel
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
Expand Down Expand Up @@ -776,6 +778,10 @@ def outcomes(self, value: List[str]) -> None:
)


# Note: although `HeterogeneousMTGP`` is a subclass of `MultiTaskGP`,
# it takes different inputs than a `MultiTaskGP`` and it's
# datasets can be heterogeneous, which means `dataset.X`
# will fail since it tries to concatenate heterogeneous datasets.
@submodel_input_constructor.register(Model)
def _submodel_input_constructor_base(
botorch_model_class: Type[Model],
Expand Down Expand Up @@ -841,3 +847,26 @@ def _submodel_input_constructor_base(
botorch_model_class_args=botorch_model_class_args,
)
return formatted_model_inputs


@submodel_input_constructor.register(MultiTaskGP)
def _submodel_input_constructor_mtgp(
botorch_model_class: Type[Model],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
surrogate: Surrogate,
) -> Dict[str, Any]:
if len(dataset.outcome_names) > 1:
raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.")
formatted_model_inputs = _submodel_input_constructor_base(
botorch_model_class=botorch_model_class,
dataset=dataset,
search_space_digest=search_space_digest,
surrogate=surrogate,
)
# specify output tasks so that model.num_outputs = 1
# since the model only models a single outcome
formatted_model_inputs["output_tasks"] = dataset.X[
:1, formatted_model_inputs["task_feature"]
].tolist()
return formatted_model_inputs

0 comments on commit c27a5f8

Please sign in to comment.