Skip to content

Commit

Permalink
improve and move Choice logic from optimizer to parser
Browse files Browse the repository at this point in the history
  • Loading branch information
berombau committed Oct 30, 2024
1 parent a57a126 commit 7d79816
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 53 deletions.
33 changes: 2 additions & 31 deletions src/amltk/optimization/optimizers/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def add_to_history(_, report: Trial.Report):
Sorry!
""" # noqa: E501

from __future__ import annotations

from collections.abc import Iterable, Sequence
Expand Down Expand Up @@ -291,37 +292,7 @@ def ask(
"""
if n is not None:
return (self.ask(n=None) for _ in range(n))

if any("__choice__" in k for k in self.space):
optuna_trial: optuna.Trial = self.study.ask()
# do all __choice__ suggestions with suggest_categorical
workspace = self.space.copy()
delete_other_options = []
for name, distribution in workspace.items():
if "__choice__" in name:
possible_choices = distribution.choices
choice_made = optuna_trial.suggest_categorical(name, choices=possible_choices)
for c in possible_choices:
if c != choice_made:
delete_other_options.append(f"{name}:{c}:")
# filter all parameters given the made choices
filtered_workspace = {k: v for k, v in workspace.items() if (
("__choice__" not in k) and
(not any(c in k for c in delete_other_options))
)}
# do all remaining suggestions with the correct suggest function
for name, distribution in filtered_workspace.items():
match distribution:
case optuna.distributions.CategoricalDistribution(choices=choices):
optuna_trial.suggest_categorical(name, choices=choices)
case optuna.distributions.IntDistribution(low=low, high=high, log=log):
optuna_trial.suggest_int(name, low=low, high=high, log=log)
case optuna.distributions.FloatDistribution(low=low, high=high):
optuna_trial.suggest_float(name, low=low, high=high)
case _:
raise ValueError(f"Unknown distribution: {distribution}")
else:
optuna_trial: optuna.Trial = self.study.ask(self.space)
optuna_trial = self.space.get_trial(self.study)
config = optuna_trial.params
trial_number = optuna_trial.number
unique_name = f"{trial_number=}"
Expand Down
83 changes: 76 additions & 7 deletions src/amltk/pipeline/parsers/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,11 @@
from __future__ import annotations

