From f1f2226afddd48c3b6baa9531f11dc3a88669be0 Mon Sep 17 00:00:00 2001 From: Logan Evans Date: Sun, 24 Mar 2024 00:19:55 -0700 Subject: [PATCH] Match functions by parameter names and types --- funktools/__init__.py | 4 + funktools/_template.py | 193 ++++++++++++++++++++++++++++++++++------- test/test_template.py | 74 ++++++++++++++-- 3 files changed, 230 insertions(+), 41 deletions(-) diff --git a/funktools/__init__.py b/funktools/__init__.py index 7d3b2f3..12fcc41 100644 --- a/funktools/__init__.py +++ b/funktools/__init__.py @@ -22,6 +22,9 @@ def __getattr__(attr: str) -> typing.Callable: elif attr == "template": from ._template import _template return _template() + elif attr == "template": + from ._template import TemplateException + return TemplateException else: raise AttributeError(f"Module 'funktools' has no attribute '{attr}'") @@ -33,6 +36,7 @@ def __getattr__(attr: str) -> typing.Callable: 'Register', 'Throttle', 'template', + 'TemplateException', ] def __dir__(): diff --git a/funktools/_template.py b/funktools/_template.py index 5e67d4b..6ffb902 100644 --- a/funktools/_template.py +++ b/funktools/_template.py @@ -1,56 +1,184 @@ import inspect -from typing import Any, Callable +import typing -_template_funcs: dict[str, Callable] = {} +class TemplateException(Exception): + pass -def _infer_types(func): - annotations = inspect.get_annotations(func) - args = inspect.getfullargspec(func).args - return tuple([annotations[arg] for arg in args]) +_template_funcs: dict[str, typing.Callable] = {} -class _TemplateFunctionBase: - def __init__(self, func, types): - if not isinstance(types, tuple): - types = tuple([types]) - self._funcs = {types: func} - def __call__(self, *args, **kwargs): - types = tuple([type(arg) for arg in args]) - return self._funcs[types](*args, **kwargs) +class _FuncArgInfo: + def __init__( + self, + func: typing.Callable, + annotations: dict = None, + fullargspec: inspect.FullArgSpec = None, + ): + self._func = func + self._annotations = annotations + self._fullargspec = fullargspec + self._arg2type = None - def __getitem__(self, types): - if not isinstance(types, tuple): - types = tuple([types]) - return self._funcs[tuple(types)] + def get_func(self): + return self._func + + def annotations(self): + if self._annotations is None: + self._annotations = inspect.get_annotations(self.get_func()) + + return self._annotations + + def fullargspec(self): + if self._fullargspec is None: + self._fullargspec = inspect.getfullargspec(self.get_func()) + + return self._fullargspec + + def is_match(self, *args, **kwargs): + matched = set() + + argspec = self.fullargspec() + matched_via_default = 0 + if argspec.kwonlydefaults: + matched.update(set(argspec.kwonlydefaults.keys())) + matched_via_default = len(matched.difference(set(kwargs.keys()))) + + for name, val in zip(argspec.args, args): + if want_type := self.annotations().get(name): + if not isinstance(val, want_type): + return False + + matched.add(name) - def __setitem__(self, types, func): - if not isinstance(types, tuple): - types = (types,) - self._funcs[types] = func + legal_names = set(argspec.args + argspec.kwonlyargs) + for name, val in kwargs.items(): + if name not in legal_names: + return False - def add(self, func): - self[_infer_types(func)] = func + if want_type := self.annotations().get(name): + if not isinstance(val, want_type): + return False + + matched.add(name) + + if ( + len(matched) + == len(argspec.args) + len(argspec.kwonlyargs) + == len(args) + len(kwargs) + matched_via_default + ): + return True + + return False def _make_typed_template_function(name: str): """Create a _TemplateFunctionBase object with type `name`""" + + def __init__(self): + self._types2funcs = {} + self._func_arg_infos = [] + + def __call__(self, *args, **kwargs): + for func_arg_info in self._func_arg_infos: + if func_arg_info.is_match(*args, **kwargs): + return func_arg_info.get_func()(*args, **kwargs) + raise TemplateException("Cannot find templated function matching signature") + + __call__.__annotations__ = {} + + def __getitem__(self, types): + try: + return self._types2funcs[types] + except KeyError: + raise TemplateException( + f"Cannot find templated function with types: {types}" + ) + + __getitem__.__annotations__ = {} + + def __setitem__(self, types: typing.Type | tuple, func: typing.Callable): + self._types2funcs[types] = func + + annotations = self.__class__.__getitem__.__annotations__ + + if isinstance(types, tuple): + types = typing.Tuple[types] + + annotations["types"] = typing.Union[types, annotations.get("types", types)] + + for param, types in inspect.get_annotations(func).items(): + annotations[param] = typing.Union[types, annotations.get(param, types)] + + annotations["return"] = typing.Callable + + self._append_func_info(func) + + def add(self, func: typing.Callable): + annotations = inspect.get_annotations(func) + args_annotated = len(annotations) + if "return" in annotations: + args_annotated = args_annotated - 1 + + types = None + argspec = inspect.getfullargspec(func) + if args_annotated == len(argspec.args): + types = tuple([annotations[arg] for arg in argspec.args]) + + if len(types) == 1: + types = types[0] + + self[types] = func + + self._append_func_info(func, annotations, argspec) + + def _append_func_info( + self, + func: typing.Callable, + annotations: dict = None, + fullargspec: inspect.FullArgSpec = None, + ): + arg_info = _FuncArgInfo(func, annotations, fullargspec) + + call_annotations = self.__class__.__call__.__annotations__ + all_args = arg_info.fullargspec().args + arg_info.fullargspec().kwonlyargs + for arg in all_args: + arg_type = arg_info.annotations().get(arg, typing.Any) + call_annotations[arg] = arg_type | call_annotations.get(arg, arg_type) + + self._func_arg_infos.append(arg_info) + TemplateFuncMeta = type(name, (type,), {}) - return TemplateFuncMeta(name, (_TemplateFunctionBase,), {'__name__': name}) + return TemplateFuncMeta( + name, + (), + { + "__init__": __init__, + "__call__": __call__, + "__getitem__": __getitem__, + "__setitem__": __setitem__, + "__name__": name, + "__annotations__": {}, + "add": add, + "_append_func_info": _append_func_info, + }, + )() class _template: def __call__(self, func): name = func.__name__ + if template_func := _template_funcs.get(name): - template_func[_infer_types(func)] = func + template_func.add(func) return template_func typed_template_func = _make_typed_template_function(func.__name__) - template_func = typed_template_func(func, _infer_types(func)) - _template_funcs[name] = template_func - return template_func + typed_template_func.add(func) + _template_funcs[name] = typed_template_func + + return typed_template_func def __getitem__(self, types): def _trampoline(func): @@ -60,9 +188,8 @@ def _trampoline(func): return template_func typed_template_func = _make_typed_template_function(func.__name__) - template_func = typed_template_func(func, types) - _template_funcs[name] = template_func - return template_func + typed_template_func[types] = func + _template_funcs[name] = typed_template_func + return typed_template_func return _trampoline - diff --git a/test/test_template.py b/test/test_template.py index b993b6e..850ebbc 100644 --- a/test/test_template.py +++ b/test/test_template.py @@ -1,53 +1,77 @@ from funktools import template -class Foo: pass -class Bar: pass -class Baz: pass -class Qux: pass -class Fee: pass +class Foo: + pass + + +class Bar: + pass + + +class Baz: + pass + + +class Qux: + pass + + +class Fee: + pass + @template def funk(): return "empty" + @template[Foo] def funk(): return "Foo" + @template[Bar] def funk(): return "Bar" + @template[Baz] def funk(baz: Baz): return "Baz" + @template[Foo, Bar, Baz] def funk(baz: Baz): return "Foo, Bar, Baz" + @template def funk(val: int): return "int" + def foo_qux(): return "Qux" + + funk[Qux] = foo_qux + def foo_fee(fee: Fee): return "Fee" + + funk.add(foo_fee) -def test_calls() -> None: + +def test_typed_calls() -> None: assert funk() == "empty" assert funk[()]() == "empty" assert funk[Foo]() == "Foo" assert funk[Bar]() == "Bar" assert funk[Baz](Baz()) == "Baz" assert funk(Baz()) == "Baz" - # TODO(lpe): Should this one eventually allow - # funk[Foo, Bar](Baz()) as an equivalent call? assert funk[Foo, Bar, Baz](Baz()) == "Foo, Bar, Baz" assert funk[int](31) == "int" assert funk(42) == "int" @@ -55,3 +79,37 @@ def test_calls() -> None: assert funk[Fee](Fee()) == "Fee" assert funk(Fee()) == "Fee" + +@template +def funky(a: int, b) -> str: + return f"funky({a}: int, {b})" + + +@template +def funky(a: int, c: float) -> str: + return f"funky({a}: int, {c}: float)" + + +@template +def funky(a: int, b: float, *, c: str, d: tuple = (1,)) -> str: + return f"funky({a}: int, {b}: float, *, {c}: str, {d}: tuple)" + + +def test_arg_calls() -> None: + assert funky(1, 2.3) == "funky(1: int, 2.3)" + assert funky(1, b=2.3) == "funky(1: int, 2.3)" + assert funky(a=1, b=2.3) == "funky(1: int, 2.3)" + assert funky(b=2.3, a=1) == "funky(1: int, 2.3)" + + assert funky(1, c=3.4) == "funky(1: int, 3.4: float)" + assert funky(a=1, c=3.4) == "funky(1: int, 3.4: float)" + assert funky(c=3.4, a=1) == "funky(1: int, 3.4: float)" + + assert ( + funky(1, 2.3, d=(1, 2), c="abc") + == "funky(1: int, 2.3: float, *, abc: str, (1, 2): tuple)" + ) + assert ( + funky(a=1, b=2.3, c="abc") + == "funky(1: int, 2.3: float, *, abc: str, (1,): tuple)" + )