Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-89263: Add typing.get_overloads #31716

Merged
merged 37 commits into from
Apr 16, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
2ee377d
initial
JelleZijlstra Mar 6, 2022
831b565
Implementation, tests, and docs
JelleZijlstra Mar 7, 2022
f03f8a9
fix versionadded
JelleZijlstra Mar 7, 2022
404668a
Merge branch 'main' into funcregistry
JelleZijlstra Mar 8, 2022
7a5b0d1
make get_key_for_callable private
JelleZijlstra Mar 8, 2022
6998255
doc updates; remove unnecessary try-except
JelleZijlstra Mar 9, 2022
26bb908
Merge remote-tracking branch 'upstream/main' into funcregistry
JelleZijlstra Mar 27, 2022
f52b757
rename method
JelleZijlstra Mar 27, 2022
fc6a925
Don't store singledispatch in the registry
JelleZijlstra Mar 27, 2022
b524244
more tests
JelleZijlstra Mar 27, 2022
e95558e
and another
JelleZijlstra Mar 27, 2022
31fd72d
fix line length in new tests
JelleZijlstra Mar 27, 2022
7041ad3
Update Doc/library/functools.rst
JelleZijlstra Mar 27, 2022
e26b0db
Update Doc/library/typing.rst
JelleZijlstra Mar 27, 2022
1bf89fb
only for overload
JelleZijlstra Apr 2, 2022
83ac432
Merge remote-tracking branch 'upstream/main' into funcregistry
JelleZijlstra Apr 2, 2022
dfdbdc7
fix tests
JelleZijlstra Apr 2, 2022
e16c8d0
undo stray changes, fix NEWS entry
JelleZijlstra Apr 2, 2022
b3d2227
remove extra import
JelleZijlstra Apr 2, 2022
9727eee
Apply suggestions from code review
JelleZijlstra Apr 2, 2022
2e374b8
Apply suggestions from code review
JelleZijlstra Apr 3, 2022
ff03b12
Guido's feedback
JelleZijlstra Apr 3, 2022
17f0710
Optimizations suggested by Guido and Alex
JelleZijlstra Apr 3, 2022
2346970
inline _get_firstlineno, store outer objects for classmethod/staticme…
JelleZijlstra Apr 3, 2022
f2053a0
use defaultdict
JelleZijlstra Apr 3, 2022
b6131ad
another optimization
JelleZijlstra Apr 4, 2022
506bd66
Update Lib/typing.py
JelleZijlstra Apr 7, 2022
e9a2100
Merge remote-tracking branch 'upstream/main' into funcregistry
JelleZijlstra Apr 8, 2022
2b1a5cc
Merge remote-tracking branch 'upstream/main' into funcregistry
JelleZijlstra Apr 9, 2022
103bfd4
Simpler implementation (thanks Guido)
JelleZijlstra Apr 9, 2022
d453f7f
More comments and tests
JelleZijlstra Apr 9, 2022
450afeb
Merge remote-tracking branch 'upstream/main' into funcregistry
JelleZijlstra Apr 14, 2022
ea62287
simplify clear_overloads
JelleZijlstra Apr 14, 2022
905253c
use partial
JelleZijlstra Apr 14, 2022
debbf8a
add test
JelleZijlstra Apr 14, 2022
754c134
docs changes (thanks Alex)
JelleZijlstra Apr 14, 2022
1ad8224
Merge remote-tracking branch 'upstream/main' into funcregistry
JelleZijlstra Apr 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions Doc/library/functools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,8 @@ The :mod:`functools` module defines the following functions:
.. versionchanged:: 3.7
The :func:`register` attribute now supports using type annotations.

.. versionchanged:: 3.11
Implementation functions can now be retrieved using :func:`get_variants`.

.. class:: singledispatchmethod(func)

Expand Down Expand Up @@ -587,6 +589,9 @@ The :mod:`functools` module defines the following functions:

.. versionadded:: 3.8

.. versionchanged:: 3.11
Implementation functions are now registered using :func:`register_variant`.
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved


.. function:: update_wrapper(wrapper, wrapped, assigned=WRAPPER_ASSIGNMENTS, updated=WRAPPER_UPDATES)

Expand Down Expand Up @@ -664,6 +669,31 @@ The :mod:`functools` module defines the following functions:
would have been ``'wrapper'``, and the docstring of the original :func:`example`
would have been lost.

.. function:: get_variants(func)

Return all registered function variants for this function. Function variants are
objects that represent some subset of the functionality of a function, for
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
example overloads decorated with :func:`typing.overload` or :func:`singledispatch`
implementation functions.

