diff --git a/examples/templates/src/example.cpp b/examples/templates/src/example.cpp index 877a743..bd2199f 100644 --- a/examples/templates/src/example.cpp +++ b/examples/templates/src/example.cpp @@ -54,7 +54,7 @@ PYBIND11_MODULE(example, m) { auto getFunc = pybind11::module_::import("funktools") .attr("template") - .attr("register_name")("get"); + .attr("make")("get"); m.attr("get") = getFunc; @@ -70,7 +70,7 @@ PYBIND11_MODULE(example, m) { auto getFromArgFunc = pybind11::module_::import("funktools") .attr("template") - .attr("register_name")("get_from_arg"); + .attr("make")("get_from_arg"); m.attr("get_from_arg") = getFromArgFunc; diff --git a/funktools/_template.py b/funktools/_template.py index a89f6b5..87e513a 100644 --- a/funktools/_template.py +++ b/funktools/_template.py @@ -6,18 +6,19 @@ class TemplateException(Exception): pass -_template_funcs: dict[str, typing.Callable] = {} +class TemplateFunction: + pass class _FuncArgInfo: + """Information about the arguments for a function.""" + def __init__( self, func: typing.Callable, - annotations: dict = None, fullargspec: inspect.FullArgSpec = None, ): self._func = func - self._annotations = annotations try: self._fullargspec = inspect.getfullargspec(self.get_func()) @@ -26,19 +27,20 @@ def __init__( self._arg2type = None - def get_func(self): + def get_func(self) -> typing.Callable: return self._func - def annotations(self): - if self._annotations is None: - self._annotations = inspect.get_annotations(self.get_func()) + def annotations(self) -> dict: + return self.fullargspec().annotations - return self._annotations - - def fullargspec(self): + def fullargspec(self) -> None | inspect.FullArgSpec: return self._fullargspec - def is_match(self, *args, **kwargs): + def is_match(self, *args, **kwargs) -> bool: + """Check if this function can handle the provided arguments. + + This checks argument names as well as types provided in annotations. + """ argspec = self.fullargspec() if argspec is None: return False @@ -50,8 +52,14 @@ def is_match(self, *args, **kwargs): matched.update(set(argspec.kwonlydefaults.keys())) matched_via_default = len(matched.difference(set(kwargs.keys()))) + if ( + len(argspec.args) + len(argspec.kwonlyargs) + != len(args) + len(kwargs) + matched_via_default + ): + return False + for name, val in zip(argspec.args, args): - if want_type := self.annotations().get(name): + if want_type := argspec.annotations.get(name): if not isinstance(val, want_type): return False @@ -62,154 +70,183 @@ def is_match(self, *args, **kwargs): if name not in legal_names: return False - if want_type := self.annotations().get(name): + if want_type := argspec.annotations.get(name): if not isinstance(val, want_type): return False matched.add(name) - if ( - len(matched) - == len(argspec.args) + len(argspec.kwonlyargs) - == len(args) + len(kwargs) + matched_via_default - ): - return True - - return False - - -def _make_typed_template_function(name: str): - """Create a _TemplateFunctionBase object with type `name`""" - - def __init__(self): - self._types2funcs = {} - self._func_arg_infos = [] + if len(matched) != len(argspec.args) + len(argspec.kwonlyargs): + return False - def __call__(self, *args, **kwargs): - for func_arg_info in self._func_arg_infos: - if func_arg_info.is_match(*args, **kwargs): - return func_arg_info.get_func()(*args, **kwargs) + return True - if len(args) == 1: - types = type(args[0]) - else: - types = tuple([type(arg) for arg in args]) - if func := self._types2funcs.get(types): - return func(*args, **kwargs) +class _template: + def __call__(self, func): + """Add a func to the TemplateFunction. + + If all arguments of func are annotated, this will allow the function to + be retrieved from the TemplateFunction using square brackets and the + specified types. + + The function can also be found by matching argument types and names at a + callsite. + + >>> template = _template() # Use `from funktools import template` + >>> @template + ... def foo(a: int, b: str): + ... return "foo(a: int, b: str)" + >>> + >>> @template + ... def foo(a: float, b: int): + ... return "foo(a: float, b: int)" + ... + >>> @template + ... def foo(a, *, c): + ... return "foo(a, *, c)" + ... + >>> foo(1, "b") + "foo(a: int, b: str)" + >>> foo(2.2, 1) + "foo(a: float, b: int)" + >>> foo(None, c="c") + "foo(a, *, c)" + """ + name = func.__name__ - raise TemplateException("Cannot find templated function matching signature") + # Check the surrounding scope for an object with this name -- if + # it exists and it's a TemplateFunction, we avoid making a new one. + if ( + template_func := inspect.stack()[1].frame.f_locals.get(name) + ) and isinstance(template_func, TemplateFunction): + template_func.add(func) + return template_func - __call__.__annotations__ = {} + typed_template_func = self.make(func.__name__) + typed_template_func.add(func) + return typed_template_func def __getitem__(self, types): - try: - return self._types2funcs[types] - except KeyError: - raise TemplateException( - f"Cannot find templated function with types: {types}" - ) + """Adds a function using types as a lookup key.""" + def _trampoline(func): + name = func.__name__ - __getitem__.__annotations__ = {} + if ( + template_func := inspect.stack()[1].frame.f_locals.get(name) + ) and isinstance(template_func, TemplateFunction): + template_func[types] = func + return template_func - def __setitem__(self, types: typing.Type | tuple, func: typing.Callable): - self._types2funcs[types] = func + typed_template_func = self.make(func.__name__) + typed_template_func[types] = func + return typed_template_func - annotations = self.__class__.__getitem__.__annotations__ + return _trampoline - if isinstance(types, tuple): - types = typing.Tuple[types] + @staticmethod + def make(name: str): + """Create a TemplateFunction object with type `name`.""" - annotations["types"] = typing.Union[types, annotations.get("types", types)] + def __init__(self): + self._types2funcs = {} + self._func_arg_infos = [] - for param, types in inspect.get_annotations(func).items(): - annotations[param] = typing.Union[types, annotations.get(param, types)] + def __call__(self, *args, **kwargs): + for func_arg_info in self._func_arg_infos: + if func_arg_info.is_match(*args, **kwargs): + return func_arg_info.get_func()(*args, **kwargs) - annotations["return"] = typing.Callable + if len(args) == 1: + types = type(args[0]) + else: + types = tuple([type(arg) for arg in args]) - self._append_func_arg_info(func) + if func := self._types2funcs.get(types): + return func(*args, **kwargs) - def add(self, func: typing.Callable): - annotations = inspect.get_annotations(func) - args_annotated = len(annotations) - if "return" in annotations: - args_annotated = args_annotated - 1 + raise TemplateException("Cannot find templated function matching signature") - func_arg_info = self._append_func_arg_info(func, annotations) + __call__.__annotations__ = {} - types = None - argspec = func_arg_info.fullargspec() - if argspec and args_annotated == len(argspec.args): - types = tuple([annotations[arg] for arg in argspec.args]) + def __getitem__(self, types): + try: + return self._types2funcs[types] + except KeyError: + raise TemplateException( + f"Cannot find templated function with types: {types}" + ) - if len(types) == 1: - types = types[0] + __getitem__.__annotations__ = {} - self[types] = func + def __setitem__(self, types: typing.Type | tuple, func: typing.Callable): + self._types2funcs[types] = func - def _append_func_arg_info( - self, - func: typing.Callable, - annotations: dict = None, - fullargspec: inspect.FullArgSpec = None, - ): - arg_info = _FuncArgInfo(func, annotations, fullargspec) - - call_annotations = self.__class__.__call__.__annotations__ - if argspec := arg_info.fullargspec(): - all_args = arg_info.fullargspec().args + arg_info.fullargspec().kwonlyargs - for arg in all_args: - arg_type = arg_info.annotations().get(arg, typing.Any) - call_annotations[arg] = arg_type | call_annotations.get(arg, arg_type) - - self._func_arg_infos.append(arg_info) - return arg_info - - TemplateFuncMeta = type(name, (type,), {}) - return TemplateFuncMeta( - name, - (), - { - "__init__": __init__, - "__call__": __call__, - "__getitem__": __getitem__, - "__setitem__": __setitem__, - "__name__": name, - "__annotations__": {}, - "add": add, - "_append_func_arg_info": _append_func_arg_info, - }, - )() + annotations = self.__class__.__getitem__.__annotations__ + if isinstance(types, tuple): + types = typing.Tuple[types] -class _template: - def __call__(self, func): - name = func.__name__ + annotations["types"] = typing.Union[types, annotations.get("types", types)] - if template_func := _template_funcs.get(name): - template_func.add(func) - return template_func + for param, types in inspect.get_annotations(func).items(): + annotations[param] = typing.Union[types, annotations.get(param, types)] - typed_template_func = self.register_name(func.__name__) - typed_template_func.add(func) - return typed_template_func + annotations["return"] = typing.Callable - def __getitem__(self, types): - def _trampoline(func): - name = func.__name__ - if template_func := _template_funcs.get(name): - template_func[types] = func - return template_func + self._append_func_arg_info(func) - typed_template_func = self.register_name(func.__name__) - typed_template_func[types] = func - return typed_template_func + def add(self, func: typing.Callable): + func_arg_info = self._append_func_arg_info(func) + argspec = func_arg_info.fullargspec() - return _trampoline + annotations = argspec.annotations + args_annotated = len(annotations) + if "return" in annotations: + args_annotated = args_annotated - 1 - def register_name(self, name: str): - typed_template_func = _make_typed_template_function(name) - _template_funcs[name] = typed_template_func + types = None + if argspec and args_annotated == len(argspec.args): + types = tuple([annotations[arg] for arg in argspec.args]) - return typed_template_func + if len(types) == 1: + types = types[0] + + self[types] = func + def _append_func_arg_info( + self, + func: typing.Callable, + fullargspec: inspect.FullArgSpec = None, + ): + arg_info = _FuncArgInfo(func, fullargspec) + + call_annotations = self.__class__.__call__.__annotations__ + if argspec := arg_info.fullargspec(): + all_args = ( + arg_info.fullargspec().args + arg_info.fullargspec().kwonlyargs + ) + for arg in all_args: + arg_type = arg_info.annotations().get(arg, typing.Any) + call_annotations[arg] = arg_type | call_annotations.get( + arg, arg_type + ) + + self._func_arg_infos.append(arg_info) + return arg_info + + TemplateFuncMeta = type(name, (type,), {}) + return TemplateFuncMeta( + name, + (TemplateFunction,), + { + "__init__": __init__, + "__call__": __call__, + "__getitem__": __getitem__, + "__setitem__": __setitem__, + "__name__": name, + "__annotations__": {}, + "add": add, + "_append_func_arg_info": _append_func_arg_info, + }, + )()