Skip to content

Commit

Permalink
feat(Optimizer): Add a basic random search
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Apr 24, 2024
1 parent 30610c3 commit 8d79bdb
Showing 1 changed file with 150 additions and 0 deletions.
150 changes: 150 additions & 0 deletions src/amltk/optimization/optimizers/random_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""An optimizer that uses ConfigSpace for random search."""
from __future__ import annotations

from collections.abc import Iterable, Sequence
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Literal, overload
from typing_extensions import override

from amltk.optimization import Metric, Optimizer, Trial
from amltk.pipeline import Node
from amltk.randomness import as_int, randuid
from amltk.store import PathBucket

if TYPE_CHECKING:
from typing_extensions import Self

from ConfigSpace import ConfigurationSpace

from amltk.types import Seed


class RandomSearch(Optimizer[None]):
"""An optimizer that uses ConfigSpace for random search."""

def __init__(
self,
*,
space: ConfigurationSpace,
bucket: PathBucket | None = None,
metrics: Metric | Sequence[Metric],
seed: Seed | None = None,
) -> None:
"""Initialize the optimizer.
Args:
space: The search space to search over.
bucket: The bucket given to trials generated by this optimizer.
metrics: The metrics to optimize. Unused for RandomSearch.
seed: The seed to use for the optimization.
"""
metrics = metrics if isinstance(metrics, Sequence) else [metrics]
super().__init__(metrics=metrics, bucket=bucket)
seed = as_int(seed)
space.seed(seed)
self._counter = 0
self.seed = seed
self.space = space

@override
@classmethod
def create(
cls,
*,
space: ConfigurationSpace | Node,
metrics: Metric | Sequence[Metric],
bucket: PathBucket | str | Path | None = None,
seed: Seed | None = None,
) -> Self:
"""Create a random search optimizer.
Args:
space: The node to optimize
metrics: The metrics to optimize
bucket: The bucket to store the results in
seed: The seed to use for the optimization
"""
seed = as_int(seed)
match bucket:
case None:
bucket = PathBucket(
f"{cls.__name__}-{datetime.now().isoformat()}",
)
case str() | Path():
bucket = PathBucket(bucket)
case bucket:
bucket = bucket # noqa: PLW0127

if isinstance(space, Node):
space = space.search_space(parser=cls.preferred_parser())

return cls(
space=space,
seed=seed,
bucket=bucket,
metrics=metrics,
)

@overload
def ask(self, n: int) -> Iterable[Trial[None]]:
...

@overload
def ask(self, n: None = None) -> Trial[None]:
...

@override
def ask(
self,
n: int | None = None,
) -> Trial[None] | Iterable[Trial[None]]:
"""Ask the optimizer for a new config.
Args:
n: The number of configs to ask for. If `None`, ask for a single config.
Returns:
The trial info for the new config.
"""
if n is None:
configs = [self.space.sample_configuration()]
else:
configs = self.space.sample_configuration(n)

trials: list[Trial[None]] = []
for config in configs:
self._counter += 1
randuid_seed = self.seed + self._counter
unique_name = f"trial-{randuid(4, seed=randuid_seed)}-{self._counter}"
trial: Trial[None] = Trial.create(
name=unique_name,
config=dict(config),
info=None,
seed=self.seed,
bucket=self.bucket / unique_name,
metrics=self.metrics,
)
trials.append(trial)

if n is None:
return trials[0]

return trials

@override
def tell(self, report: Trial.Report[None]) -> None:
"""Tell the optimizer about the result of a trial.
Does nothing for random search.
Args:
report: The report of the trial.
"""

@override
@classmethod
def preferred_parser(cls) -> Literal["configspace"]:
"""The preferred parser for this optimizer."""
return "configspace"

0 comments on commit 8d79bdb

Please sign in to comment.