Skip to content

Commit

Permalink
Make evaluation metadata accept iso formatted dates instead of ints i…
Browse files Browse the repository at this point in the history
…n ms (#1334)

Summary:
Pull Request resolved: #1334

Because ints in ms don't work with pd.Timestamp, which the models ultimately need (https://fburl.com/code/ci5ji3kr) and which render in data

Differential Revision: D42050451

fbshipit-source-id: 1c52446c15d2900f3f5829fe72d05d1ab953210d
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Dec 21, 2022
1 parent 9d0f93e commit b099cb8
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 37 deletions.
23 changes: 18 additions & 5 deletions ax/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
from functools import reduce
from hashlib import md5
from typing import Any, Dict, Iterable, Optional, Set, Type
from typing import Any, Dict, Iterable, Optional, Set, Type, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -317,8 +317,8 @@ def from_evaluations(
evaluations: Dict[str, TTrialEvaluation],
trial_index: int,
sample_sizes: Optional[Dict[str, int]] = None,
start_time: Optional[int] = None,
end_time: Optional[int] = None,
start_time: Optional[Union[int, str]] = None,
end_time: Optional[Union[int, str]] = None,
) -> Data:
"""
Convert dict of evaluations to Ax data object.
Expand All @@ -330,9 +330,13 @@ def from_evaluations(
trial_index: Trial index to which this data belongs.
sample_sizes: Number of samples collected for each arm.
start_time: Optional start time of run of the trial that produced this
data, in milliseconds.
data, in milliseconds or iso format. Milliseconds will be automatically
converted to iso format because iso format automatically works with the
pandas column type `Timestamp`.
end_time: Optional end time of run of the trial that produced this
data, in milliseconds.
data, in milliseconds or iso format. Milliseconds will be automatically
converted to iso format because iso format automatically works with the
pandas column type `Timestamp`.
Returns:
Ax Data object.
Expand All @@ -349,6 +353,11 @@ def from_evaluations(
for metric_name, value in evaluation.items()
]
if start_time is not None or end_time is not None:
if isinstance(start_time, int):
start_time = _ms_epoch_to_isoformat(start_time)
if isinstance(end_time, int):
end_time = _ms_epoch_to_isoformat(end_time)

for record in records:
record.update({"start_time": start_time, "end_time": end_time})
if sample_sizes:
Expand Down Expand Up @@ -442,6 +451,10 @@ def clone_without_metrics(data: Data, excluded_metric_names: Iterable[str]) -> D
)


def _ms_epoch_to_isoformat(epoch: int) -> str:
return pd.Timestamp(epoch, unit="ms").isoformat()


def custom_data_class(
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
Expand Down
30 changes: 25 additions & 5 deletions ax/core/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,41 @@ def testCustomData(self) -> None:
with self.assertRaises(ValueError):
Data(df=pd.DataFrame([data_entry2]))

def testFromEvaluations(self) -> None:
def testFromEvaluationsIsoFormat(self) -> None:
now = pd.Timestamp.now()
day = now.day
for sem in (0.5, None):
eval1 = (3.7, sem) if sem is not None else 3.7
data = Data.from_evaluations(
evaluations={"0_1": {"b": eval1}},
trial_index=0,
sample_sizes={"0_1": 2},
start_time=current_timestamp_in_millis(),
end_time=current_timestamp_in_millis(),
start_time=now.isoformat(),
end_time=now.isoformat(),
)
self.assertEqual(data.df["sem"].isnull()[0], sem is None)
self.assertEqual(len(data.df), 1)
self.assertNotEqual(data, Data(self.df))
self.assertIn("start_time", data.df)
self.assertIn("end_time", data.df)
self.assertEqual(data.df["start_time"][0].day, day)
self.assertEqual(data.df["end_time"][0].day, day)

def testFromEvaluationsMillisecondFormat(self) -> None:
now_ms = current_timestamp_in_millis()
day = pd.Timestamp(now_ms, unit="ms").day
for sem in (0.5, None):
eval1 = (3.7, sem) if sem is not None else 3.7
data = Data.from_evaluations(
evaluations={"0_1": {"b": eval1}},
trial_index=0,
sample_sizes={"0_1": 2},
start_time=now_ms,
end_time=now_ms,
)
self.assertEqual(data.df["sem"].isnull()[0], sem is None)
self.assertEqual(len(data.df), 1)
self.assertNotEqual(data, Data(self.df))
self.assertEqual(data.df["start_time"][0].day, day)
self.assertEqual(data.df["end_time"][0].day, day)

def testFromFidelityEvaluations(self) -> None:
for sem in (0.5, None):
Expand Down
20 changes: 4 additions & 16 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,7 @@
from ax.utils.common.docutils import copy_doc
from ax.utils.common.executils import retry_on_exception
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.typeutils import (
checked_cast,
checked_cast_complex,
checked_cast_optional,
not_none,
)
from ax.utils.common.typeutils import checked_cast, checked_cast_complex, not_none
from botorch.utils.sampling import manual_seed

logger: Logger = get_logger(__name__)
Expand Down Expand Up @@ -1792,22 +1787,15 @@ def _make_evaluations_and_data(
a batched trial or a 1-arm trial.
"""
raw_data_by_arm = self._raw_data_by_arm(trial=trial, raw_data=raw_data)
metadata = metadata if metadata is not None else {}

evaluations, data = self.data_and_evaluations_from_raw_data(
raw_data=raw_data_by_arm,
metric_names=list(self.metric_names),
trial_index=trial.index,
sample_sizes=sample_sizes or {},
start_time=(
checked_cast_optional(int, metadata.get("start_time"))
if metadata is not None
else None
),
end_time=(
checked_cast_optional(int, metadata.get("end_time"))
if metadata is not None
else None
),
start_time=metadata.get("start_time"),
end_time=metadata.get("end_time"),
)
return evaluations, data

Expand Down
51 changes: 48 additions & 3 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from unittest.mock import patch

import numpy as np
import pandas as pd
import torch
from ax.core.arm import Arm
from ax.core.base_trial import TrialStatus
Expand All @@ -39,6 +40,7 @@
from ax.metrics.branin import branin
from ax.modelbridge.dispatch_utils import DEFAULT_BAYESIAN_PARALLELISM
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.random import RandomModelBridge
from ax.modelbridge.registry import Models
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.service.utils.best_point import (
Expand All @@ -53,7 +55,6 @@
from ax.storage.sqa_store.sqa_config import SQAConfig
from ax.storage.sqa_store.structs import DBSettings
from ax.utils.common.testutils import TestCase
from ax.utils.common.timeutils import current_timestamp_in_millis
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.testing.core_stubs import DummyEarlyStoppingStrategy
from ax.utils.testing.mock import fast_botorch_optimize
Expand Down Expand Up @@ -1555,6 +1556,50 @@ def test_trial_completion(self) -> None:
self.assertEqual(best_trial_values[0], {"branin": -2.0})
self.assertTrue(math.isnan(best_trial_values[1]["branin"]["branin"]))

def test_trial_completion_with_metadata_with_iso_times(self) -> None:
ax_client = get_branin_optimization()
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(
trial_index=idx,
raw_data={"branin": (0, 0.0)},
metadata={
"start_time": "2020-01-01",
"end_time": "2020-01-05 00:00:00",
},
)
with patch.object(
RandomModelBridge, "_fit", autospec=True, side_effect=RandomModelBridge._fit
) as mock_fit:
ax_client.get_next_trial()
mock_fit.assert_called_once()
features = mock_fit.call_args_list[0][1]["observations"][0].features
# we're asserting it's actually created real Timestamp objects
# for the observation features
self.assertEqual(features.start_time.day, 1)
self.assertEqual(features.end_time.day, 5)

def test_trial_completion_with_metadata_milisecond_times(self) -> None:
ax_client = get_branin_optimization()
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(
trial_index=idx,
raw_data={"branin": (0, 0.0)},
metadata={
"start_time": int(pd.Timestamp("2020-01-01").timestamp() * 1000),
"end_time": int(pd.Timestamp("2020-01-05").timestamp() * 1000),
},
)
with patch.object(
RandomModelBridge, "_fit", autospec=True, side_effect=RandomModelBridge._fit
) as mock_fit:
ax_client.get_next_trial()
mock_fit.assert_called_once()
features = mock_fit.call_args_list[0][1]["observations"][0].features
# we're asserting it's actually created real Timestamp objects
# for the observation features
self.assertEqual(features.start_time.day, 1)
self.assertEqual(features.end_time.day, 5)

def test_abandon_trial(self) -> None:
ax_client = get_branin_optimization()

Expand Down Expand Up @@ -1598,7 +1643,7 @@ def test_ttl_trial(self) -> None:
self.assertEqual(ax_client.get_best_parameters()[0], params2)

def test_start_and_end_time_in_trial_completion(self) -> None:
start_time = current_timestamp_in_millis()
start_time = pd.Timestamp.now().isoformat()
ax_client = AxClient()
ax_client.create_experiment(
parameters=[
Expand All @@ -1613,7 +1658,7 @@ def test_start_and_end_time_in_trial_completion(self) -> None:
raw_data=1.0,
metadata={
"start_time": start_time,
"end_time": current_timestamp_in_millis(),
"end_time": pd.Timestamp.now().isoformat(),
},
)
dat = ax_client.experiment.fetch_data().df
Expand Down
16 changes: 8 additions & 8 deletions ax/service/utils/instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,8 +895,6 @@ def make_experiment(
def raw_data_to_evaluation(
raw_data: TEvaluationOutcome,
metric_names: List[str],
start_time: Optional[int] = None,
end_time: Optional[int] = None,
) -> TEvaluationOutcome:
"""Format the trial evaluation data to a standard `TTrialEvaluation`
(mapping from metric names to a tuple of mean and SEM) representation, or
Expand Down Expand Up @@ -945,8 +943,8 @@ def data_and_evaluations_from_raw_data(
metric_names: List[str],
trial_index: int,
sample_sizes: Dict[str, int],
start_time: Optional[int] = None,
end_time: Optional[int] = None,
start_time: Optional[Union[int, str]] = None,
end_time: Optional[Union[int, str]] = None,
) -> Tuple[Dict[str, TEvaluationOutcome], Data]:
"""Transforms evaluations into Ax Data.
Expand All @@ -961,16 +959,18 @@ def data_and_evaluations_from_raw_data(
sample_sizes: Number of samples collected for each arm, may be empty
if unavailable.
start_time: Optional start time of run of the trial that produced this
data, in milliseconds.
data, in milliseconds or iso format. Milliseconds will eventually be
converted to iso format because iso format automatically works with the
pandas column type `Timestamp`.
end_time: Optional end time of run of the trial that produced this
data, in milliseconds.
data, in milliseconds or iso format. Milliseconds will eventually be
converted to iso format because iso format automatically works with the
pandas column type `Timestamp`.
"""
evaluations = {
arm_name: cls.raw_data_to_evaluation(
raw_data=raw_data[arm_name],
metric_names=metric_names,
start_time=start_time,
end_time=end_time,
)
for arm_name in raw_data
}
Expand Down

0 comments on commit b099cb8

Please sign in to comment.