Skip to content

Commit

Permalink
Avoid errors in telemetry due to node-based GenerationStrategy (#2554)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2554

The Ax telemetry module currently assumes that `GenerationStrategy` objects are all step-based. This diff a warning and returns a dummy record for those cases to avoid errors with node-based GSs.

The proper fix here is to fully support node-based GSs in telemetry, but that will be a bit more work it seems.

Reviewed By: saitcakmak

Differential Revision: D59193222
  • Loading branch information
Balandat authored and facebook-github-bot committed Jul 1, 2024
1 parent ff1445a commit 0a32d5b
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 12 deletions.
2 changes: 1 addition & 1 deletion ax/telemetry/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AxClientCreatedRecord:

# Dimensionality of transformed SearchSpace can often be much higher due to one-hot
# encoding of unordered ChoiceParameters
transformed_dimensionality: int
transformed_dimensionality: Optional[int]

@classmethod
def from_ax_client(cls, ax_client: AxClient) -> AxClientCreatedRecord:
Expand Down
18 changes: 13 additions & 5 deletions ax/telemetry/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@

# pyre-strict

import warnings
from datetime import datetime
from typing import Any, Dict, List, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type

from ax.core.experiment import Experiment

from ax.exceptions.core import AxWarning
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy

from ax.modelbridge.modelbridge_utils import (
extract_search_space_digest,
transform_search_space,
)

from ax.modelbridge.registry import ModelRegistryBase, Models, SearchSpace
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.cast import Cast
Expand All @@ -32,11 +31,20 @@

def _get_max_transformed_dimensionality(
search_space: SearchSpace, generation_strategy: GenerationStrategy
) -> int:
) -> Optional[int]:
"""
Get dimensionality of transformed SearchSpace for all steps in the
GenerationStrategy and return the maximum.
"""
if generation_strategy.is_node_based:
warnings.warn(
"`_get_max_transformed_dimensionality` does not fully support node-based "
"generation strategies. This will result in an incomplete record.",
category=AxWarning,
stacklevel=4,
)
# TODO [T192965545]: Support node-based generation strategies in telemetry
return None

transforms_by_step = [
_extract_transforms_and_configs(step=step)
Expand Down
29 changes: 25 additions & 4 deletions ax/telemetry/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

from __future__ import annotations

import warnings
from dataclasses import dataclass
from math import inf
from typing import Optional

from ax.exceptions.core import AxWarning
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.telemetry.common import INITIALIZATION_MODELS, OTHER_MODELS

Expand All @@ -26,17 +29,35 @@ class GenerationStrategyCreatedRecord:
generation_strategy_name: str

# -1 indicates unlimited trials requested, 0 indicates no trials requested
num_requested_initialization_trials: int # Typically the number of Sobol trials
num_requested_bayesopt_trials: int
num_requested_other_trials: int
num_requested_initialization_trials: Optional[
int # Typically the number of Sobol trials
]
num_requested_bayesopt_trials: Optional[int]
num_requested_other_trials: Optional[int]

# Minimum `max_parallelism` across GenerationSteps, i.e. the bottleneck
max_parallelism: int
max_parallelism: Optional[int]

@classmethod
def from_generation_strategy(
cls, generation_strategy: GenerationStrategy
) -> GenerationStrategyCreatedRecord:
if generation_strategy.is_node_based:
warnings.warn(
"`GenerationStrategyCreatedRecord` does not fully support node-based "
"generation strategies. This will result in an incomplete record.",
category=AxWarning,
stacklevel=4,
)
# TODO [T192965545]: Support node-based generation strategies in telemetry
return cls(
generation_strategy_name=generation_strategy.name,
num_requested_initialization_trials=None,
num_requested_bayesopt_trials=None,
num_requested_other_trials=None,
max_parallelism=None,
)

# Minimum `max_parallelism` across GenerationSteps, i.e. the bottleneck
true_max_parallelism = min(
step.max_parallelism or inf for step in generation_strategy._steps
Expand Down
2 changes: 1 addition & 1 deletion ax/telemetry/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class SchedulerCreatedRecord:

# Dimensionality of transformed SearchSpace can often be much higher due to one-hot
# encoding of unordered ChoiceParameters
transformed_dimensionality: int
transformed_dimensionality: Optional[int]

@classmethod
def from_scheduler(cls, scheduler: Scheduler) -> SchedulerCreatedRecord:
Expand Down
24 changes: 23 additions & 1 deletion ax/telemetry/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@

# pyre-strict

from ax.exceptions.core import AxWarning
from ax.telemetry.generation_strategy import GenerationStrategyCreatedRecord
from ax.utils.common.testutils import TestCase
from ax.utils.testing.modeling_stubs import get_generation_strategy
from ax.utils.testing.modeling_stubs import (
get_generation_strategy,
sobol_gpei_generation_node_gs,
)


class TestGenerationStrategy(TestCase):
Expand All @@ -25,3 +29,21 @@ def test_generation_strategy_created_record_from_generation_strategy(self) -> No
max_parallelism=3,
)
self.assertEqual(record, expected)

def test_generation_strategy_created_record_node_based(self) -> None:
gs = sobol_gpei_generation_node_gs()
with self.assertWarnsRegex(
AxWarning,
"`GenerationStrategyCreatedRecord` does not fully support node-based*",
):
record = GenerationStrategyCreatedRecord.from_generation_strategy(
generation_strategy=gs
)
expected = GenerationStrategyCreatedRecord(
generation_strategy_name="Sobol+GPEI_Nodes",
num_requested_initialization_trials=None,
num_requested_bayesopt_trials=None,
num_requested_other_trials=None,
max_parallelism=None,
)
self.assertEqual(record, expected)

0 comments on commit 0a32d5b

Please sign in to comment.