Skip to content

Commit

Permalink
Use scope lookup to append functions
Browse files Browse the repository at this point in the history
  • Loading branch information
LoganEvans committed Mar 24, 2024
1 parent e4cca03 commit 81166f6
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 129 deletions.
4 changes: 2 additions & 2 deletions examples/templates/src/example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ PYBIND11_MODULE(example, m) {

auto getFunc = pybind11::module_::import("funktools")
.attr("template")
.attr("register_name")("get");
.attr("make")("get");

m.attr("get") = getFunc;

Expand All @@ -70,7 +70,7 @@ PYBIND11_MODULE(example, m) {

auto getFromArgFunc = pybind11::module_::import("funktools")
.attr("template")
.attr("register_name")("get_from_arg");
.attr("make")("get_from_arg");

m.attr("get_from_arg") = getFromArgFunc;

Expand Down
291 changes: 164 additions & 127 deletions funktools/_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@ class TemplateException(Exception):
pass


_template_funcs: dict[str, typing.Callable] = {}
class TemplateFunction:
pass


class _FuncArgInfo:
"""Information about the arguments for a function."""

def __init__(
self,
func: typing.Callable,
annotations: dict = None,
fullargspec: inspect.FullArgSpec = None,
):
self._func = func
self._annotations = annotations

try:
self._fullargspec = inspect.getfullargspec(self.get_func())
Expand All @@ -26,19 +27,20 @@ def __init__(

self._arg2type = None

def get_func(self):
def get_func(self) -> typing.Callable:
return self._func

def annotations(self):
if self._annotations is None:
self._annotations = inspect.get_annotations(self.get_func())
def annotations(self) -> dict:
return self.fullargspec().annotations

return self._annotations

def fullargspec(self):
def fullargspec(self) -> None | inspect.FullArgSpec:
return self._fullargspec

def is_match(self, *args, **kwargs):
def is_match(self, *args, **kwargs) -> bool:
"""Check if this function can handle the provided arguments.
This checks argument names as well as types provided in annotations.
"""
argspec = self.fullargspec()
if argspec is None:
return False
Expand All @@ -50,8 +52,14 @@ def is_match(self, *args, **kwargs):
matched.update(set(argspec.kwonlydefaults.keys()))
matched_via_default = len(matched.difference(set(kwargs.keys())))

if (
len(argspec.args) + len(argspec.kwonlyargs)
!= len(args) + len(kwargs) + matched_via_default
):
return False

for name, val in zip(argspec.args, args):
if want_type := self.annotations().get(name):
if want_type := argspec.annotations.get(name):
if not isinstance(val, want_type):
return False

Expand All @@ -62,154 +70,183 @@ def is_match(self, *args, **kwargs):
if name not in legal_names:
return False

if want_type := self.annotations().get(name):
if want_type := argspec.annotations.get(name):
if not isinstance(val, want_type):
return False

matched.add(name)

if (
len(matched)
== len(argspec.args) + len(argspec.kwonlyargs)
== len(args) + len(kwargs) + matched_via_default
):
return True

return False


def _make_typed_template_function(name: str):
"""Create a _TemplateFunctionBase object with type `name`"""

def __init__(self):
self._types2funcs = {}
self._func_arg_infos = []
if len(matched) != len(argspec.args) + len(argspec.kwonlyargs):
return False

def __call__(self, *args, **kwargs):
for func_arg_info in self._func_arg_infos:
if func_arg_info.is_match(*args, **kwargs):
return func_arg_info.get_func()(*args, **kwargs)
return True

if len(args) == 1:
types = type(args[0])
else:
types = tuple([type(arg) for arg in args])

if func := self._types2funcs.get(types):
return func(*args, **kwargs)
class _template:
def __call__(self, func):
"""Add a func to the TemplateFunction.
If all arguments of func are annotated, this will allow the function to
be retrieved from the TemplateFunction using square brackets and the
specified types.
The function can also be found by matching argument types and names at a
callsite.
>>> template = _template() # Use `from funktools import template`
>>> @template
... def foo(a: int, b: str):
... return "foo(a: int, b: str)"
>>>
>>> @template
... def foo(a: float, b: int):
... return "foo(a: float, b: int)"
...
>>> @template
... def foo(a, *, c):
... return "foo(a, *, c)"
...
>>> foo(1, "b")
"foo(a: int, b: str)"
>>> foo(2.2, 1)
"foo(a: float, b: int)"
>>> foo(None, c="c")
"foo(a, *, c)"
"""
name = func.__name__

raise TemplateException("Cannot find templated function matching signature")
# Check the surrounding scope for an object with this name -- if
# it exists and it's a TemplateFunction, we avoid making a new one.
if (
template_func := inspect.stack()[1].frame.f_locals.get(name)
) and isinstance(template_func, TemplateFunction):
template_func.add(func)
return template_func

__call__.__annotations__ = {}
typed_template_func = self.make(func.__name__)
typed_template_func.add(func)
return typed_template_func

def __getitem__(self, types):
try:
return self._types2funcs[types]
except KeyError:
raise TemplateException(
f"Cannot find templated function with types: {types}"
)
"""Adds a function using types as a lookup key."""
def _trampoline(func):
name = func.__name__

__getitem__.__annotations__ = {}
if (
template_func := inspect.stack()[1].frame.f_locals.get(name)
) and isinstance(template_func, TemplateFunction):
template_func[types] = func
return template_func

def __setitem__(self, types: typing.Type | tuple, func: typing.Callable):
self._types2funcs[types] = func
typed_template_func = self.make(func.__name__)
typed_template_func[types] = func
return typed_template_func

annotations = self.__class__.__getitem__.__annotations__
return _trampoline

if isinstance(types, tuple):
types = typing.Tuple[types]
@staticmethod
def make(name: str):
"""Create a TemplateFunction object with type `name`."""

annotations["types"] = typing.Union[types, annotations.get("types", types)]
def __init__(self):
self._types2funcs = {}
self._func_arg_infos = []

for param, types in inspect.get_annotations(func).items():
annotations[param] = typing.Union[types, annotations.get(param, types)]
def __call__(self, *args, **kwargs):
for func_arg_info in self._func_arg_infos:
if func_arg_info.is_match(*args, **kwargs):
return func_arg_info.get_func()(*args, **kwargs)

annotations["return"] = typing.Callable
if len(args) == 1:
types = type(args[0])
else:
types = tuple([type(arg) for arg in args])

self._append_func_arg_info(func)
if func := self._types2funcs.get(types):
return func(*args, **kwargs)

def add(self, func: typing.Callable):
annotations = inspect.get_annotations(func)
args_annotated = len(annotations)
if "return" in annotations:
args_annotated = args_annotated - 1
raise TemplateException("Cannot find templated function matching signature")

func_arg_info = self._append_func_arg_info(func, annotations)
__call__.__annotations__ = {}

types = None
argspec = func_arg_info.fullargspec()
if argspec and args_annotated == len(argspec.args):
types = tuple([annotations[arg] for arg in argspec.args])
def __getitem__(self, types):
try:
return self._types2funcs[types]
except KeyError:
raise TemplateException(
f"Cannot find templated function with types: {types}"
)

if len(types) == 1:
types = types[0]
__getitem__.__annotations__ = {}

self[types] = func
def __setitem__(self, types: typing.Type | tuple, func: typing.Callable):
self._types2funcs[types] = func

def _append_func_arg_info(
self,
func: typing.Callable,
annotations: dict = None,
fullargspec: inspect.FullArgSpec = None,
):
arg_info = _FuncArgInfo(func, annotations, fullargspec)

call_annotations = self.__class__.__call__.__annotations__
if argspec := arg_info.fullargspec():
all_args = arg_info.fullargspec().args + arg_info.fullargspec().kwonlyargs
for arg in all_args:
arg_type = arg_info.annotations().get(arg, typing.Any)
call_annotations[arg] = arg_type | call_annotations.get(arg, arg_type)

self._func_arg_infos.append(arg_info)
return arg_info

TemplateFuncMeta = type(name, (type,), {})
return TemplateFuncMeta(
name,
(),
{
"__init__": __init__,
"__call__": __call__,
"__getitem__": __getitem__,
"__setitem__": __setitem__,
"__name__": name,
"__annotations__": {},
"add": add,
"_append_func_arg_info": _append_func_arg_info,
},
)()
annotations = self.__class__.__getitem__.__annotations__

if isinstance(types, tuple):
types = typing.Tuple[types]

class _template:
def __call__(self, func):
name = func.__name__
annotations["types"] = typing.Union[types, annotations.get("types", types)]

if template_func := _template_funcs.get(name):
template_func.add(func)
return template_func
for param, types in inspect.get_annotations(func).items():
annotations[param] = typing.Union[types, annotations.get(param, types)]

typed_template_func = self.register_name(func.__name__)
typed_template_func.add(func)
return typed_template_func
annotations["return"] = typing.Callable

def __getitem__(self, types):
def _trampoline(func):
name = func.__name__
if template_func := _template_funcs.get(name):
template_func[types] = func
return template_func
self._append_func_arg_info(func)

typed_template_func = self.register_name(func.__name__)
typed_template_func[types] = func
return typed_template_func
def add(self, func: typing.Callable):
func_arg_info = self._append_func_arg_info(func)
argspec = func_arg_info.fullargspec()

return _trampoline
annotations = argspec.annotations
args_annotated = len(annotations)
if "return" in annotations:
args_annotated = args_annotated - 1

def register_name(self, name: str):
typed_template_func = _make_typed_template_function(name)
_template_funcs[name] = typed_template_func
types = None
if argspec and args_annotated == len(argspec.args):
types = tuple([annotations[arg] for arg in argspec.args])

return typed_template_func
if len(types) == 1:
types = types[0]

self[types] = func

def _append_func_arg_info(
self,
func: typing.Callable,
fullargspec: inspect.FullArgSpec = None,
):
arg_info = _FuncArgInfo(func, fullargspec)

call_annotations = self.__class__.__call__.__annotations__
if argspec := arg_info.fullargspec():
all_args = (
arg_info.fullargspec().args + arg_info.fullargspec().kwonlyargs
)
for arg in all_args:
arg_type = arg_info.annotations().get(arg, typing.Any)
call_annotations[arg] = arg_type | call_annotations.get(
arg, arg_type
)

self._func_arg_infos.append(arg_info)
return arg_info

TemplateFuncMeta = type(name, (type,), {})
return TemplateFuncMeta(
name,
(TemplateFunction,),
{
"__init__": __init__,
"__call__": __call__,
"__getitem__": __getitem__,
"__setitem__": __setitem__,
"__name__": name,
"__annotations__": {},
"add": add,
"_append_func_arg_info": _append_func_arg_info,
},
)()

0 comments on commit 81166f6

Please sign in to comment.