Skip to content

Commit

Permalink
Match functions by parameter names and types
Browse files Browse the repository at this point in the history
  • Loading branch information
LoganEvans committed Mar 24, 2024
1 parent 1ae36a6 commit f1f2226
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 41 deletions.
4 changes: 4 additions & 0 deletions funktools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def __getattr__(attr: str) -> typing.Callable:
elif attr == "template":
from ._template import _template
return _template()
elif attr == "template":
from ._template import TemplateException
return TemplateException
else:
raise AttributeError(f"Module 'funktools' has no attribute '{attr}'")

Expand All @@ -33,6 +36,7 @@ def __getattr__(attr: str) -> typing.Callable:
'Register',
'Throttle',
'template',
'TemplateException',
]

def __dir__():
Expand Down
193 changes: 160 additions & 33 deletions funktools/_template.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,184 @@
import inspect
from typing import Any, Callable
import typing

_template_funcs: dict[str, Callable] = {}

class TemplateException(Exception):
pass

def _infer_types(func):
annotations = inspect.get_annotations(func)
args = inspect.getfullargspec(func).args
return tuple([annotations[arg] for arg in args])

_template_funcs: dict[str, typing.Callable] = {}

class _TemplateFunctionBase:
def __init__(self, func, types):
if not isinstance(types, tuple):
types = tuple([types])
self._funcs = {types: func}

def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
return self._funcs[types](*args, **kwargs)
class _FuncArgInfo:
def __init__(
self,
func: typing.Callable,
annotations: dict = None,
fullargspec: inspect.FullArgSpec = None,
):
self._func = func
self._annotations = annotations
self._fullargspec = fullargspec
self._arg2type = None

def __getitem__(self, types):
if not isinstance(types, tuple):
types = tuple([types])
return self._funcs[tuple(types)]
def get_func(self):
return self._func

def annotations(self):
if self._annotations is None:
self._annotations = inspect.get_annotations(self.get_func())

return self._annotations

def fullargspec(self):
if self._fullargspec is None:
self._fullargspec = inspect.getfullargspec(self.get_func())

return self._fullargspec

def is_match(self, *args, **kwargs):
matched = set()

argspec = self.fullargspec()
matched_via_default = 0
if argspec.kwonlydefaults:
matched.update(set(argspec.kwonlydefaults.keys()))
matched_via_default = len(matched.difference(set(kwargs.keys())))

for name, val in zip(argspec.args, args):
if want_type := self.annotations().get(name):
if not isinstance(val, want_type):
return False

matched.add(name)

def __setitem__(self, types, func):
if not isinstance(types, tuple):
types = (types,)
self._funcs[types] = func
legal_names = set(argspec.args + argspec.kwonlyargs)
for name, val in kwargs.items():
if name not in legal_names:
return False

def add(self, func):
self[_infer_types(func)] = func
if want_type := self.annotations().get(name):
if not isinstance(val, want_type):
return False

matched.add(name)

if (
len(matched)
== len(argspec.args) + len(argspec.kwonlyargs)
== len(args) + len(kwargs) + matched_via_default
):
return True

return False


def _make_typed_template_function(name: str):
"""Create a _TemplateFunctionBase object with type `name`"""

def __init__(self):
self._types2funcs = {}
self._func_arg_infos = []

def __call__(self, *args, **kwargs):
for func_arg_info in self._func_arg_infos:
if func_arg_info.is_match(*args, **kwargs):
return func_arg_info.get_func()(*args, **kwargs)
raise TemplateException("Cannot find templated function matching signature")

__call__.__annotations__ = {}

def __getitem__(self, types):
try:
return self._types2funcs[types]
except KeyError:
raise TemplateException(
f"Cannot find templated function with types: {types}"
)

__getitem__.__annotations__ = {}

def __setitem__(self, types: typing.Type | tuple, func: typing.Callable):
self._types2funcs[types] = func

annotations = self.__class__.__getitem__.__annotations__

if isinstance(types, tuple):
types = typing.Tuple[types]

annotations["types"] = typing.Union[types, annotations.get("types", types)]

for param, types in inspect.get_annotations(func).items():
annotations[param] = typing.Union[types, annotations.get(param, types)]

annotations["return"] = typing.Callable

self._append_func_info(func)

def add(self, func: typing.Callable):
annotations = inspect.get_annotations(func)
args_annotated = len(annotations)
if "return" in annotations:
args_annotated = args_annotated - 1

types = None
argspec = inspect.getfullargspec(func)
if args_annotated == len(argspec.args):
types = tuple([annotations[arg] for arg in argspec.args])

if len(types) == 1:
types = types[0]

self[types] = func

self._append_func_info(func, annotations, argspec)