from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import numpy as np
import optuna
from optuna.distributions import (
BaseDistribution,
CategoricalChoiceType,
Expand All @@ -106,13 +108,78 @@
from amltk.pipeline.components import Choice

if TYPE_CHECKING:
from typing import TypeAlias

from amltk.pipeline import Node

PAIR = 2

OptunaSearchSpace: TypeAlias = dict[str, BaseDistribution]

@dataclass
class OptunaSearchSpace:
distributions: dict[str, BaseDistribution] = field(default_factory=dict)

def __repr__(self) -> str:
return f"OptunaSearchSpace({self.distributions})"

def __str__(self) -> str:
return str(self.distributions)

@classmethod
def parse(cls, *args, **kwargs) -> OptunaSearchSpace:
return parser(*args, **kwargs)

def sample_configuration(self) -> dict[str, Any]:
study = optuna.create_study()
trial = self.get_trial(study)
return trial.params

def get_trial(self, study: Study) -> optuna.Trial:
if any("__choice__" in k for k in self.distributions):
optuna_trial: optuna.Trial = study.ask()
# do all __choice__ suggestions with suggest_categorical
workspace = self.distributions.copy()
filter_patterns = []
for name, distribution in workspace.items():
if "__choice__" in name:
possible_choices = distribution.choices
choice_made = optuna_trial.suggest_categorical(
name,
choices=possible_choices,
)
for c in possible_choices:
if c != choice_made:
# deletable options have the name of the unwanted choices
filter_patterns.append(f":{c}:")
# filter all parameters for the unwanted choices
filtered_workspace = {
k: v
for k, v in workspace.items()
if (
("__choice__" not in k)
and (
not any(
filter_pattern in k for filter_pattern in filter_patterns
)
)
)
}
# do all remaining suggestions with the correct suggest function
for name, distribution in filtered_workspace.items():
match distribution:
case CategoricalDistribution(choices=choices):
optuna_trial.suggest_categorical(name, choices=choices)
case IntDistribution(
low=low,
high=high,
log=log,
):
optuna_trial.suggest_int(name, low=low, high=high, log=log)
case FloatDistribution(low=low, high=high):
optuna_trial.suggest_float(name, low=low, high=high)
case _:
raise ValueError(f"Unknown distribution: {distribution}")
else:
optuna_trial: optuna.Trial = study.ask(self.distributions)
return optuna_trial


def _convert_hp_to_optuna_distribution(
Expand Down Expand Up @@ -150,7 +217,7 @@ def _convert_hp_to_optuna_distribution(
raise ValueError(f"Could not parse {name} as a valid Optuna distribution.\n{hp=}")


def _parse_space(node: Node) -> OptunaSearchSpace:
def _parse_space(node: Node) -> dict[str, BaseDistribution]:
match node.space:
case None:
space = {}
Expand Down Expand Up @@ -206,7 +273,9 @@ def parser(
space[name] = CategoricalDistribution([child.name for child in children])

for child in children:
subspace = parser(child, flat=flat, conditionals=conditionals, delim=delim)
subspace = parser(
child, flat=flat, conditionals=conditionals, delim=delim
).distributions
if not flat:
subspace = prefix_keys(subspace, prefix=f"{node.name}{delim}")

Expand All @@ -218,4 +287,4 @@ def parser(
)
space[name] = hp

return space
return OptunaSearchSpace(distributions=space)
20 changes: 18 additions & 2 deletions tests/optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ def opt_optuna(metric: Metric, tmp_path: Path) -> OptunaOptimizer:

@case
@parametrize("metric", [*metrics, metrics]) # Single obj and multi
def opt_optuna_choice(metric: Metric, tmp_path: Path) -> OptunaOptimizer:
def opt_optuna_choice_hierarchical(metric: Metric, tmp_path: Path) -> OptunaOptimizer:
try:
from amltk.optimization.optimizers.optuna import OptunaOptimizer
except ImportError:
pytest.skip("Optuna is not installed")

c1 = Component(_A, name="hi1", space={"a": [1, 2, 3]})
c2 = Component(_B, name="hi2", space={"b": [4, 5, 6]})
pipeline = Choice([c1, c2], name="hi")
pipeline = Choice(c1, c2, name="hi")
return OptunaOptimizer.create(
space=pipeline,
metrics=metric,
Expand Down Expand Up @@ -173,3 +173,19 @@ def test_optuna_choice_output(optimizer: Optimizer):
trial = optimizer.ask()
keys = list(trial.config.keys())
assert any("__choice__" in k for k in keys), trial.config


@parametrize_with_cases("optimizer", cases=".", prefix="opt_optuna_choice")
def test_optuna_choice_no_params_left(optimizer: Optimizer):
trial = optimizer.ask()
keys_without_choices = [
k for k in list(trial.config.keys()) if "__choice__" not in k
]
for k, v in trial.config.items():
if "__choice__" in k:
name_without_choice = k.removesuffix("__choice__")
params_for_choice = [
k for k in keys_without_choices if k.startswith(name_without_choice)
]
# Check that only params for the chosen choice are left
assert all(v in k for k in params_for_choice), params_for_choice
49 changes: 36 additions & 13 deletions tests/pipeline/parsing/test_optuna_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from pytest_cases import case, parametrize_with_cases
from src.amltk.pipeline.components import Split

from amltk.pipeline import Component, Fixed, Node
from amltk.pipeline.components import Choice
Expand Down Expand Up @@ -91,7 +92,6 @@ def case_single_step_two_hp_different_types() -> Params:
return Params(item, expected) # type: ignore


# TODO: Testing for with and without conditions does not really make sense here
@case
def case_choice() -> Params:
item = Choice(
Expand All @@ -103,7 +103,7 @@ def case_choice() -> Params:

expected = {}

# Not flat and with conditions
# Not Flat and without conditions
space = OptunaSearchSpace(
{
"choice1:a:hp": CategoricalDistribution([1, 2, 3]),
Expand All @@ -112,9 +112,9 @@ def case_choice() -> Params:
"choice1:__choice__": CategoricalDistribution(["a", "b"]),
},
)
expected[(NOT_FLAT, CONDITIONED)] = space
expected[(NOT_FLAT, NOT_CONDITIONED)] = space

# Flat and with conditions
# Flat and without conditions
space = OptunaSearchSpace(
{
"a:hp": CategoricalDistribution([1, 2, 3]),
Expand All @@ -123,15 +123,36 @@ def case_choice() -> Params:
"choice1:__choice__": CategoricalDistribution(["a", "b"]),
},
)
expected[(FLAT, CONDITIONED)] = space
expected[(FLAT, NOT_CONDITIONED)] = space
return Params(item, expected) # type: ignore

# Not Flat and without conditions

@case
def case_nested_choices_with_split_and_choice() -> Params:
item = Choice(
Split(
Choice(
Component(object, name="a", space={"hp": [1, 2, 3]}),
Component(object, name="b", space={"hp2": (1, 10)}),
name="choice3",
),
Component(object, name="c", space={"hp3": (1, 10)}),
name="split2",
),
Component(object, name="d", space={"hp4": (1, 10)}),
name="choice1",
)
expected = {}

# Not flat and without conditions
space = OptunaSearchSpace(
{
"choice1:a:hp": CategoricalDistribution([1, 2, 3]),
"choice1:b:hp2": IntDistribution(1, 10),
"choice1:hp3": IntDistribution(1, 10),
"choice1:__choice__": CategoricalDistribution(["a", "b"]),
"choice1:split2:choice3:a:hp": CategoricalDistribution([1, 2, 3]),
"choice1:split2:choice3:b:hp2": IntDistribution(1, 10),
"choice1:split2:c:hp3": IntDistribution(1, 10),
"choice1:d:hp4": IntDistribution(1, 10),
"choice1:__choice__": CategoricalDistribution(["d", "split2"]),
"choice1:split2:choice3:__choice__": CategoricalDistribution(["a", "b"]),
},
)
expected[(NOT_FLAT, NOT_CONDITIONED)] = space
Expand All @@ -141,12 +162,14 @@ def case_choice() -> Params:
{
"a:hp": CategoricalDistribution([1, 2, 3]),
"b:hp2": IntDistribution(1, 10),
"choice1:hp3": IntDistribution(1, 10),
"choice1:__choice__": CategoricalDistribution(["a", "b"]),
"c:hp3": IntDistribution(1, 10),
"d:hp4": IntDistribution(1, 10),
"choice1:__choice__": CategoricalDistribution(["d", "split2"]),
"choice3:__choice__": CategoricalDistribution(["a", "b"]),
},
)
expected[(FLAT, NOT_CONDITIONED)] = space
return Params(item, expected) # type: ignore
return Params(item, expected)


@parametrize_with_cases("test_case", cases=".")
Expand Down

0 comments on commit 7d79816

Please sign in to comment.