Skip to content

Commit

Permalink
Make evaluation metadata accept iso formatted dates instead of ints i…
Browse files Browse the repository at this point in the history
…n ms

Summary: Because ints in ms don't work with pd.Timestamp, which the models ultimately need (https://fburl.com/code/ci5ji3kr) and which render in data

Differential Revision: D42050451

fbshipit-source-id: bd156bafdee39637720bd2705d6138874e7b633f
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Dec 20, 2022
1 parent 32fbe65 commit 8f980e3
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 20 deletions.
4 changes: 2 additions & 2 deletions ax/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ def from_evaluations(
evaluations: Dict[str, TTrialEvaluation],
trial_index: int,
sample_sizes: Optional[Dict[str, int]] = None,
start_time: Optional[int] = None,
end_time: Optional[int] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
) -> Data:
"""
Convert dict of evaluations to Ax data object.
Expand Down
4 changes: 2 additions & 2 deletions ax/core/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def testFromEvaluations(self) -> None:
evaluations={"0_1": {"b": eval1}},
trial_index=0,
sample_sizes={"0_1": 2},
start_time=current_timestamp_in_millis(),
end_time=current_timestamp_in_millis(),
start_time=pd.Timestamp.now().isoformat(),
end_time=pd.Timestamp.now().isoformat(),
)
self.assertEqual(data.df["sem"].isnull()[0], sem is None)
self.assertEqual(len(data.df), 1)
Expand Down
14 changes: 7 additions & 7 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def update_running_trial_with_intermediate_data(
self,
trial_index: int,
raw_data: TEvaluationOutcome,
metadata: Optional[Dict[str, Union[str, int]]] = None,
metadata: Optional[Dict[str, str]] = None,
sample_size: Optional[int] = None,
) -> None:
"""
Expand Down Expand Up @@ -733,7 +733,7 @@ def complete_trial(
self,
trial_index: int,
raw_data: TEvaluationOutcome,
metadata: Optional[Dict[str, Union[str, int]]] = None,
metadata: Optional[Dict[str, str]] = None,
sample_size: Optional[int] = None,
) -> None:
"""
Expand Down Expand Up @@ -784,7 +784,7 @@ def update_trial_data(
self,
trial_index: int,
raw_data: TEvaluationOutcome,
metadata: Optional[Dict[str, Union[str, int]]] = None,
metadata: Optional[Dict[str, str]] = None,
sample_size: Optional[int] = None,
) -> None:
"""
Expand Down Expand Up @@ -1554,7 +1554,7 @@ def _update_trial_with_raw_data(
self,
trial_index: int,
raw_data: TEvaluationOutcome,
metadata: Optional[Dict[str, Union[str, int]]] = None,
metadata: Optional[Dict[str, str]] = None,
sample_size: Optional[int] = None,
complete_trial: bool = False,
combine_with_last_data: bool = False,
Expand Down Expand Up @@ -1776,7 +1776,7 @@ def _make_evaluations_and_data(
self,
trial: BaseTrial,
raw_data: Union[TEvaluationOutcome, Dict[str, TEvaluationOutcome]],
metadata: Optional[Dict[str, Union[str, int]]],
metadata: Optional[Dict[str, str]],
sample_sizes: Optional[Dict[str, int]] = None,
) -> Tuple[Dict[str, TEvaluationOutcome], Data]:
"""Formats given raw data as Ax evaluations and `Data`.
Expand All @@ -1799,12 +1799,12 @@ def _make_evaluations_and_data(
trial_index=trial.index,
sample_sizes=sample_sizes or {},
start_time=(
checked_cast_optional(int, metadata.get("start_time"))
checked_cast_optional(str, metadata.get("start_time"))
if metadata is not None
else None
),
end_time=(
checked_cast_optional(int, metadata.get("end_time"))
checked_cast_optional(str, metadata.get("end_time"))
if metadata is not None
else None
),
Expand Down
32 changes: 29 additions & 3 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
import time
from itertools import product
from math import ceil
from random import random
from typing import Dict, List, Optional, Set, Tuple, TYPE_CHECKING
from unittest.mock import patch

import numpy as np
import pandas as pd
import torch
from ax.core.arm import Arm
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
Expand All @@ -28,6 +31,7 @@
)
from ax.core.parameter_constraint import OrderConstraint
from ax.core.search_space import HierarchicalSearchSpace
from ax.core.trial import Trial
from ax.core.types import ComparisonOp, TModelPredictArm, TParameterization, TParamValue
from ax.exceptions.core import (
DataRequiredError,
Expand All @@ -39,6 +43,7 @@
from ax.metrics.branin import branin
from ax.modelbridge.dispatch_utils import DEFAULT_BAYESIAN_PARALLELISM
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.random import RandomModelBridge
from ax.modelbridge.registry import Models
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.service.utils.best_point import (
Expand All @@ -53,7 +58,6 @@
from ax.storage.sqa_store.sqa_config import SQAConfig
from ax.storage.sqa_store.structs import DBSettings
from ax.utils.common.testutils import TestCase
from ax.utils.common.timeutils import current_timestamp_in_millis
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.testing.core_stubs import DummyEarlyStoppingStrategy
from ax.utils.testing.mock import fast_botorch_optimize
Expand Down Expand Up @@ -1555,6 +1559,28 @@ def test_trial_completion(self) -> None:
self.assertEqual(best_trial_values[0], {"branin": -2.0})
self.assertTrue(math.isnan(best_trial_values[1]["branin"]["branin"]))

def test_trial_completion_with_metadata(self) -> None:
ax_client = get_branin_optimization()
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(
trial_index=idx,
raw_data={"branin": (0, 0.0)},
metadata={
"start_time": "2020-01-01",
"end_time": "2020-01-05 00:00:00",
},
)
with patch.object(
RandomModelBridge, "_fit", autospec=True, side_effect=RandomModelBridge._fit
) as mock_fit:
ax_client.get_next_trial()
mock_fit.assert_called_once()
features = mock_fit.call_args_list[0][1]["observations"][0].features
# we're asserting it's actually created real Timestamp objects
# for the observation features
self.assertEqual(features.start_time.day, 1)
self.assertEqual(features.end_time.day, 5)

def test_abandon_trial(self) -> None:
ax_client = get_branin_optimization()

Expand Down Expand Up @@ -1598,7 +1624,7 @@ def test_ttl_trial(self) -> None:
self.assertEqual(ax_client.get_best_parameters()[0], params2)

def test_start_and_end_time_in_trial_completion(self) -> None:
start_time = current_timestamp_in_millis()
start_time = pd.Timestamp.now().isoformat()
ax_client = AxClient()
ax_client.create_experiment(
parameters=[
Expand All @@ -1613,7 +1639,7 @@ def test_start_and_end_time_in_trial_completion(self) -> None:
raw_data=1.0,
metadata={
"start_time": start_time,
"end_time": current_timestamp_in_millis(),
"end_time": pd.Timestamp.now().isoformat(),
},
)
dat = ax_client.experiment.fetch_data().df
Expand Down
8 changes: 2 additions & 6 deletions ax/service/utils/instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,8 +895,6 @@ def make_experiment(
def raw_data_to_evaluation(
raw_data: TEvaluationOutcome,
metric_names: List[str],
start_time: Optional[int] = None,
end_time: Optional[int] = None,
) -> TEvaluationOutcome:
"""Format the trial evaluation data to a standard `TTrialEvaluation`
(mapping from metric names to a tuple of mean and SEM) representation, or
Expand Down Expand Up @@ -945,8 +943,8 @@ def data_and_evaluations_from_raw_data(
metric_names: List[str],
trial_index: int,
sample_sizes: Dict[str, int],
start_time: Optional[int] = None,
end_time: Optional[int] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
) -> Tuple[Dict[str, TEvaluationOutcome], Data]:
"""Transforms evaluations into Ax Data.
Expand All @@ -969,8 +967,6 @@ def data_and_evaluations_from_raw_data(
arm_name: cls.raw_data_to_evaluation(
raw_data=raw_data[arm_name],
metric_names=metric_names,
start_time=start_time,
end_time=end_time,
)
for arm_name in raw_data
}
Expand Down

0 comments on commit 8f980e3

Please sign in to comment.