def _append_func_info(
self,
func: typing.Callable,
annotations: dict = None,
fullargspec: inspect.FullArgSpec = None,
):
arg_info = _FuncArgInfo(func, annotations, fullargspec)

call_annotations = self.__class__.__call__.__annotations__
all_args = arg_info.fullargspec().args + arg_info.fullargspec().kwonlyargs
for arg in all_args:
arg_type = arg_info.annotations().get(arg, typing.Any)
call_annotations[arg] = arg_type | call_annotations.get(arg, arg_type)

self._func_arg_infos.append(arg_info)

TemplateFuncMeta = type(name, (type,), {})
return TemplateFuncMeta(name, (_TemplateFunctionBase,), {'__name__': name})
return TemplateFuncMeta(
name,
(),
{
"__init__": __init__,
"__call__": __call__,
"__getitem__": __getitem__,
"__setitem__": __setitem__,
"__name__": name,
"__annotations__": {},
"add": add,
"_append_func_info": _append_func_info,
},
)()


class _template:
def __call__(self, func):
name = func.__name__

if template_func := _template_funcs.get(name):
template_func[_infer_types(func)] = func
template_func.add(func)
return template_func

typed_template_func = _make_typed_template_function(func.__name__)
template_func = typed_template_func(func, _infer_types(func))
_template_funcs[name] = template_func
return template_func
typed_template_func.add(func)
_template_funcs[name] = typed_template_func

return typed_template_func

def __getitem__(self, types):
def _trampoline(func):
Expand All @@ -60,9 +188,8 @@ def _trampoline(func):
return template_func

typed_template_func = _make_typed_template_function(func.__name__)
template_func = typed_template_func(func, types)
_template_funcs[name] = template_func
return template_func
typed_template_func[types] = func
_template_funcs[name] = typed_template_func
return typed_template_func

return _trampoline

74 changes: 66 additions & 8 deletions test/test_template.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,115 @@
from funktools import template


class Foo: pass
class Bar: pass
class Baz: pass
class Qux: pass
class Fee: pass
class Foo:
pass


class Bar:
pass


class Baz:
pass


class Qux:
pass


class Fee:
pass


@template
def funk():
return "empty"


@template[Foo]
def funk():
return "Foo"


@template[Bar]
def funk():
return "Bar"


@template[Baz]
def funk(baz: Baz):
return "Baz"


@template[Foo, Bar, Baz]
def funk(baz: Baz):
return "Foo, Bar, Baz"


@template
def funk(val: int):
return "int"


def foo_qux():
return "Qux"


funk[Qux] = foo_qux


def foo_fee(fee: Fee):
return "Fee"


funk.add(foo_fee)

def test_calls() -> None:

def test_typed_calls() -> None:
assert funk() == "empty"
assert funk[()]() == "empty"
assert funk[Foo]() == "Foo"
assert funk[Bar]() == "Bar"
assert funk[Baz](Baz()) == "Baz"
assert funk(Baz()) == "Baz"
# TODO(lpe): Should this one eventually allow
# funk[Foo, Bar](Baz()) as an equivalent call?
assert funk[Foo, Bar, Baz](Baz()) == "Foo, Bar, Baz"
assert funk[int](31) == "int"
assert funk(42) == "int"
assert funk[Qux]() == "Qux"
assert funk[Fee](Fee()) == "Fee"
assert funk(Fee()) == "Fee"


@template
def funky(a: int, b) -> str:
return f"funky({a}: int, {b})"


@template
def funky(a: int, c: float) -> str:
return f"funky({a}: int, {c}: float)"


@template
def funky(a: int, b: float, *, c: str, d: tuple = (1,)) -> str:
return f"funky({a}: int, {b}: float, *, {c}: str, {d}: tuple)"


def test_arg_calls() -> None:
assert funky(1, 2.3) == "funky(1: int, 2.3)"
assert funky(1, b=2.3) == "funky(1: int, 2.3)"
assert funky(a=1, b=2.3) == "funky(1: int, 2.3)"
assert funky(b=2.3, a=1) == "funky(1: int, 2.3)"

assert funky(1, c=3.4) == "funky(1: int, 3.4: float)"
assert funky(a=1, c=3.4) == "funky(1: int, 3.4: float)"
assert funky(c=3.4, a=1) == "funky(1: int, 3.4: float)"

assert (
funky(1, 2.3, d=(1, 2), c="abc")
== "funky(1: int, 2.3: float, *, abc: str, (1, 2): tuple)"
)
assert (
funky(a=1, b=2.3, c="abc")
== "funky(1: int, 2.3: float, *, abc: str, (1,): tuple)"
)

0 comments on commit f1f2226

Please sign in to comment.