Skip to content

Commit

Permalink
begin on support for Choice in Optuna parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
berombau committed Oct 29, 2024
1 parent b680838 commit a57a126
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 7 deletions.
31 changes: 30 additions & 1 deletion src/amltk/optimization/optimizers/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,36 @@ def ask(
if n is not None:
return (self.ask(n=None) for _ in range(n))

optuna_trial: optuna.Trial = self.study.ask(self.space)
if any("__choice__" in k for k in self.space):
optuna_trial: optuna.Trial = self.study.ask()
# do all __choice__ suggestions with suggest_categorical
workspace = self.space.copy()
delete_other_options = []
for name, distribution in workspace.items():
if "__choice__" in name:
possible_choices = distribution.choices
choice_made = optuna_trial.suggest_categorical(name, choices=possible_choices)
for c in possible_choices:
if c != choice_made:
delete_other_options.append(f"{name}:{c}:")
# filter all parameters given the made choices
filtered_workspace = {k: v for k, v in workspace.items() if (
("__choice__" not in k) and
(not any(c in k for c in delete_other_options))
)}
# do all remaining suggestions with the correct suggest function
for name, distribution in filtered_workspace.items():
match distribution:
case optuna.distributions.CategoricalDistribution(choices=choices):
optuna_trial.suggest_categorical(name, choices=choices)
case optuna.distributions.IntDistribution(low=low, high=high, log=log):
optuna_trial.suggest_int(name, low=low, high=high, log=log)
case optuna.distributions.FloatDistribution(low=low, high=high):
optuna_trial.suggest_float(name, low=low, high=high)
case _:
raise ValueError(f"Unknown distribution: {distribution}")
else:
optuna_trial: optuna.Trial = self.study.ask(self.space)
config = optuna_trial.params
trial_number = optuna_trial.number
unique_name = f"{trial_number=}"
Expand Down
16 changes: 10 additions & 6 deletions src/amltk/pipeline/parsers/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,17 @@
)

from amltk._functional import prefix_keys
from amltk.pipeline.components import Choice

if TYPE_CHECKING:
from typing import TypeAlias

from amltk.pipeline import Node

OptunaSearchSpace: TypeAlias = dict[str, BaseDistribution]

PAIR = 2

OptunaSearchSpace: TypeAlias = dict[str, BaseDistribution]


