diff --git a/src/amltk/pipeline/parsers/optuna.py b/src/amltk/pipeline/parsers/optuna.py index 49c28d84..ce31e4a1 100644 --- a/src/amltk/pipeline/parsers/optuna.py +++ b/src/amltk/pipeline/parsers/optuna.py @@ -115,6 +115,11 @@ @dataclass class OptunaSearchSpace: + """A class to represent an Optuna search space. + + Wraps a dictionary of hyperparameters and their Optuna distributions. + """ + distributions: dict[str, BaseDistribution] = field(default_factory=dict) def __repr__(self) -> str: @@ -124,22 +129,29 @@ def __str__(self) -> str: return str(self.distributions) @classmethod - def parse(cls, *args, **kwargs) -> OptunaSearchSpace: + def parse(cls, *args: Any, **kwargs: Any) -> OptunaSearchSpace: + """Parse a Node into an Optuna search space.""" return parser(*args, **kwargs) def sample_configuration(self) -> dict[str, Any]: + """Sample a configuration from the search space using a default Optuna Study.""" study = optuna.create_study() trial = self.get_trial(study) return trial.params - def get_trial(self, study: Study) -> optuna.Trial: + def get_trial(self, study: optuna.Study) -> optuna.Trial: + """Get a trial from a given Optuna Study using this search space.""" + optuna_trial: optuna.Trial if any("__choice__" in k for k in self.distributions): - optuna_trial: optuna.Trial = study.ask() + optuna_trial = study.ask() # do all __choice__ suggestions with suggest_categorical workspace = self.distributions.copy() filter_patterns = [] for name, distribution in workspace.items(): - if "__choice__" in name: + if "__choice__" in name and isinstance( + distribution, + CategoricalDistribution, + ): possible_choices = distribution.choices choice_made = optuna_trial.suggest_categorical( name, @@ -178,7 +190,7 @@ def get_trial(self, study: Study) -> optuna.Trial: case _: raise ValueError(f"Unknown distribution: {distribution}") else: - optuna_trial: optuna.Trial = study.ask(self.distributions) + optuna_trial = study.ask(self.distributions) return optuna_trial @@ -274,7 +286,10 @@ def parser( for child in children: subspace = parser( - child, flat=flat, conditionals=conditionals, delim=delim + child, + flat=flat, + conditionals=conditionals, + delim=delim, ).distributions if not flat: subspace = prefix_keys(subspace, prefix=f"{node.name}{delim}") diff --git a/tests/pipeline/parsing/test_optuna_parser.py b/tests/pipeline/parsing/test_optuna_parser.py index 2b69dc49..789ab549 100644 --- a/tests/pipeline/parsing/test_optuna_parser.py +++ b/tests/pipeline/parsing/test_optuna_parser.py @@ -5,10 +5,9 @@ import pytest from pytest_cases import case, parametrize_with_cases -from src.amltk.pipeline.components import Split from amltk.pipeline import Component, Fixed, Node -from amltk.pipeline.components import Choice +from amltk.pipeline.components import Choice, Split try: from optuna.distributions import CategoricalDistribution, IntDistribution