Skip to content

Commit

Permalink
Match functions by parameter names and types
Browse files Browse the repository at this point in the history
  • Loading branch information
LoganEvans committed Mar 24, 2024
1 parent 1ae36a6 commit e4cca03
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 61 deletions.
36 changes: 17 additions & 19 deletions examples/templates/src/example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,31 +52,29 @@ PYBIND11_MODULE(example, m) {
m.def("_get_double", &get<double>);
m.def("_get_Foo", &get<Foo>);

auto templateFunc =
pybind11::module_::import("funktools").attr("template").attr("Function");
auto getFunc = pybind11::module_::import("funktools")
.attr("template")
.attr("register_name")("get");

m.attr("get") = getFunc;

// Due to https://github.com/pybind/pybind11/issues/2486, basic types aren't
// automatically converted with `pybind11::type::of`.
m.attr("get") =
templateFunc("get", m.attr("_get_int"),
pybind11::make_tuple(pybind11::type::of(pybind11::int_())));
m.attr("get") = templateFunc(
"get", m.attr("_get_double"),
pybind11::make_tuple(pybind11::type::of(pybind11::float_())));
m.attr("get") = templateFunc("get", m.attr("_get_Foo"),
pybind11::make_tuple(pybind11::type::of<Foo>()));
getFunc[pybind11::type::of(pybind11::int_())] = m.attr("_get_int");
getFunc[pybind11::type::of(pybind11::float_())] = m.attr("_get_double");
getFunc[pybind11::type::of<Foo>()] = m.attr("_get_Foo");

m.def("_get_from_arg_int", &get_from_arg<int>);
m.def("_get_from_arg_double", &get_from_arg<double>);
m.def("_get_from_arg_Foo", &get_from_arg<Foo>);

m.attr("get_from_arg") =
templateFunc("get_from_arg", m.attr("_get_from_arg_int"),
pybind11::make_tuple(pybind11::type::of(pybind11::int_())));
m.attr("get_from_arg") = templateFunc(
"get_from_arg", m.attr("_get_from_arg_double"),
pybind11::make_tuple(pybind11::type::of(pybind11::float_())));
m.attr("get_from_arg") =
templateFunc("get_from_arg", m.attr("_get_from_arg_Foo"),
pybind11::make_tuple(pybind11::type::of<Foo>()));
auto getFromArgFunc = pybind11::module_::import("funktools")
.attr("template")
.attr("register_name")("get_from_arg");

m.attr("get_from_arg") = getFromArgFunc;

getFromArgFunc[pybind11::type::of(pybind11::int_())] = m.attr("_get_from_arg_int");
getFromArgFunc[pybind11::type::of(pybind11::float_())] = m.attr("_get_from_arg_double");
getFromArgFunc[pybind11::type::of<Foo>()] = m.attr("_get_from_arg_Foo");
}
4 changes: 4 additions & 0 deletions funktools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def __getattr__(attr: str) -> typing.Callable:
elif attr == "template":
from ._template import _template
return _template()
elif attr == "template":
from ._template import TemplateException
return TemplateException
else:
raise AttributeError(f"Module 'funktools' has no attribute '{attr}'")

Expand All @@ -33,6 +36,7 @@ def __getattr__(attr: str) -> typing.Callable:
'Register',
'Throttle',
'template',
'TemplateException',
]

def __dir__():
Expand Down
215 changes: 181 additions & 34 deletions funktools/_template.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,198 @@
import inspect
from typing import Any, Callable
import typing

_template_funcs: dict[str, Callable] = {}

class TemplateException(Exception):
pass

def _infer_types(func):
annotations = inspect.get_annotations(func)
args = inspect.getfullargspec(func).args
return tuple([annotations[arg] for arg in args])

_template_funcs: dict[str, typing.Callable] = {}

class _TemplateFunctionBase:
def __init__(self, func, types):
if not isinstance(types, tuple):
types = tuple([types])
self._funcs = {types: func}

def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
return self._funcs[types](*args, **kwargs)
class _FuncArgInfo:
def __init__(
self,
func: typing.Callable,
annotations: dict = None,
fullargspec: inspect.FullArgSpec = None,
):
self._func = func
self._annotations = annotations

def __getitem__(self, types):
if not isinstance(types, tuple):
types = tuple([types])
return self._funcs[tuple(types)]
try:
self._fullargspec = inspect.getfullargspec(self.get_func())
except TypeError as ex:
self._fullargspec = None

self._arg2type = None

def get_func(self):
return self._func

def annotations(self):
if self._annotations is None:
self._annotations = inspect.get_annotations(self.get_func())

return self._annotations

def fullargspec(self):
return self._fullargspec

def is_match(self, *args, **kwargs):
argspec = self.fullargspec()
if argspec is None:
return False

matched = set()