def _convert_hp_to_optuna_distribution(
name: str,
Expand Down Expand Up @@ -196,12 +197,15 @@ def parser(
delim: The delimiter to use for the names of the hyperparameters.
"""
if conditionals:
raise NotImplementedError("Conditionals are not yet supported with Optuna.")

space = prefix_keys(_parse_space(node), prefix=f"{node.name}{delim}")

for child in node.nodes:
children = node.nodes

if isinstance(node, Choice) and any(children):
name = f"{node.name}{delim}__choice__"
space[name] = CategoricalDistribution([child.name for child in children])

for child in children:
subspace = parser(child, flat=flat, conditionals=conditionals, delim=delim)
if not flat:
subspace = prefix_keys(subspace, prefix=f"{node.name}{delim}")
Expand Down
31 changes: 31 additions & 0 deletions tests/optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from amltk.optimization import Metric, Optimizer, Trial
from amltk.pipeline import Component
from amltk.pipeline.components import Choice
from amltk.profiling import Timer

if TYPE_CHECKING:
Expand All @@ -24,6 +25,10 @@ class _A:
pass


class _B:
pass


metrics = [
Metric("score_bounded", minimize=False, bounds=(0, 1)),
Metric("score_unbounded", minimize=False),
Expand Down Expand Up @@ -87,6 +92,25 @@ def opt_optuna(metric: Metric, tmp_path: Path) -> OptunaOptimizer:
)


@case
@parametrize("metric", [*metrics, metrics]) # Single obj and multi
def opt_optuna_choice(metric: Metric, tmp_path: Path) -> OptunaOptimizer:
try:
from amltk.optimization.optimizers.optuna import OptunaOptimizer
except ImportError:
pytest.skip("Optuna is not installed")

c1 = Component(_A, name="hi1", space={"a": [1, 2, 3]})
c2 = Component(_B, name="hi2", space={"b": [4, 5, 6]})
pipeline = Choice([c1, c2], name="hi")
return OptunaOptimizer.create(
space=pipeline,
metrics=metric,
seed=42,
bucket=tmp_path,
)


@case
@parametrize("metric", [*metrics]) # Single obj
def opt_neps(metric: Metric, tmp_path: Path) -> NEPSOptimizer:
Expand Down Expand Up @@ -142,3 +166,10 @@ def test_batched_ask_generates_unique_configs(optimizer: Optimizer):
batch = list(optimizer.ask(10))
assert len(batch) == 10
assert all_unique(batch)


@parametrize_with_cases("optimizer", cases=".", prefix="opt_optuna_choice")
def test_optuna_choice_output(optimizer: Optimizer):
trial = optimizer.ask()
keys = list(trial.config.keys())
assert any("__choice__" in k for k in keys), trial.config
196 changes: 196 additions & 0 deletions tests/pipeline/parsing/test_optuna_parser.py
Original file line number Diff line number Diff line change
@@ -1 +1,197 @@
# TODO: Fill this in
from __future__ import annotations

from dataclasses import dataclass

import pytest
from pytest_cases import case, parametrize_with_cases

from amltk.pipeline import Component, Fixed, Node
from amltk.pipeline.components import Choice

try:
from optuna.distributions import CategoricalDistribution, IntDistribution

from amltk.pipeline.parsers.optuna import OptunaSearchSpace
except ImportError:
pytest.skip("Optuna not installed", allow_module_level=True)


FLAT = True
NOT_FLAT = False
CONDITIONED = True
NOT_CONDITIONED = False


@dataclass
class Params:
"""A test case for parsing a Node into a ConfigurationSpace."""

root: Node
expected: dict[tuple[bool, bool], OptunaSearchSpace]


@case
def case_single_frozen() -> Params:
item = Fixed(object(), name="a")
space = OptunaSearchSpace()
expected = {
(NOT_FLAT, CONDITIONED): space,
(NOT_FLAT, NOT_CONDITIONED): space,
(FLAT, CONDITIONED): space,
(FLAT, NOT_CONDITIONED): space,
}
return Params(item, expected) # type: ignore


@case
def case_single_component() -> Params:
item = Component(object, name="a", space={"hp": [1, 2, 3]})
space = OptunaSearchSpace({"a:hp": CategoricalDistribution([1, 2, 3])})
expected = {
(NOT_FLAT, CONDITIONED): space,
(NOT_FLAT, NOT_CONDITIONED): space,
(FLAT, CONDITIONED): space,
(FLAT, NOT_CONDITIONED): space,
}
return Params(item, expected) # type: ignore


@case
def case_single_step_two_hp() -> Params:
item = Component(object, name="a", space={"hp": [1, 2, 3], "hp2": [1, 2, 3]})
space = OptunaSearchSpace(
{
"a:hp": CategoricalDistribution([1, 2, 3]),
"a:hp2": CategoricalDistribution([1, 2, 3]),
},
)

expected = {
(NOT_FLAT, CONDITIONED): space,
(NOT_FLAT, NOT_CONDITIONED): space,
(FLAT, CONDITIONED): space,
(FLAT, NOT_CONDITIONED): space,
}
return Params(item, expected) # type: ignore


@case
def case_single_step_two_hp_different_types() -> Params:
item = Component(object, name="a", space={"hp": [1, 2, 3], "hp2": (1, 10)})
space = OptunaSearchSpace(
{"a:hp": CategoricalDistribution([1, 2, 3]), "a:hp2": IntDistribution(1, 10)},
)
expected = {
(NOT_FLAT, CONDITIONED): space,
(NOT_FLAT, NOT_CONDITIONED): space,
(FLAT, CONDITIONED): space,
(FLAT, NOT_CONDITIONED): space,
}
return Params(item, expected) # type: ignore


# TODO: Testing for with and without conditions does not really make sense here
@case
def case_choice() -> Params:
item = Choice(
Component(object, name="a", space={"hp": [1, 2, 3]}),
Component(object, name="b", space={"hp2": (1, 10)}),
name="choice1",
space={"hp3": (1, 10)},
)

expected = {}

# Not flat and with conditions
space = OptunaSearchSpace(
{
"choice1:a:hp": CategoricalDistribution([1, 2, 3]),
"choice1:b:hp2": IntDistribution(1, 10),
"choice1:hp3": IntDistribution(1, 10),
"choice1:__choice__": CategoricalDistribution(["a", "b"]),
},
)
expected[(NOT_FLAT, CONDITIONED)] = space

# Flat and with conditions
space = OptunaSearchSpace(
{
"a:hp": CategoricalDistribution([1, 2, 3]),
"b:hp2": IntDistribution(1, 10),
"choice1:hp3": IntDistribution(1, 10),
"choice1:__choice__": CategoricalDistribution(["a", "b"]),
},
)
expected[(FLAT, CONDITIONED)] = space

# Not Flat and without conditions
space = OptunaSearchSpace(
{
"choice1:a:hp": CategoricalDistribution([1, 2, 3]),
"choice1:b:hp2": IntDistribution(1, 10),
"choice1:hp3": IntDistribution(1, 10),
"choice1:__choice__": CategoricalDistribution(["a", "b"]),
},
)
expected[(NOT_FLAT, NOT_CONDITIONED)] = space

# Flat and without conditions
space = OptunaSearchSpace(
{
"a:hp": CategoricalDistribution([1, 2, 3]),
"b:hp2": IntDistribution(1, 10),
"choice1:hp3": IntDistribution(1, 10),
"choice1:__choice__": CategoricalDistribution(["a", "b"]),
},
)
expected[(FLAT, NOT_CONDITIONED)] = space
return Params(item, expected) # type: ignore


@parametrize_with_cases("test_case", cases=".")
def test_parsing_pipeline(test_case: Params) -> None:
pipeline = test_case.root

for (flat, conditioned), expected in test_case.expected.items():
parsed_space = pipeline.search_space(
"optuna",
flat=flat,
conditionals=conditioned,
)
assert (
parsed_space == expected
), f"Failed for {flat=}, {conditioned=}.\n{parsed_space}\n{expected}"


@parametrize_with_cases("test_case", cases=".")
def test_parsing_does_not_mutate_space_of_nodes(test_case: Params) -> None:
pipeline = test_case.root
spaces_before = {tuple(path): step.space for path, step in pipeline.walk()}

for (flat, conditioned), _ in test_case.expected.items():
pipeline.search_space(
"optuna",
flat=flat,
conditionals=conditioned,
)
spaces_after = {tuple(path): step.space for path, step in pipeline.walk()}
assert spaces_before == spaces_after


@parametrize_with_cases("test_case", cases=".")
def test_parsing_twice_produces_same_space(test_case: Params) -> None:
pipeline = test_case.root

for (flat, conditioned), _ in test_case.expected.items():
parsed_space = pipeline.search_space(
"optuna",
flat=flat,
conditionals=conditioned,
)
parsed_space2 = pipeline.search_space(
"optuna",
flat=flat,
conditionals=conditioned,
)
assert parsed_space == parsed_space2

0 comments on commit a57a126

Please sign in to comment.