Skip to content

Commit

Permalink
Register hook with the same signature (#275)
Browse files Browse the repository at this point in the history
This PR implements the first version of the `register_hook` utility with a
simple example.
  • Loading branch information
AdrianSosic authored Jun 20, 2024
2 parents b7e6959 + 748c847 commit ec077ac
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ _ `_optional` subpackage for managing optional dependencies
- Acquisition function for active learning: `qNIPV`
- Abstract `ContinuousNonlinearConstraint` class
- `ContinuousCardinalityConstraint` class and corresponding uniform sampling mechanism
- `register_hook` utility enabling user-defined augmentation of arbitrary callables

### Changed
- Passing an `Objective` to `Campaign` is now optional
Expand Down
51 changes: 49 additions & 2 deletions baybe/utils/basic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Collection of small basic utilities."""

import functools
import inspect
from collections.abc import Callable, Collection, Iterable, Sequence
from dataclasses import dataclass
from inspect import signature
from typing import Any, TypeVar

from baybe.exceptions import UnidentifiedSubclassError
Expand Down Expand Up @@ -139,7 +140,7 @@ def filter_attributes(
Returns:
A dictionary mapping the matched attribute names to their values.
"""
params = signature(callable_).parameters
params = inspect.signature(callable_).parameters
return {
p: getattr(object, p)
for p in params
Expand Down Expand Up @@ -181,3 +182,49 @@ def find_subclass(base: type, name_or_abbr: str, /):
f"The class name or abbreviation '{name_or_abbr}' does not refer to any "
f"of the subclasses of '{base.__name__}'."
)


def register_hook(target: Callable, hook: Callable) -> Callable:
"""Register a hook for the given target callable.
Args:
target: The callable to which the hook is attached.
hook: The callable to be registered in the given callable.
Returns:
A wrapped callable that replaces the original target callable.
Raises:
TypeError: If the signature of the callable does not match the signature of the
target callable.
"""
target_params = inspect.signature(target, eval_str=True).parameters.values()
hook_params = inspect.signature(hook, eval_str=True).parameters.values()

if len(target_params) != len(hook_params):
raise TypeError(
f"'{target.__name__}' and '{hook.__name__}' have "
f"a different number of parameters."
)

for p1, p2 in zip(target_params, hook_params):
if p1.name != p2.name:
raise TypeError(
f"The parameter names of '{target.__name__}' "
f"and '{hook.__name__}' do not match."
)
if (p1.annotation != p2.annotation) and (
p2.annotation is not inspect.Parameter.empty
):
raise TypeError(
f"The type annotations of '{target.__name__}' "
f"and '{hook.__name__}' do not match."
)

@functools.wraps(target)
def wraps(*args, **kwargs):
result = target(*args, **kwargs)
hook(*args, **kwargs)
return result

return wraps
3 changes: 3 additions & 0 deletions examples/Custom_Hooks/Custom_Hooks_Header.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Custom Hooks

These examples demonstrate how to register and use custom hooks for all callables of package objects.
68 changes: 68 additions & 0 deletions examples/Custom_Hooks/callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
## Registering Custom Hooks

# This example demonstrates the basic mechanics of the
# {func}`register_hook <baybe.utils.basic.register_hook>` utility,
# which lets you hook into any callable of your choice:
# * We define a hook that is compatible with the general
# {meth}`RecommenderProtocol.recommend <baybe.recommenders.base.RecommenderProtocol.recommend>`
# interface,
# * attach it to a recommender,
# * and watch it take action.


### Imports


from baybe.parameters import NumericalDiscreteParameter
from baybe.recommenders import RandomRecommender
from baybe.searchspace import SearchSpace
from baybe.utils.basic import register_hook

### Defining the Hook

# We start by defining a hook that lets us inspect the names of the parameters involved
# in the recommendation process.
# For this purpose, we match its signature to that of
# {meth}`RecommenderProtocol.recommend <baybe.recommenders.base.RecommenderProtocol.recommend>`:
#

# ```{admonition} Signature components
# :class: important
# Note that, if provided, annotations must match **exactly** those of the target signature.
# However, annotations are completely optional
# — only the names/order of the signature parameters and their defaults matter.
#
# For the sake of demonstration, we only provide the annotation for the
# relevant `searchspace` parameter of the hook in the example below.
# ```


def print_parameter_names_hook(
self,
batch_size,
searchspace: SearchSpace,
objective=None,
measurements=None,
):
"""Print the names of the parameters spanning the search space."""
print(f"Search space parameters: {[p.name for p in searchspace.parameters]}")


### Monkeypatching

# Next, we create our recommender and monkeypatch its `recommend` method:

recommender = RandomRecommender()
RandomRecommender.recommend = register_hook(
RandomRecommender.recommend, print_parameter_names_hook
)

### Triggering the Hook

# When we now apply the recommender in a specific context, we immediately see the
# effect of the hook:

temperature = NumericalDiscreteParameter("Temperature", values=[90, 105, 120])
concentration = NumericalDiscreteParameter("Concentration", values=[0.057, 0.1, 0.153])
searchspace = SearchSpace.from_product([temperature, concentration])
recommendation = recommender.recommend(batch_size=3, searchspace=searchspace)
74 changes: 74 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Tests for utilities."""
import math

from contextlib import nullcontext

import numpy as np
import pandas as pd
import pytest
from pytest import param

from baybe.utils.basic import register_hook
from baybe.utils.memory import bytes_to_human_readable
from baybe.utils.numerical import closest_element
from baybe.utils.sampling_algorithms import DiscreteSamplingMethod, sample_numerical_df
Expand All @@ -14,6 +17,30 @@
_CLOSEST = _TARGET + 0.1


def f_plain(arg1, arg2):
pass


def f_reduced_plain(arg1):
pass


def f_annotated(arg1: str, arg2: int):
pass


def f_annotated_one_default(arg1: str, arg2: int = 1):
pass


def f_reversed_annotated(arg2: int, arg1: str):
pass


def f2_plain(arg, arg3):
pass


@pytest.mark.parametrize(
"as_ndarray", [param(False, id="list"), param(True, id="array")]
)
Expand Down Expand Up @@ -62,3 +89,50 @@ def test_discrete_sampling(fraction, method):
assert len(sampled) == len(
sampled.drop_duplicates()
), "Undersized sampling did not return unique points."


@pytest.mark.parametrize(
("target, hook, error"),
[
param(
f_annotated,
f_annotated_one_default,
None,
id="hook_with_defaults",
),
param(
f_annotated_one_default,
f_annotated,
None,
id="target_with_defaults",
),
param(
f_annotated,
f_plain,
None,
id="hook_without_annotations",
),
param(
f_annotated,
f_reversed_annotated,
TypeError,
id="different_order",
),
param(
f_annotated,
f2_plain,
TypeError,
id="different_names",
),
param(
f_annotated,
f_reduced_plain,
TypeError,
id="hook_missing_arguments",
),
],
)
def test_register_hook(target, hook, error):
"""Passing in-/consistent signatures to `register_hook` raises an/no error."""
with pytest.raises(error) if error is not None else nullcontext():
register_hook(target, hook)

0 comments on commit ec077ac

Please sign in to comment.