matched_via_default = 0
if argspec.kwonlydefaults:
matched.update(set(argspec.kwonlydefaults.keys()))
matched_via_default = len(matched.difference(set(kwargs.keys())))

for name, val in zip(argspec.args, args):
if want_type := self.annotations().get(name):
if not isinstance(val, want_type):
return False

matched.add(name)

def __setitem__(self, types, func):
if not isinstance(types, tuple):
types = (types,)
self._funcs[types] = func
legal_names = set(argspec.args + argspec.kwonlyargs)
for name, val in kwargs.items():
if name not in legal_names:
return False

def add(self, func):
self[_infer_types(func)] = func
if want_type := self.annotations().get(name):
if not isinstance(val, want_type):
return False

matched.add(name)

if (
len(matched)
== len(argspec.args) + len(argspec.kwonlyargs)
== len(args) + len(kwargs) + matched_via_default
):
return True

return False


def _make_typed_template_function(name: str):
"""Create a _TemplateFunctionBase object with type `name`"""

def __init__(self):
self._types2funcs = {}
self._func_arg_infos = []

def __call__(self, *args, **kwargs):
for func_arg_info in self._func_arg_infos:
if func_arg_info.is_match(*args, **kwargs):
return func_arg_info.get_func()(*args, **kwargs)

if len(args) == 1:
types = type(args[0])
else:
types = tuple([type(arg) for arg in args])

if func := self._types2funcs.get(types):
return func(*args, **kwargs)

raise TemplateException("Cannot find templated function matching signature")

__call__.__annotations__ = {}

def __getitem__(self, types):
try:
return self._types2funcs[types]
except KeyError:
raise TemplateException(
f"Cannot find templated function with types: {types}"
)

__getitem__.__annotations__ = {}

def __setitem__(self, types: typing.Type | tuple, func: typing.Callable):
self._types2funcs[types] = func

annotations = self.__class__.__getitem__.__annotations__

if isinstance(types, tuple):
types = typing.Tuple[types]

annotations["types"] = typing.Union[types, annotations.get("types", types)]

for param, types in inspect.get_annotations(func).items():
annotations[param] = typing.Union[types, annotations.get(param, types)]

annotations["return"] = typing.Callable

self._append_func_arg_info(func)

def add(self, func: typing.Callable):
annotations = inspect.get_annotations(func)
args_annotated = len(annotations)
if "return" in annotations:
args_annotated = args_annotated - 1

func_arg_info = self._append_func_arg_info(func, annotations)

types = None
argspec = func_arg_info.fullargspec()
if argspec and args_annotated == len(argspec.args):
types = tuple([annotations[arg] for arg in argspec.args])

if len(types) == 1:
types = types[0]

self[types] = func

def _append_func_arg_info(
self,
func: typing.Callable,
annotations: dict = None,
fullargspec: inspect.FullArgSpec = None,
):
arg_info = _FuncArgInfo(func, annotations, fullargspec)

call_annotations = self.__class__.__call__.__annotations__
if argspec := arg_info.fullargspec():
all_args = arg_info.fullargspec().args + arg_info.fullargspec().kwonlyargs
for arg in all_args:
arg_type = arg_info.annotations().get(arg, typing.Any)
call_annotations[arg] = arg_type | call_annotations.get(arg, arg_type)

self._func_arg_infos.append(arg_info)
return arg_info

TemplateFuncMeta = type(name, (type,), {})
return TemplateFuncMeta(name, (_TemplateFunctionBase,), {'__name__': name})
return TemplateFuncMeta(
name,
(),
{
"__init__": __init__,
"__call__": __call__,
"__getitem__": __getitem__,
"__setitem__": __setitem__,
"__name__": name,
"__annotations__": {},
"add": add,
"_append_func_arg_info": _append_func_arg_info,
},
)()


class _template:
def __call__(self, func):
name = func.__name__

if template_func := _template_funcs.get(name):
template_func[_infer_types(func)] = func
template_func.add(func)
return template_func

typed_template_func = _make_typed_template_function(func.__name__)
template_func = typed_template_func(func, _infer_types(func))
_template_funcs[name] = template_func
return template_func
typed_template_func = self.register_name(func.__name__)
typed_template_func.add(func)
return typed_template_func

def __getitem__(self, types):
def _trampoline(func):
Expand All @@ -59,10 +201,15 @@ def _trampoline(func):
template_func[types] = func
return template_func

typed_template_func = _make_typed_template_function(func.__name__)
template_func = typed_template_func(func, types)
_template_funcs[name] = template_func
return template_func
typed_template_func = self.register_name(func.__name__)
typed_template_func[types] = func
return typed_template_func

return _trampoline

def register_name(self, name: str):
typed_template_func = _make_typed_template_function(name)
_template_funcs[name] = typed_template_func

return typed_template_func

Loading

0 comments on commit e4cca03

Please sign in to comment.