From e93f459cd9559b163bace64e6021a180273da992 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Mon, 26 Aug 2024 04:40:22 -0700 Subject: [PATCH] Run Torchvision problems with Benchmark[Problem/Runner/Metric]; consolidate PyTorchCNN problems (#2688) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2688 Context: Consolidating on common abstractions will make it easier to add functionality to all classes. Also, this will make the code smaller and easier to navigate. This diff is substantially LOC-negative outside of tests. This PR: * Updates functionality in torchvision.py that replaces MNIST datasets with fakes when run in a test environment; the data sets are now realistic enough to be usable. * Merge the functionality in `pytorch_cnn.py` into `torchvision.py`; it was only ever used in order to support `torchvision.py`. * Remove `PyTorchCNNTorchvisionBenchmarkProblem` and its special serialization logic; it is replaced with `BenchmarkProblem` * Remove `PyTorchCNNTorchvisionRunner`, `PyTorchCNNBenchmarkProblem`, `PyTorchCNNMetric`, and `PyTorchCNNRunner` * Introduce `PyTorchCNNTorchvisionParamBasedProblem`. It does not need special serialization logic because it is a dataclass with datasets constructed in the `__post_init__`. When an instance is serialized, the data sets are not serialized; they are reconstructed when the instance is decoded. Using a dataclass here also allows for an automatic and more rigorous equality check. * Use `BenchmarkRunner`; as per D61483962, this means that this problem now has a ground truth, which won't change its behavior since it doesn't have noise added. Differential Revision: D61414680 Reviewed By: Balandat --- ax/benchmark/problems/hpo/pytorch_cnn.py | 230 --------------- ax/benchmark/problems/hpo/torchvision.py | 276 ++++++++++++------ ax/benchmark/problems/registry.py | 8 +- .../tests/problems/hpo/test_torchvision.py | 100 +++++++ .../tests/problems/test_problem_storage.py | 27 -- ax/benchmark/tests/test_benchmark.py | 25 +- ax/storage/json_store/decoder.py | 8 - ax/storage/json_store/encoders.py | 15 - ax/storage/json_store/registry.py | 14 +- .../json_store/tests/test_json_store.py | 11 + ax/utils/testing/benchmark_stubs.py | 27 ++ 11 files changed, 353 insertions(+), 388 deletions(-) delete mode 100644 ax/benchmark/problems/hpo/pytorch_cnn.py create mode 100644 ax/benchmark/tests/problems/hpo/test_torchvision.py delete mode 100644 ax/benchmark/tests/problems/test_problem_storage.py diff --git a/ax/benchmark/problems/hpo/pytorch_cnn.py b/ax/benchmark/problems/hpo/pytorch_cnn.py deleted file mode 100644 index bbc89c3ae1c..00000000000 --- a/ax/benchmark/problems/hpo/pytorch_cnn.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from collections.abc import Iterable -from typing import Any - -import pandas as pd -import torch -from ax.benchmark.benchmark_problem import BenchmarkProblem -from ax.core.base_trial import BaseTrial, TrialStatus -from ax.core.data import Data -from ax.core.metric import Metric, MetricFetchE, MetricFetchResult -from ax.core.objective import Objective -from ax.core.optimization_config import OptimizationConfig -from ax.core.parameter import ParameterType, RangeParameter -from ax.core.runner import Runner -from ax.core.search_space import SearchSpace -from ax.utils.common.base import Base -from ax.utils.common.equality import equality_typechecker -from ax.utils.common.result import Err, Ok -from torch import nn, optim, Tensor -from torch.nn import functional as F -from torch.utils.data import DataLoader, Dataset - - -class PyTorchCNNBenchmarkProblem(BenchmarkProblem): - @equality_typechecker - def __eq__(self, other: Base) -> bool: - if not isinstance(other, PyTorchCNNBenchmarkProblem): - return False - - # Checking the whole datasets' equality here would be too expensive to be - # worth it; just check names instead - return self.name == other.name - - @classmethod - def from_datasets( - cls, - name: str, - num_trials: int, - train_set: Dataset, - test_set: Dataset, - ) -> "PyTorchCNNBenchmarkProblem": - optimal_value = 1.0 - - search_space = SearchSpace( - parameters=[ - RangeParameter( - name="lr", parameter_type=ParameterType.FLOAT, lower=1e-6, upper=0.4 - ), - RangeParameter( - name="momentum", - parameter_type=ParameterType.FLOAT, - lower=0, - upper=1, - ), - RangeParameter( - name="weight_decay", - parameter_type=ParameterType.FLOAT, - lower=0, - upper=1, - ), - RangeParameter( - name="step_size", - parameter_type=ParameterType.INT, - lower=1, - upper=100, - ), - RangeParameter( - name="gamma", - parameter_type=ParameterType.FLOAT, - lower=0, - upper=1, - ), - ] - ) - optimization_config = OptimizationConfig( - objective=Objective( - metric=PyTorchCNNMetric(), - minimize=False, - ) - ) - - runner = PyTorchCNNRunner(name=name, train_set=train_set, test_set=test_set) - - return cls( - name=f"HPO_PyTorchCNN_{name}", - optimal_value=optimal_value, - search_space=search_space, - optimization_config=optimization_config, - runner=runner, - num_trials=num_trials, - is_noiseless=False, - observe_noise_stds=False, - has_ground_truth=False, - ) - - -class PyTorchCNNMetric(Metric): - def __init__(self) -> None: - super().__init__(name="accuracy") - - def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult: - try: - accuracy = [ - trial.run_metadata["accuracy"][name] - for name, arm in trial.arms_by_name.items() - ] - df = pd.DataFrame( - { - "arm_name": list(trial.arms_by_name.keys()), - "metric_name": self.name, - "mean": accuracy, - "sem": None, - "trial_index": trial.index, - } - ) - - return Ok(value=Data(df=df)) - - except Exception as e: - return Err( - value=MetricFetchE( - message=f"Failed to fetch {self.name} for trial {trial}", - exception=e, - ) - ) - - -class PyTorchCNNRunner(Runner): - def __init__(self, name: str, train_set: Dataset, test_set: Dataset) -> None: - self.name = name - self.train_loader: DataLoader = DataLoader(train_set) - self.test_loader: DataLoader = DataLoader(test_set) - self.results: dict[int, float] = {} - self.statuses: dict[int, TrialStatus] = {} - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - class CNN(nn.Module): - def __init__(self) -> None: - super().__init__() - self.conv1 = nn.Conv2d(1, 20, kernel_size=5, stride=1) - self.fc1 = nn.Linear(8 * 8 * 20, 64) - self.fc2 = nn.Linear(64, 10) - - def forward(self, x: Tensor) -> Tensor: - x = F.relu(self.conv1(x)) - x = F.max_pool2d(x, 3, 3) - x = x.view(-1, 8 * 8 * 20) - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return F.log_softmax(x, dim=-1) - - def train_and_evaluate( - self, - lr: float, - momentum: float, - weight_decay: float, - step_size: int, - gamma: float, - ) -> float: - net = self.CNN() - net.to(device=self.device) - - # Train - net.train() - criterion = nn.NLLLoss(reduction="sum") - optimizer = optim.SGD( - net.parameters(), - lr=lr, - momentum=momentum, - weight_decay=weight_decay, - ) - - scheduler = optim.lr_scheduler.StepLR( - optimizer, step_size=step_size, gamma=gamma - ) - - for inputs, labels in self.train_loader: - inputs = inputs.to(device=self.device) - labels = labels.to(device=self.device) - - # zero the parameter gradients - optimizer.zero_grad() - - # forward + backward + optimize - outputs = net(inputs) - loss = criterion(outputs, labels) - loss.backward() - optimizer.step() - scheduler.step() - - # Evaluate - net.eval() - correct = 0 - total = 0 - with torch.no_grad(): - for inputs, labels in self.test_loader: - outputs = net(inputs) - _, predicted = torch.max(outputs.data, 1) - total += labels.size(0) - correct += (predicted == labels).sum().item() - - return correct / total - - def run(self, trial: BaseTrial) -> dict[str, Any]: - self.statuses[trial.index] = TrialStatus.RUNNING - - self.statuses[trial.index] = TrialStatus.COMPLETED - return { - "accuracy": { - arm.name: self.train_and_evaluate( - lr=arm.parameters["lr"], # pyre-ignore[6] - momentum=arm.parameters["momentum"], # pyre-ignore[6] - weight_decay=arm.parameters["weight_decay"], # pyre-ignore[6] - step_size=arm.parameters["step_size"], # pyre-ignore[6] - gamma=arm.parameters["gamma"], # pyre-ignore[6] - ) - for arm in trial.arms - } - } - - def poll_trial_status( - self, trials: Iterable[BaseTrial] - ) -> dict[TrialStatus, set[int]]: - return {TrialStatus.COMPLETED: {t.index for t in trials}} diff --git a/ax/benchmark/problems/hpo/torchvision.py b/ax/benchmark/problems/hpo/torchvision.py index a832406e697..993376acfeb 100644 --- a/ax/benchmark/problems/hpo/torchvision.py +++ b/ax/benchmark/problems/hpo/torchvision.py @@ -5,17 +5,25 @@ # pyre-strict -import os -from typing import Any, Optional - -from ax.benchmark.problems.hpo.pytorch_cnn import ( - PyTorchCNNBenchmarkProblem, - PyTorchCNNRunner, +from dataclasses import dataclass, field, InitVar +from functools import lru_cache +from typing import Mapping + +import torch +from ax.benchmark.benchmark_problem import ( + BenchmarkProblem, + get_soo_config_and_outcome_names, +) +from ax.benchmark.runners.botorch_test import ( + ParamBasedTestProblem, + ParamBasedTestProblemRunner, ) +from ax.core.parameter import ParameterType, RangeParameter +from ax.core.search_space import SearchSpace from ax.exceptions.core import UserInputError -from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry -from ax.utils.common.typeutils import checked_cast -from torch.utils.data import TensorDataset +from torch import nn, optim, Tensor +from torch.nn import functional as F +from torch.utils.data import DataLoader try: # We don't require TorchVision by default. from torchvision import datasets, transforms @@ -25,15 +33,6 @@ "FashionMNIST": datasets.FashionMNIST, } - if os.environ.get("TESTENV"): - # If we are in the test environment do not download any torchvision datasets. - # Instead, we use an empty TensorDataset - def get_dummy_dataset(**kwargs: dict[str, Any]) -> TensorDataset: - return TensorDataset() - - # pyre-ignore[9] We are replacing a type with a function - _REGISTRY = {key: get_dummy_dataset for key in _REGISTRY.keys()} - except ModuleNotFoundError: transforms = None @@ -41,19 +40,104 @@ def get_dummy_dataset(**kwargs: dict[str, Any]) -> TensorDataset: _REGISTRY = {} -class PyTorchCNNTorchvisionBenchmarkProblem(PyTorchCNNBenchmarkProblem): - @classmethod - def from_dataset_name( - cls, - name: str, - num_trials: int, - ) -> "PyTorchCNNTorchvisionBenchmarkProblem": - if name not in _REGISTRY: +class CNN(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.Conv2d(1, 20, kernel_size=5, stride=1) + self.fc1 = nn.Linear(8 * 8 * 20, 64) + self.fc2 = nn.Linear(64, 10) + + def forward(self, x: Tensor) -> Tensor: + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, 3, 3) + x = x.view(-1, 8 * 8 * 20) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=-1) + + +@lru_cache +def train_and_evaluate( + lr: float, + momentum: float, + weight_decay: float, + step_size: int, + gamma: float, + device: torch.device, + train_loader: DataLoader, + test_loader: DataLoader, +) -> torch.Tensor: + """Return the fraction of correctly classified test examples.""" + net = CNN() + net.to(device=device) + + # Train + net.train() + criterion = nn.NLLLoss(reduction="sum") + optimizer = optim.SGD( + net.parameters(), + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + ) + + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) + + for inputs, labels in train_loader: + inputs = inputs.to(device=device) + labels = labels.to(device=device) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + scheduler.step() + + # Evaluate + net.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for inputs, labels in test_loader: + outputs = net(inputs) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum() + + return correct / total + + +@dataclass(kw_only=True) +class PyTorchCNNTorchvisionParamBasedProblem(ParamBasedTestProblem): + name: str # The name of the dataset to load -- MNIST or FashionMNIST + num_objectives: int = 1 + optimal_value: float = 1.0 + device: torch.device = field( + default_factory=lambda: torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + ) + negate: bool = False + # Using `InitVar` prevents the DataLoaders from being serialized; instead + # they are reconstructed upon deserialization. + # Pyre doesn't understand InitVars. + # pyre-ignore: Undefined attribute [16]: `typing.Type` has no attribute + # `train_loader` + train_loader: InitVar[DataLoader | None] = None + # pyre-ignore + test_loader: InitVar[DataLoader | None] = None + + def __post_init__(self, train_loader: None, test_loader: None) -> None: + if self.name not in _REGISTRY: raise UserInputError( - f"Unrecognized torchvision dataset {name}. Please ensure it is listed " - "in PyTorchCNNTorchvisionBenchmarkProblem registry." + f"Unrecognized torchvision dataset {self.name}. Please ensure it " + "is listed in PyTorchCNNTorchvisionBenchmarkProblem registry." ) - dataset_fn = _REGISTRY[name] + dataset_fn = _REGISTRY[self.name] train_set = dataset_fn( root="./data", @@ -68,66 +152,80 @@ def from_dataset_name( download=True, transform=transforms.ToTensor(), ) - - problem = cls.from_datasets( - name=name, - num_trials=num_trials, - train_set=train_set, - test_set=test_set, - ) - runner = PyTorchCNNTorchvisionRunner( - name=name, train_set=train_set, test_set=test_set - ) - - return cls( - name=f"HPO_PyTorchCNN_Torchvision::{name}", - search_space=problem.search_space, - optimization_config=problem.optimization_config, - runner=runner, - num_trials=num_trials, - is_noiseless=False, - observe_noise_stds=False, - has_ground_truth=False, - optimal_value=problem.optimal_value, - ) - - -class PyTorchCNNTorchvisionRunner(PyTorchCNNRunner): - """ - A subclass to aid in serialization. This allows us to save only the name of the - dataset and reload it from TorchVision at deserialization time. - """ - - @classmethod - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - def serialize_init_args(cls, obj: Any) -> dict[str, Any]: - pytorch_cnn_runner = checked_cast(PyTorchCNNRunner, obj) - - return {"name": pytorch_cnn_runner.name} - - @classmethod - def deserialize_init_args( - cls, - args: dict[str, Any], - decoder_registry: Optional[TDecoderRegistry] = None, - class_decoder_registry: Optional[TClassDecoderRegistry] = None, - ) -> dict[str, Any]: - name = args["name"] - - dataset_fn = _REGISTRY[name] - - train_set = dataset_fn( - root="./data", - train=True, - download=True, - transform=transforms.ToTensor(), + # pyre-fixme: Undefined attribute [16]: + # `PyTorchCNNTorchvisionParamBasedProblem` has no attribute + # `train_loader`. + self.train_loader = DataLoader(train_set, num_workers=1) + # pyre-fixme + self.test_loader = DataLoader(test_set, num_workers=1) + + # pyre-fixme[14]: Inconsistent override (super class takes a more general + # type, TParameterization) + def evaluate_true(self, params: Mapping[str, int | float]) -> Tensor: + return train_and_evaluate( + **params, + device=self.device, + train_loader=self.train_loader, + test_loader=self.test_loader, ) - test_set = dataset_fn( - root="./data", - train=False, - download=True, - transform=transforms.ToTensor(), - ) - return {"name": name, "train_set": train_set, "test_set": test_set} +def get_pytorch_cnn_torchvision_benchmark_problem( + name: str, + num_trials: int, +) -> BenchmarkProblem: + base_problem = PyTorchCNNTorchvisionParamBasedProblem(name=name) + + search_space = SearchSpace( + parameters=[ + RangeParameter( + name="lr", parameter_type=ParameterType.FLOAT, lower=1e-6, upper=0.4 + ), + RangeParameter( + name="momentum", + parameter_type=ParameterType.FLOAT, + lower=0, + upper=1, + ), + RangeParameter( + name="weight_decay", + parameter_type=ParameterType.FLOAT, + lower=0, + upper=1, + ), + RangeParameter( + name="step_size", + parameter_type=ParameterType.INT, + lower=1, + upper=100, + ), + RangeParameter( + name="gamma", + parameter_type=ParameterType.FLOAT, + lower=0, + upper=1, + ), + ] + ) + optimization_config, outcome_names = get_soo_config_and_outcome_names( + num_constraints=0, + lower_is_better=False, + observe_noise_sd=False, + objective_name="accuracy", + ) + runner = ParamBasedTestProblemRunner( + test_problem_class=PyTorchCNNTorchvisionParamBasedProblem, + test_problem_kwargs={"name": name}, + outcome_names=outcome_names, + ) + return BenchmarkProblem( + name=f"HPO_PyTorchCNN_Torchvision::{name}", + search_space=search_space, + optimization_config=optimization_config, + num_trials=num_trials, + observe_noise_stds=False, + is_noiseless=True, + has_ground_truth=True, + optimal_value=base_problem.optimal_value, + runner=runner, + ) diff --git a/ax/benchmark/problems/registry.py b/ax/benchmark/problems/registry.py index 10575fd1756..651d2244a03 100644 --- a/ax/benchmark/problems/registry.py +++ b/ax/benchmark/problems/registry.py @@ -15,7 +15,9 @@ create_single_objective_problem_from_botorch, ) from ax.benchmark.problems.hd_embedding import embed_higher_dimension -from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionBenchmarkProblem +from ax.benchmark.problems.hpo.torchvision import ( + get_pytorch_cnn_torchvision_benchmark_problem, +) from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_benchmark_problem from botorch.test_functions import synthetic from botorch.test_functions.multi_objective import BraninCurrin @@ -113,14 +115,14 @@ class BenchmarkProblemRegistryEntry: factory_kwargs={"n": 30, "num_trials": 25}, ), "hpo_pytorch_cnn_MNIST": BenchmarkProblemRegistryEntry( - factory_fn=PyTorchCNNTorchvisionBenchmarkProblem.from_dataset_name, + factory_fn=get_pytorch_cnn_torchvision_benchmark_problem, factory_kwargs={ "name": "MNIST", "num_trials": 20, }, ), "hpo_pytorch_cnn_FashionMNIST": BenchmarkProblemRegistryEntry( - factory_fn=PyTorchCNNTorchvisionBenchmarkProblem.from_dataset_name, + factory_fn=get_pytorch_cnn_torchvision_benchmark_problem, factory_kwargs={ "name": "FashionMNIST", "num_trials": 50, diff --git a/ax/benchmark/tests/problems/hpo/test_torchvision.py b/ax/benchmark/tests/problems/hpo/test_torchvision.py new file mode 100644 index 00000000000..bed29b8da0c --- /dev/null +++ b/ax/benchmark/tests/problems/hpo/test_torchvision.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from random import choice +from unittest.mock import MagicMock, patch + +from ax.benchmark.benchmark_problem import BenchmarkProblem + +from ax.benchmark.problems.hpo.torchvision import CNN +from ax.benchmark.problems.registry import get_problem +from ax.core.arm import Arm +from ax.core.trial import Trial +from ax.utils.common.testutils import TestCase +from ax.utils.testing.benchmark_stubs import TestDataset + + +class TestPyTorchCNNTorchvision(TestCase): + def setUp(self) -> None: + self.parameters = { + "lr": 1e-1, + "momentum": 0.5, + "weight_decay": 0.5, + "step_size": 10, + "gamma": 0.5, + } + super().setUp() + + def test_problem_properties(self) -> None: + num_trials = 173 + + with patch.dict( + "ax.benchmark.problems.hpo.torchvision._REGISTRY", + {"MNIST": TestDataset, "FashionMNIST": TestDataset}, + ): + + self.assertEqual( + get_problem(problem_name="hpo_pytorch_cnn_MNIST").name, + "HPO_PyTorchCNN_Torchvision::MNIST", + ) + problem = get_problem( + problem_name="hpo_pytorch_cnn_FashionMNIST", num_trials=num_trials + ) + + self.assertEqual(problem.name, "HPO_PyTorchCNN_Torchvision::FashionMNIST") + self.assertIsInstance(problem, BenchmarkProblem) + self.assertEqual(problem.optimal_value, 1.0) + self.assertSetEqual( + set(problem.search_space.parameters.keys()), + {"lr", "momentum", "weight_decay", "step_size", "gamma"}, + ) + self.assertFalse(problem.optimization_config.objective.minimize) + self.assertEqual(problem.num_trials, num_trials) + self.assertTrue(problem.is_noiseless) + self.assertFalse(problem.observe_noise_stds) + self.assertTrue(problem.has_ground_truth) + + def test_deterministic(self) -> None: + problem_name = choice(["MNIST", "FashionMNIST"]) + with patch.dict( + "ax.benchmark.problems.hpo.torchvision._REGISTRY", + {problem_name: TestDataset}, + ): + problem = get_problem(problem_name=f"hpo_pytorch_cnn_{problem_name}") + # pyre-fixme[6]: complaining that the annotation for Arm.parameters is + # too broad because it's not immutable + arm = Arm(parameters=self.parameters, name="0") + trial = Trial(experiment=MagicMock()).add_arm(arm=arm) + + result = problem.runner.run(trial=trial) + expected = 0.21875 + self.assertEqual( + result, + { + "Ys": {"0": [expected]}, + "Ystds": {"0": [0.0]}, + "Ys_true": {"0": [expected]}, + "outcome_names": ["accuracy"], + }, + ) + + with self.subTest("test caching"): + with patch( + "ax.benchmark.problems.hpo.torchvision.CNN", + wraps=CNN, + ) as mock_CNN: + problem.runner.run(trial=trial) + mock_CNN.assert_not_called() + + other_trial = Trial(experiment=MagicMock()).add_arm( + arm=Arm(parameters={**self.parameters, "lr": 0.9}, name="1") + ) + with patch( + "ax.benchmark.problems.hpo.torchvision.CNN", wraps=CNN + ) as mock_CNN: + problem.runner.run(trial=other_trial) + mock_CNN.assert_called_once() diff --git a/ax/benchmark/tests/problems/test_problem_storage.py b/ax/benchmark/tests/problems/test_problem_storage.py deleted file mode 100644 index 60ececb6512..00000000000 --- a/ax/benchmark/tests/problems/test_problem_storage.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionBenchmarkProblem -from ax.storage.json_store.decoder import object_from_json -from ax.storage.json_store.encoder import object_to_json -from ax.utils.common.testutils import TestCase - - -class TestProblems(TestCase): - def test_torchvision_encode_decode(self) -> None: - original_object = PyTorchCNNTorchvisionBenchmarkProblem.from_dataset_name( - name="MNIST", num_trials=50 - ) - - json_object = object_to_json( - original_object, - ) - converted_object = object_from_json( - json_object, - ) - - self.assertEqual(original_object, converted_object) diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index 366c5611122..fb95644c786 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -42,12 +42,13 @@ get_single_objective_benchmark_problem, get_sobol_benchmark_method, get_soo_surrogate, + TestDataset, ) from ax.utils.testing.core_stubs import get_experiment from ax.utils.testing.mock import fast_botorch_optimize from botorch.acquisition.logei import qLogNoisyExpectedImprovement -from botorch.acquisition.multi_objective.monte_carlo import ( - qNoisyExpectedHypervolumeImprovement, +from botorch.acquisition.multi_objective.logei import ( + qLogNoisyExpectedHypervolumeImprovement, ) from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP from botorch.models.gp_regression import SingleTaskGP @@ -322,6 +323,13 @@ def test_replication_sobol_surrogate(self) -> None: @fast_botorch_optimize def test_replication_mbm(self) -> None: + with patch.dict( + "ax.benchmark.problems.hpo.torchvision._REGISTRY", + {"MNIST": TestDataset}, + ): + mnist_problem = get_problem( + problem_name="hpo_pytorch_cnn_MNIST", name="MNIST", num_trials=6 + ) for method, problem, expected_name in [ ( get_sobol_botorch_modular_acquisition( @@ -358,13 +366,13 @@ def test_replication_mbm(self) -> None: ( get_sobol_botorch_modular_acquisition( model_cls=SingleTaskGP, - acquisition_cls=qNoisyExpectedHypervolumeImprovement, + acquisition_cls=qLogNoisyExpectedHypervolumeImprovement, distribute_replications=False, ), get_multi_objective_benchmark_problem( observe_noise_sd=True, num_trials=6 ), - "MBM::SingleTaskGP_qNEHVI", + "MBM::SingleTaskGP_qLogNEHVI", ), ( get_sobol_botorch_modular_acquisition( @@ -375,6 +383,15 @@ def test_replication_mbm(self) -> None: get_multi_objective_benchmark_problem(num_trials=6), "MBM::SAAS_qLogNEI", ), + ( + get_sobol_botorch_modular_acquisition( + model_cls=SingleTaskGP, + acquisition_cls=qLogNoisyExpectedImprovement, + distribute_replications=False, + ), + mnist_problem, + "MBM::SingleTaskGP_qLogNEI", + ), ]: with self.subTest(method=method, problem=problem): res = benchmark_replication(problem=problem, method=method, seed=0) diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 59606294bdd..e751c28adc7 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -17,9 +17,6 @@ import numpy as np import pandas as pd import torch -from ax.benchmark.problems.hpo.torchvision import ( - PyTorchCNNTorchvisionBenchmarkProblem as TorchvisionBenchmarkProblem, -) from ax.core.base_trial import BaseTrial from ax.core.data import Data from ax.core.experiment import Experiment @@ -227,11 +224,6 @@ def object_from_json( decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) - elif _class == TorchvisionBenchmarkProblem: - return TorchvisionBenchmarkProblem.from_dataset_name( - name=object_json["name"], - num_trials=object_json["num_trials"], - ) elif _class in (SurrogateSpec, Surrogate): if "input_transform" in object_json: ( diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 761ddfa61f9..44a0836541c 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -6,12 +6,10 @@ # pyre-strict -import re import warnings from pathlib import Path from typing import Any -from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionBenchmarkProblem from ax.core import Experiment, ObservationFeatures from ax.core.arm import Arm from ax.core.auxiliary import AuxiliaryExperiment @@ -60,7 +58,6 @@ from ax.storage.transform_registry import TRANSFORM_REGISTRY from ax.utils.common.constants import Keys from ax.utils.common.serialization import serialize_init_args -from ax.utils.common.typeutils import not_none from ax.utils.common.typeutils_torch import torch_type_to_str from botorch.models.transforms.input import ChainedInputTransform, InputTransform from botorch.sampling.base import MCSampler @@ -688,18 +685,6 @@ def winsorization_config_to_dict(config: WinsorizationConfig) -> dict[str, Any]: } -def pytorch_cnn_torchvision_benchmark_problem_to_dict( - problem: PyTorchCNNTorchvisionBenchmarkProblem, -) -> dict[str, Any]: - # unit tests for this in benchmark suite - return { - "__type": problem.__class__.__name__, - "name": not_none(re.compile("(?<=::).*").search(problem.name)).group(), - "num_trials": problem.num_trials, - "observe_noise_stds": problem.observe_noise_stds, - } - - def risk_measure_to_dict( risk_measure: RiskMeasure, ) -> dict[str, Any]: diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index d4672b7237c..48ef289edf7 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -17,11 +17,7 @@ ) from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult from ax.benchmark.metrics.benchmark import BenchmarkMetric, GroundTruthBenchmarkMetric -from ax.benchmark.problems.hpo.pytorch_cnn import PyTorchCNNMetric -from ax.benchmark.problems.hpo.torchvision import ( - PyTorchCNNTorchvisionBenchmarkProblem, - PyTorchCNNTorchvisionRunner, -) +from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionParamBasedProblem from ax.benchmark.runners.botorch_test import ( BotorchTestProblemRunner, ParamBasedTestProblemRunner, @@ -151,7 +147,6 @@ parameter_distribution_to_dict, pathlib_to_dict, percentile_early_stopping_strategy_to_dict, - pytorch_cnn_torchvision_benchmark_problem_to_dict, range_parameter_to_dict, risk_measure_to_dict, robust_search_space_to_dict, @@ -246,9 +241,6 @@ pathlib.WindowsPath: pathlib_to_dict, pathlib.PurePosixPath: pathlib_to_dict, pathlib.PureWindowsPath: pathlib_to_dict, - PyTorchCNNTorchvisionBenchmarkProblem: pytorch_cnn_torchvision_benchmark_problem_to_dict, # noqa - PyTorchCNNMetric: metric_to_dict, - PyTorchCNNTorchvisionRunner: runner_to_dict, RangeParameter: range_parameter_to_dict, RiskMeasure: risk_measure_to_dict, RobustSearchSpace: robust_search_space_to_dict, @@ -373,9 +365,7 @@ "PurePosixPath": pathlib_from_json, "PureWindowsPath": pathlib_from_json, "PercentileEarlyStoppingStrategy": PercentileEarlyStoppingStrategy, - "PyTorchCNNTorchvisionBenchmarkProblem": PyTorchCNNTorchvisionBenchmarkProblem, - "PyTorchCNNMetric": PyTorchCNNMetric, - "PyTorchCNNTorchvisionRunner": PyTorchCNNTorchvisionRunner, + "PyTorchCNNTorchvisionParamBasedProblem": PyTorchCNNTorchvisionParamBasedProblem, "RangeParameter": RangeParameter, "ReductionCriterion": ReductionCriterion, "RiskMeasure": RiskMeasure, diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 4e116cf5765..f5886141483 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -13,6 +13,7 @@ import numpy as np import torch +from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionParamBasedProblem from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_benchmark_problem from ax.core.metric import Metric from ax.core.objective import Objective @@ -402,6 +403,16 @@ def __post_init__(self, doesnt_serialize: None) -> None: self.assertEqual(recovered.not_a_field, 1) self.assertEqual(obj, recovered) + def test_EncodeDecode_torchvision_problem(self) -> None: + test_problem = PyTorchCNNTorchvisionParamBasedProblem(name="MNIST") + self.assertIsNotNone(test_problem.train_loader) + self.assertIsNotNone(test_problem.test_loader) + as_json = object_to_json(obj=test_problem) + self.assertNotIn("train_loader", as_json) + recovered = object_from_json(as_json) + self.assertIsNotNone(recovered.train_loader) + self.assertEqual(test_problem, recovered) + def test_EncodeDecodeTorchTensor(self) -> None: x = torch.tensor( [[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64, device=torch.device("cpu") diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 80d3ac9daa8..e1b390c8a15 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -45,6 +45,8 @@ from botorch.models.gp_regression import SingleTaskGP from botorch.test_functions.multi_objective import BraninCurrin, ConstrainedBraninCurrin from botorch.test_functions.synthetic import Branin +from pyre_extensions import assert_is_instance +from torch.utils.data import Dataset def get_single_objective_benchmark_problem( @@ -240,3 +242,28 @@ def __init__( def evaluate_true(self, params: dict[str, float]) -> torch.Tensor: value = sum(elt**2 for elt in params.values()) return value * torch.ones(self.num_objectives, dtype=torch.double) + + +class TestDataset(Dataset): + def __init__( + self, + root: str = "", + train: bool = True, + download: bool = True, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. + transform: Any = None, + ) -> None: + torch.manual_seed(0) + self.data: torch.Tensor = torch.randint( + low=0, high=256, size=(32, 1, 28, 28), dtype=torch.float32 + ) + self.targets: torch.Tensor = torch.randint( + low=0, high=10, size=(32,), dtype=torch.uint8 + ) + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]: + target = assert_is_instance(self.targets[idx].item(), int) + return self.data[idx], target