From c27a5f805ba7f9625600d002d2cb730050708b51 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Mon, 4 Mar 2024 07:49:34 -0800 Subject: [PATCH] add output_tasks to MTGP in MBM (#2241) Summary: see title. output_tasks were not set for MTGP in MBM. Differential Revision: D54453253 --- ax/models/torch/botorch_modular/surrogate.py | 29 ++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 27f9ac9bde0..189ba16a193 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -49,6 +49,7 @@ ) from botorch.models.model import Model from botorch.models.model_list_gp_regression import ModelListGP +from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP from botorch.models.transforms.input import ( ChainedInputTransform, @@ -59,6 +60,7 @@ from botorch.utils.containers import SliceContainer from botorch.utils.datasets import RankingDataset, SupervisedDataset from botorch.utils.dispatcher import Dispatcher +from botorch_fb.models.heterogeneous_mtgp import HeterogeneousMTGP from gpytorch.kernels import Kernel from gpytorch.likelihoods.likelihood import Likelihood from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood @@ -776,6 +778,10 @@ def outcomes(self, value: List[str]) -> None: ) +# Note: although `HeterogeneousMTGP`` is a subclass of `MultiTaskGP`, +# it takes different inputs than a `MultiTaskGP`` and it's +# datasets can be heterogeneous, which means `dataset.X` +# will fail since it tries to concatenate heterogeneous datasets. @submodel_input_constructor.register(Model) def _submodel_input_constructor_base( botorch_model_class: Type[Model], @@ -841,3 +847,26 @@ def _submodel_input_constructor_base( botorch_model_class_args=botorch_model_class_args, ) return formatted_model_inputs + + +@submodel_input_constructor.register(MultiTaskGP) +def _submodel_input_constructor_mtgp( + botorch_model_class: Type[Model], + dataset: SupervisedDataset, + search_space_digest: SearchSpaceDigest, + surrogate: Surrogate, +) -> Dict[str, Any]: + if len(dataset.outcome_names) > 1: + raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.") + formatted_model_inputs = _submodel_input_constructor_base( + botorch_model_class=botorch_model_class, + dataset=dataset, + search_space_digest=search_space_digest, + surrogate=surrogate, + ) + # specify output tasks so that model.num_outputs = 1 + # since the model only models a single outcome + formatted_model_inputs["output_tasks"] = dataset.X[ + :1, formatted_model_inputs["task_feature"] + ].tolist() + return formatted_model_inputs