Skip to content

Commit

Permalink
feat(Pipeline): Optimize pipelines directly with optimize() (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman authored Jan 26, 2024
1 parent 4198de7 commit bded378
Show file tree
Hide file tree
Showing 14 changed files with 1,108 additions and 14 deletions.
3 changes: 2 additions & 1 deletion src/amltk/_richutil/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from amltk._richutil.renderable import RichRenderable
from amltk._richutil.renderers import Function, rich_make_column_selector
from amltk._richutil.util import df_to_table, richify
from amltk._richutil.util import df_to_table, is_jupyter, richify

__all__ = [
"df_to_table",
"richify",
"RichRenderable",
"Function",
"rich_make_column_selector",
"is_jupyter",
]
23 changes: 23 additions & 0 deletions src/amltk/_richutil/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# where rich not being installed.
from __future__ import annotations

import os
from concurrent.futures import ProcessPoolExecutor
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -70,3 +71,25 @@ def df_to_table(
table.add_row(str(index), *[str(cell) for cell in row])

return table


def is_jupyter() -> bool:
"""Return True if running in a Jupyter environment."""
# https://github.com/Textualize/rich/blob/fd981823644ccf50d685ac9c0cfe8e1e56c9dd35/rich/console.py#L518-L535
try:
get_ipython # type: ignore[name-defined] # noqa: B018
except NameError:
return False
ipython = get_ipython() # type: ignore[name-defined] # noqa: F821
shell = ipython.__class__.__name__
if (
"google.colab" in str(ipython.__class__)
or os.getenv("DATABRICKS_RUNTIME_VERSION")
or shell == "ZMQInteractiveShell"
):
return True # Jupyter notebook or qtconsole

if shell == "TerminalInteractiveShell":
return False # Terminal running IPython

return False # Other type (?)
43 changes: 43 additions & 0 deletions src/amltk/_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from typing import Any


def threadpoolctl_heuristic(item_contained_in_node: Any | None) -> bool:
"""Heuristic to determine if we should automatically set threadpoolctl.
This is done by detecting if it's a scikit-learn `BaseEstimator` but this may
be extended in the future.
!!! tip
The reason to have this heuristic is that when running scikit-learn, or any
multithreaded model, in parallel, they will over subscribe to threads. This
causes a significant performance hit as most of the time is spent switching
thread contexts instead of work. This can be particularly bad for HPO where
we are evaluating multiple models in parallel on the same system.
The recommened thread count is 1 per core with no additional information to
act upon.
!!! todo
This is potentially not an issue if running on multiple nodes of some cluster,
as they do not share logical cores and hence do not clash.
Args:
item_contained_in_node: The item with which to base the heuristic on.
Returns:
Whether we should automatically set threadpoolctl.
"""
if item_contained_in_node is None or not isinstance(item_contained_in_node, type):
return False

try:
# NOTE: sklearn depends on threadpoolctl so it will be installed.
from sklearn.base import BaseEstimator

return issubclass(item_contained_in_node, BaseEstimator)
except ImportError:
return False
Empty file added src/amltk/evalutors/__init__.py
Empty file.
59 changes: 59 additions & 0 deletions src/amltk/evalutors/evaluation_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Evaluation protocols for how a trial and a pipeline should be evaluated.
TODO: Sorry
"""
from __future__ import annotations

from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING

from amltk.scheduling import Plugin

if TYPE_CHECKING:
from amltk.optimization import Trial
from amltk.pipeline import Node
from amltk.scheduling import Scheduler, Task


class EvaluationProtocol:
"""A protocol for how a trial should be evaluated on a pipeline."""

fn: Callable[[Trial, Node], Trial.Report]

def task(
self,
scheduler: Scheduler,
plugins: Plugin | Iterable[Plugin] | None = None,
) -> Task[[Trial, Node], Trial.Report]:
"""Create a task for this protocol.
Args:
scheduler: The scheduler to use for the task.
plugins: The plugins to use for the task.
Returns:
The created task.
"""
_plugins: tuple[Plugin, ...]
match plugins:
case None:
_plugins = ()
case Plugin():
_plugins = (plugins,)
case Iterable():
_plugins = tuple(plugins)

return scheduler.task(self.fn, plugins=_plugins)


class CustomProtocol(EvaluationProtocol):
"""A custom evaluation protocol based on a user function."""

def __init__(self, fn: Callable[[Trial, Node], Trial.Report]) -> None:
"""Initialize the protocol.
Args:
fn: The function to use for the evaluation.
"""
super().__init__()
self.fn = fn
11 changes: 9 additions & 2 deletions src/amltk/optimization/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def target_function(trial: Trial) -> Trial.Report:
from collections import defaultdict
from collections.abc import Callable, Hashable, Iterable, Iterator
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal, TypeVar
from typing import TYPE_CHECKING, Literal, TypeVar, overload
from typing_extensions import override

import pandas as pd
Expand Down Expand Up @@ -527,7 +527,14 @@ def sortby(

return sorted(history.reports, key=sort_key, reverse=reverse)

@override
@overload
def __getitem__(self, key: int | str) -> Trial.Report:
...

@overload
def __getitem__(self, key: slice) -> Trial.Report:
...

def __getitem__( # type: ignore
self,
key: int | str | slice,
Expand Down
97 changes: 96 additions & 1 deletion src/amltk/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,22 @@
"""
from __future__ import annotations

import logging
from abc import abstractmethod
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
from datetime import datetime
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Concatenate,
Generic,
ParamSpec,
Protocol,
TypeVar,
overload,
)
from typing_extensions import Self

from more_itertools import all_unique

Expand All @@ -36,11 +40,14 @@
from amltk.optimization.metric import Metric
from amltk.optimization.trial import Trial
from amltk.pipeline import Node
from amltk.types import Seed

I = TypeVar("I") # noqa: E741
P = ParamSpec("P")
ParserOutput = TypeVar("ParserOutput")

logger = logging.getLogger(__name__)


class Optimizer(Generic[I]):
"""An optimizer protocol.
Expand Down Expand Up @@ -123,3 +130,91 @@ def preferred_parser(
"""
return None

@classmethod
@abstractmethod
def create(
cls,
*,
space: Node,
metrics: Metric | Sequence[Metric],
bucket: str | Path | PathBucket | None = None,
seed: Seed | None = None,
) -> Self:
"""Create this optimizer.
!!! note
Subclasses should override this with more specific configuration
but these arguments should be all that's necessary to create the optimizer.
Args:
space: The space to optimize over.
bucket: The bucket for where to store things related to the trial.
metrics: The metrics to optimize.
seed: The seed to use for the optimizer.
Returns:
The optimizer.
"""

class CreateSignature(Protocol):
"""A Protocol which defines the keywords required to create an
optimizer with deterministic behavior at a desired location.
This protocol matches the `Optimizer.create` classmethod, however we also
allow any function which accepts the keyword arguments to create an
Optimizer.
"""

def __call__(
self,
*,
space: Node,
metrics: Metric | Sequence[Metric],
bucket: PathBucket | None = None,
seed: Seed | None = None,
) -> Optimizer:
"""A function which creates an optimizer for node.optimize should
accept the following keyword arguments.
Args:
space: The node to optimize
metrics: The metrics to optimize
bucket: The bucket to store the results in
seed: The seed to use for the optimization
"""
...

@classmethod
def _get_known_importable_optimizer_classes(cls) -> Iterator[type[Optimizer]]:
"""Get all developer known optimizer classes. This is used for defaults.
Do not rely on this functionality and prefer to give concrete optimizers to
functionality requiring one. This is intended for convenience of particular
quickstart methods.
"""
# NOTE: We can't use the `Optimizer.__subclasses__` method as the optimizers
# are not imported by any other module initially and so they do no exist
# until imported. Hence this manual iteration. For now, we be explicit and
# only if the optimizer list grows should we consider dynamic importing.
try:
from amltk.optimization.optimizers.smac import SMACOptimizer

yield SMACOptimizer
except ImportError as e:
logger.debug("Failed to import SMACOptimizer", exc_info=e)

try:
from amltk.optimization.optimizers.optuna import OptunaOptimizer

yield OptunaOptimizer
except ImportError as e:
logger.debug("Failed to import OptunaOptimizer", exc_info=e)

try:
from amltk.optimization.optimizers.neps import NEPSOptimizer

yield NEPSOptimizer
except ImportError as e:
logger.debug("Failed to import NEPSOptimizer", exc_info=e)
4 changes: 2 additions & 2 deletions src/amltk/optimization/optimizers/neps.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def __init__(
self,
*,
space: SearchSpace,
loss_metric: Metric,
loss_metric: Metric | Sequence[Metric],
cost_metric: Metric | None = None,
optimizer: BaseOptimizer,
working_dir: Path,
Expand Down Expand Up @@ -307,7 +307,7 @@ def create( # noqa: PLR0913
| Mapping[str, ConfigurationSpace | Parameter]
| Node
),
metrics: Metric,
metrics: Metric | Sequence[Metric],
cost_metric: Metric | None = None,
bucket: PathBucket | str | Path | None = None,
searcher: str | BaseOptimizer = "default",
Expand Down
Loading

0 comments on commit bded378

Please sign in to comment.