-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(Optimizer): Add a basic random search
- Loading branch information
1 parent
30610c3
commit 8d79bdb
Showing
1 changed file
with
150 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
"""An optimizer that uses ConfigSpace for random search.""" | ||
from __future__ import annotations | ||
|
||
from collections.abc import Iterable, Sequence | ||
from datetime import datetime | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING, Literal, overload | ||
from typing_extensions import override | ||
|
||
from amltk.optimization import Metric, Optimizer, Trial | ||
from amltk.pipeline import Node | ||
from amltk.randomness import as_int, randuid | ||
from amltk.store import PathBucket | ||
|
||
if TYPE_CHECKING: | ||
from typing_extensions import Self | ||
|
||
from ConfigSpace import ConfigurationSpace | ||
|
||
from amltk.types import Seed | ||
|
||
|
||
class RandomSearch(Optimizer[None]): | ||
"""An optimizer that uses ConfigSpace for random search.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
space: ConfigurationSpace, | ||
bucket: PathBucket | None = None, | ||
metrics: Metric | Sequence[Metric], | ||
seed: Seed | None = None, | ||
) -> None: | ||
"""Initialize the optimizer. | ||
Args: | ||
space: The search space to search over. | ||
bucket: The bucket given to trials generated by this optimizer. | ||
metrics: The metrics to optimize. Unused for RandomSearch. | ||
seed: The seed to use for the optimization. | ||
""" | ||
metrics = metrics if isinstance(metrics, Sequence) else [metrics] | ||
super().__init__(metrics=metrics, bucket=bucket) | ||
seed = as_int(seed) | ||
space.seed(seed) | ||
self._counter = 0 | ||
self.seed = seed | ||
self.space = space | ||
|
||
@override | ||
@classmethod | ||
def create( | ||
cls, | ||
*, | ||
space: ConfigurationSpace | Node, | ||
metrics: Metric | Sequence[Metric], | ||
bucket: PathBucket | str | Path | None = None, | ||
seed: Seed | None = None, | ||
) -> Self: | ||
"""Create a random search optimizer. | ||
Args: | ||
space: The node to optimize | ||
metrics: The metrics to optimize | ||
bucket: The bucket to store the results in | ||
seed: The seed to use for the optimization | ||
""" | ||
seed = as_int(seed) | ||
match bucket: | ||
case None: | ||
bucket = PathBucket( | ||
f"{cls.__name__}-{datetime.now().isoformat()}", | ||
) | ||
case str() | Path(): | ||
bucket = PathBucket(bucket) | ||
case bucket: | ||
bucket = bucket # noqa: PLW0127 | ||
|
||
if isinstance(space, Node): | ||
space = space.search_space(parser=cls.preferred_parser()) | ||
|
||
return cls( | ||
space=space, | ||
seed=seed, | ||
bucket=bucket, | ||
metrics=metrics, | ||
) | ||
|
||
@overload | ||
def ask(self, n: int) -> Iterable[Trial[None]]: | ||
... | ||
|
||
@overload | ||
def ask(self, n: None = None) -> Trial[None]: | ||
... | ||
|
||
@override | ||
def ask( | ||
self, | ||
n: int | None = None, | ||
) -> Trial[None] | Iterable[Trial[None]]: | ||
"""Ask the optimizer for a new config. | ||
Args: | ||
n: The number of configs to ask for. If `None`, ask for a single config. | ||
Returns: | ||
The trial info for the new config. | ||
""" | ||
if n is None: | ||
configs = [self.space.sample_configuration()] | ||
else: | ||
configs = self.space.sample_configuration(n) | ||
|
||
trials: list[Trial[None]] = [] | ||
for config in configs: | ||
self._counter += 1 | ||
randuid_seed = self.seed + self._counter | ||
unique_name = f"trial-{randuid(4, seed=randuid_seed)}-{self._counter}" | ||
trial: Trial[None] = Trial.create( | ||
name=unique_name, | ||
config=dict(config), | ||
info=None, | ||
seed=self.seed, | ||
bucket=self.bucket / unique_name, | ||
metrics=self.metrics, | ||
) | ||
trials.append(trial) | ||
|
||
if n is None: | ||
return trials[0] | ||
|
||
return trials | ||
|
||
@override | ||
def tell(self, report: Trial.Report[None]) -> None: | ||
"""Tell the optimizer about the result of a trial. | ||
Does nothing for random search. | ||
Args: | ||
report: The report of the trial. | ||
""" | ||
|
||
@override | ||
@classmethod | ||
def preferred_parser(cls) -> Literal["configspace"]: | ||
"""The preferred parser for this optimizer.""" | ||
return "configspace" |