Skip to content

Commit

Permalink
Add evaluate_oracle method to BenchmarkRunner (#2705)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2705

Docstring should make this self-explanatory.

In addition to the increased ease of development, enables reaping a lot of LOC in D61415525

Reviewed By: Balandat

Differential Revision: D61431980
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 25, 2024
1 parent 003983d commit df42735
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
16 changes: 16 additions & 0 deletions ax/benchmark/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ def get_Y_true(self, arm: Arm) -> Tensor:
"""
...

def evaluate_oracle(self, arm: Arm) -> Tensor:
"""
Evaluate oracle metric values at a parameterization. In the base class,
oracle=ground truth.
This method can be customized for more complex setups based on different
notions of what the "oracle" value should be. For example, in a
multi-task or multi-fidelity problem, it might be appropriate to
evaluate at the target task or fidelity. In a simple single-task
single-fidelity problem, this could the ground truth if available or the
"Y" value if the ground truth is not available. With a
preference-learned objective, the values might be true metrics evaluated
at the true utility function (which would be unobserved in reality).
"""
return self.get_Y_true(arm=arm)

@abstractmethod
def get_noise_stds(self) -> Union[None, float, dict[str, float]]:
"""
Expand Down
2 changes: 2 additions & 0 deletions ax/benchmark/tests/runners/test_botorch_test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def test_synthetic_runner(self) -> None:
torch.Size([2]), X.pow(2).sum().item(), dtype=torch.double
)
self.assertTrue(torch.allclose(Y, expected_Y))
oracle = runner.evaluate_oracle(arm=arm)
self.assertTrue(torch.equal(Y, oracle))

with self.subTest(f"test `run()`, {test_description}"):
trial = Mock(spec=Trial)
Expand Down

0 comments on commit df42735

Please sign in to comment.