Variants are registered by calling :func:`register_variant`.

.. versionadded:: 3.11

.. function:: register_variant(func, variant)

Register *variant* for function *func* that can later be retrieved using
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
:func:`get_variants`.

.. versionadded:: 3.11

.. function:: clear_variants(func=None)

Clear all registered variants for the given *func*. If *func* is None, clear
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
all variants.

.. versionadded:: 3.11


.. _partial-objects:

Expand Down
3 changes: 3 additions & 0 deletions Doc/library/typing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2254,6 +2254,9 @@ Functions and decorators

See :pep:`484` for details and comparison with other typing semantics.

.. versionchanged:: 3.11
Overloaded functions are now registered using :func:`functools.register_variant`.
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved

.. decorator:: final

A decorator to indicate to type checkers that the decorated method
Expand Down
64 changes: 62 additions & 2 deletions Lib/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from abc import get_cache_token
from collections import namedtuple
# import types, weakref # Deferred to single_dispatch()
import types
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
# import weakref # Deferred to single_dispatch()
from reprlib import recursive_repr
from _thread import RLock
from types import GenericAlias
Expand Down Expand Up @@ -653,6 +654,63 @@ def cache(user_function, /):
return lru_cache(maxsize=None)(user_function)


################################################################################
### Function variant registry
################################################################################

# {key: [variant]}
_variant_registry = {}
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved


def register_variant(func, variant):
"""Register a function variant."""
key = _get_key_for_callable(func)
_variant_registry.setdefault(key, []).append(variant)


def get_variants(func):
"""Get all function variants for the given function."""
key = _get_key_for_callable(func)
variants = list(_variant_registry.get(key, []))

# We directly retrieve variants from the singledispatch
# and singledispatchmethod registries.
if isinstance(func, singledispatchmethod):
variants += func.dispatcher.registry.values()
else:
try:
registry = func.registry
except AttributeError:
pass
else:
if isinstance(registry, types.MappingProxyType):
variants += registry.values()
return variants


def clear_variants(func=None):
"""Clear all variants for the given function (or all functions)."""
if func is None:
_variant_registry.clear()
else:
key = _get_key_for_callable(func)
_variant_registry.pop(key, None)


def _get_key_for_callable(func):
"""Return a key for the given callable.

This key can be used to register the callable in the variant registry
with register_variant() or to get variants for this callable with get_variants().

If no key can be created (because the object is not of a supported type), raise
AttributeError.
"""
# classmethod and staticmethod
func = getattr(func, "__func__", func)
return f"{func.__module__}.{func.__qualname__}"


################################################################################
### singledispatch() - single-dispatch generic function decorator
################################################################################
Expand Down Expand Up @@ -809,11 +867,12 @@ def singledispatch(func):
# There are many programs that use functools without singledispatch, so we
# trade-off making singledispatch marginally slower for the benefit of
# making start-up of such applications slightly faster.
import types, weakref
import weakref

registry = {}
dispatch_cache = weakref.WeakKeyDictionary()
cache_token = None
outer_func = func

