From d39978b53ba602701dae56e050f325763b8e6583 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Wed, 12 Jun 2024 21:21:36 +0200 Subject: [PATCH] Fix validation logic --- baybe/utils/basic.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/baybe/utils/basic.py b/baybe/utils/basic.py index edba69274..cb48d7fd0 100644 --- a/baybe/utils/basic.py +++ b/baybe/utils/basic.py @@ -208,12 +208,18 @@ def register_hook(target: Callable, hook: Callable) -> Callable: ) for p1, p2 in zip(target_params, hook_params): - if p1.name != p2.name or p1.annotation != p2.annotation: - if p2.annotation is not inspect.Parameter.empty: - raise TypeError( - f"The signature of '{hook.__name__}' does not match the " - f"signature of '{target.__name__}'." - ) + if p1.name != p2.name: + raise TypeError( + f"The parameter names of '{target.__name__}' " + f"and '{hook.__name__}' do not match." + ) + if (p1.annotation != p2.annotation) and ( + p2.annotation is not inspect.Parameter.empty + ): + raise TypeError( + f"The type annotations of '{target.__name__}' " + f"and '{hook.__name__}' do not match." + ) @functools.wraps(target) def wraps(*args, **kwargs):