Skip to content

Commit

Permalink
Consolidate BenchmarkMetric functionality in one file
Browse files Browse the repository at this point in the history
Summary:
There is no need to have multiple files here anymore now that a lot of BenchmarkMetric functionality has disappeared.

D61736027 follows up by moving `benchmark/metrics/benchmark.py` to `benchmark/benchmark_metric.py` and moving the corresponding test file.

Reviewed By: Balandat

Differential Revision: D61432000
  • Loading branch information
esantorella authored and facebook-github-bot committed Aug 24, 2024
1 parent 0099295 commit f08058d
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 91 deletions.
66 changes: 56 additions & 10 deletions ax/benchmark/metrics/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

# pyre-strict

from __future__ import annotations

from typing import Any, Optional

from ax.benchmark.metrics.utils import _fetch_trial_data
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.metric import Metric, MetricFetchResult

from ax.core.data import Data
from ax.core.metric import Metric, MetricFetchE, MetricFetchResult
from ax.utils.common.result import Err, Ok


class BenchmarkMetric(Metric):
Expand Down Expand Up @@ -48,14 +49,59 @@ def __init__(
self.outcome_index = outcome_index

def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
"""
Args:
trial: The trial from which to fetch data.
kwargs: Unsupported and will raise an exception.
Returns:
A MetricFetchResult containing the data for the requested metric.
"""
if len(kwargs) > 0:
raise NotImplementedError(
f"Arguments {set(kwargs)} are not supported in "
f"{self.__class__.__name__}.fetch_trial_data."
)
return _fetch_trial_data(
trial=trial,
metric_name=self.name,
outcome_index=self.outcome_index,
include_noise_sd=self.observe_noise_sd,
)
outcome_index = self.outcome_index
if outcome_index is None:
# Look up the index based on the outcome name under which we track the data
# as part of `run_metadata`.
outcome_names = trial.run_metadata.get("outcome_names")
if outcome_names is None:
raise RuntimeError(
"Trials' `run_metadata` must contain `outcome_names` if "
"no `outcome_index` is provided."
)
outcome_index = outcome_names.index(self.name)

try:
arm_names = list(trial.arms_by_name.keys())
all_Ys = trial.run_metadata["Ys"]
Ys = [all_Ys[arm_name][outcome_index] for arm_name in arm_names]

if self.observe_noise_sd:
stdvs = [
trial.run_metadata["Ystds"][arm_name][outcome_index]
for arm_name in arm_names
]
else:
stdvs = [float("nan")] * len(Ys)

df = pd.DataFrame(
{
"arm_name": arm_names,
"metric_name": self.name,
"mean": Ys,
"sem": stdvs,
"trial_index": trial.index,
}
)
return Ok(value=Data(df=df))

except Exception as e:
return Err(
MetricFetchE(
message=f"Failed to obtain data for trial {trial.index}",
exception=e,
)
)
81 changes: 0 additions & 81 deletions ax/benchmark/metrics/utils.py

This file was deleted.

0 comments on commit f08058d

Please sign in to comment.