From df42735b71357fff338829099685da6d5e429769 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Sun, 25 Aug 2024 13:33:35 -0700 Subject: [PATCH] Add `evaluate_oracle` method to `BenchmarkRunner` (#2705) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2705 Docstring should make this self-explanatory. In addition to the increased ease of development, enables reaping a lot of LOC in D61415525 Reviewed By: Balandat Differential Revision: D61431980 --- ax/benchmark/runners/base.py | 16 ++++++++++++++++ .../tests/runners/test_botorch_test_problem.py | 2 ++ 2 files changed, 18 insertions(+) diff --git a/ax/benchmark/runners/base.py b/ax/benchmark/runners/base.py index cfb0eebd6f9..b4fb19eabf4 100644 --- a/ax/benchmark/runners/base.py +++ b/ax/benchmark/runners/base.py @@ -51,6 +51,22 @@ def get_Y_true(self, arm: Arm) -> Tensor: """ ... + def evaluate_oracle(self, arm: Arm) -> Tensor: + """ + Evaluate oracle metric values at a parameterization. In the base class, + oracle=ground truth. + + This method can be customized for more complex setups based on different + notions of what the "oracle" value should be. For example, in a + multi-task or multi-fidelity problem, it might be appropriate to + evaluate at the target task or fidelity. In a simple single-task + single-fidelity problem, this could the ground truth if available or the + "Y" value if the ground truth is not available. With a + preference-learned objective, the values might be true metrics evaluated + at the true utility function (which would be unobserved in reality). + """ + return self.get_Y_true(arm=arm) + @abstractmethod def get_noise_stds(self) -> Union[None, float, dict[str, float]]: """ diff --git a/ax/benchmark/tests/runners/test_botorch_test_problem.py b/ax/benchmark/tests/runners/test_botorch_test_problem.py index ea929723e3d..f6787e812fa 100644 --- a/ax/benchmark/tests/runners/test_botorch_test_problem.py +++ b/ax/benchmark/tests/runners/test_botorch_test_problem.py @@ -150,6 +150,8 @@ def test_synthetic_runner(self) -> None: torch.Size([2]), X.pow(2).sum().item(), dtype=torch.double ) self.assertTrue(torch.allclose(Y, expected_Y)) + oracle = runner.evaluate_oracle(arm=arm) + self.assertTrue(torch.equal(Y, oracle)) with self.subTest(f"test `run()`, {test_description}"): trial = Mock(spec=Trial)