Skip to content

Commit

Permalink
add pre-commit changes
Browse files Browse the repository at this point in the history
  • Loading branch information
berombau committed Oct 30, 2024
1 parent 7d79816 commit 39e002b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
27 changes: 21 additions & 6 deletions src/amltk/pipeline/parsers/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@

@dataclass
class OptunaSearchSpace:
"""A class to represent an Optuna search space.
Wraps a dictionary of hyperparameters and their Optuna distributions.
"""

distributions: dict[str, BaseDistribution] = field(default_factory=dict)

def __repr__(self) -> str:
Expand All @@ -124,22 +129,29 @@ def __str__(self) -> str:
return str(self.distributions)

@classmethod
def parse(cls, *args, **kwargs) -> OptunaSearchSpace:
def parse(cls, *args: Any, **kwargs: Any) -> OptunaSearchSpace:
"""Parse a Node into an Optuna search space."""
return parser(*args, **kwargs)

def sample_configuration(self) -> dict[str, Any]:
"""Sample a configuration from the search space using a default Optuna Study."""
study = optuna.create_study()
trial = self.get_trial(study)
return trial.params

def get_trial(self, study: Study) -> optuna.Trial:
def get_trial(self, study: optuna.Study) -> optuna.Trial:
"""Get a trial from a given Optuna Study using this search space."""
optuna_trial: optuna.Trial
if any("__choice__" in k for k in self.distributions):
optuna_trial: optuna.Trial = study.ask()
optuna_trial = study.ask()
# do all __choice__ suggestions with suggest_categorical
workspace = self.distributions.copy()
filter_patterns = []
for name, distribution in workspace.items():
if "__choice__" in name:
if "__choice__" in name and isinstance(
distribution,
CategoricalDistribution,
):
possible_choices = distribution.choices
choice_made = optuna_trial.suggest_categorical(
name,
Expand Down Expand Up @@ -178,7 +190,7 @@ def get_trial(self, study: Study) -> optuna.Trial:
case _:
raise ValueError(f"Unknown distribution: {distribution}")
else:
optuna_trial: optuna.Trial = study.ask(self.distributions)
optuna_trial = study.ask(self.distributions)
return optuna_trial


Expand Down Expand Up @@ -274,7 +286,10 @@ def parser(

for child in children:
subspace = parser(
child, flat=flat, conditionals=conditionals, delim=delim
child,
flat=flat,
conditionals=conditionals,
delim=delim,
).distributions
if not flat:
subspace = prefix_keys(subspace, prefix=f"{node.name}{delim}")
Expand Down
3 changes: 1 addition & 2 deletions tests/pipeline/parsing/test_optuna_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

import pytest
from pytest_cases import case, parametrize_with_cases
from src.amltk.pipeline.components import Split

from amltk.pipeline import Component, Fixed, Node
from amltk.pipeline.components import Choice
from amltk.pipeline.components import Choice, Split

try:
from optuna.distributions import CategoricalDistribution, IntDistribution
Expand Down

0 comments on commit 39e002b

Please sign in to comment.