def dispatch(cls):
"""generic_func.dispatch(cls) -> <function implementation>
Expand Down Expand Up @@ -948,6 +1007,7 @@ def _method(*args, **kwargs):

_method.__isabstractmethod__ = self.__isabstractmethod__
_method.register = self.register
_method.registry = self.dispatcher.registry
update_wrapper(_method, self.func)
return _method

Expand Down
167 changes: 167 additions & 0 deletions Lib/test/test_functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,6 +1890,173 @@ def cached_staticmeth(x, y):
return 3 * x + y


class MethodHolder:
@classmethod
def clsmethod(cls): ...
@staticmethod
def stmethod(): ...
def method(self): ...


class TestVariantRegistry(unittest.TestCase):
def test_get_key_for_callable(self):
self.assertEqual(
functools._get_key_for_callable(len),
"builtins.len",
)
self.assertEqual(
functools._get_key_for_callable(py_cached_func),
f"{__name__}.py_cached_func",
)
self.assertEqual(
functools._get_key_for_callable(MethodHolder.clsmethod),
f"{__name__}.MethodHolder.clsmethod",
)
self.assertEqual(
functools._get_key_for_callable(MethodHolder.stmethod),
f"{__name__}.MethodHolder.stmethod",
)
self.assertEqual(
functools._get_key_for_callable(MethodHolder.method),
f"{__name__}.MethodHolder.method",
)

def test_get_variants(self):
def func1():
pass

def func2():
pass

obj1 = object()
obj2 = object()
self.assertEqual(functools.get_variants(func1), [])
self.assertEqual(functools.get_variants(func2), [])

functools.register_variant(func1, obj1)
self.assertEqual(functools.get_variants(func1), [obj1])
self.assertEqual(functools.get_variants(func2), [])

functools.register_variant(func1, obj2)
self.assertEqual(functools.get_variants(func1), [obj1, obj2])
self.assertEqual(functools.get_variants(func2), [])

def test_clear_variants(self):
def func1():
pass

def func2():
pass

obj1 = object()

functools.register_variant(func1, obj1)
self.assertEqual(functools.get_variants(func1), [obj1])
self.assertEqual(functools.get_variants(func2), [])

functools.clear_variants(func2)
self.assertEqual(functools.get_variants(func1), [obj1])
self.assertEqual(functools.get_variants(func2), [])

functools.clear_variants(func1)
self.assertEqual(functools.get_variants(func1), [])
self.assertEqual(functools.get_variants(func2), [])

functools.register_variant(func1, obj1)
functools.register_variant(func2, obj1)
self.assertEqual(functools.get_variants(func1), [obj1])
self.assertEqual(functools.get_variants(func2), [obj1])

functools.clear_variants()
self.assertEqual(functools.get_variants(func1), [])
self.assertEqual(functools.get_variants(func2), [])

def test_singledispatch_interaction(self):
@functools.singledispatch
def func(obj):
return "base"

original_func = func.registry[object]
self.assertEqual(functools.get_variants(func), [original_func])

@func.register(int)
def func_int(obj):
return "int"

self.assertEqual(
functools.get_variants(func), [original_func, func_int]
)

def weird_func():
pass

weird_func.registry = 42
# shouldn't crash if the registry attribute exists but is not
# a mapping proxy
self.assertEqual(functools.get_variants(weird_func), [])

def test_singledispatchmethod_interaction(self):
class A:
@functools.singledispatchmethod
def t(self, arg):
self.arg = "base"

@t.register(int)
def int_t(self, arg):
self.arg = "int"

@t.register(str)
def str_t(self, arg):
self.arg = "str"

expected = [
A.t.registry[object],
A.int_t,
A.str_t,
]
self.assertEqual(functools.get_variants(A.t), expected)
method_object = A.__dict__["t"] # bypass the descriptor
self.assertEqual(functools.get_variants(method_object), expected)

def test_both_singledispatch_and_overload(self):
from typing import overload

def complex_func(arg: str) -> int:
...

str_overload = complex_func
overload(complex_func)

def complex_func(arg: int) -> str:
...

int_overload = complex_func
overload(complex_func)

@functools.singledispatch
def complex_func(arg: object):
raise NotImplementedError

@complex_func.register
def str_variant(arg: str) -> int:
return int(arg)

@complex_func.register
def int_variant(arg: int) -> str:
return str(arg)

self.assertEqual(
functools.get_variants(complex_func),
[
str_overload,
int_overload,
complex_func.registry[object],
str_variant,
int_variant,
],
)


class TestSingleDispatch(unittest.TestCase):
def test_simple_overloads(self):
@functools.singledispatch
Expand Down
36 changes: 35 additions & 1 deletion Lib/test/test_typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import collections
from functools import lru_cache
import functools
import inspect
import pickle
import re
Expand All @@ -9,7 +10,7 @@
from unittest import TestCase, main, skipUnless, skip
from copy import copy, deepcopy

from typing import Any, NoReturn, Never, assert_never
from typing import Any, NoReturn, Never, assert_never, overload
from typing import TypeVar, TypeVarTuple, Unpack, AnyStr
from typing import T, KT, VT # Not in __all__.
from typing import Union, Optional, Literal
Expand Down Expand Up @@ -3819,6 +3820,39 @@ def blah():

blah()

def test_variant_registry(self):
# Test the interaction with the variants registry in
# the functools module.
def blah():
pass

overload1 = blah
overload(blah)

def blah():
pass

overload2 = blah
overload(blah)

def blah():
pass

self.assertEqual(functools.get_variants(blah), [overload1, overload2])

def test_variant_registry_repeated(self):
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
for _ in range(2):
def blah():
pass

overload_func = blah
overload(blah)

def blah():
pass

self.assertEqual(functools.get_variants(blah), [overload_func])


# Definitions needed for features introduced in Python 3.6

Expand Down
Loading