Skip to content

Commit 078abb6

Browse files
authoredDec 25, 2021
bpo-46032: Check types in singledispatch's register() at declaration time (GH-30050)
The registry() method of functools.singledispatch() functions checks now the first argument or the first parameter annotation and raises a TypeError if it is not supported. Previously unsupported "types" were ignored (e.g. typing.List[int]) or caused an error at calling time (e.g. list[int]).
1 parent 1b30660 commit 078abb6

File tree

3 files changed

+89
-5
lines changed

3 files changed

+89
-5
lines changed
 

‎Lib/functools.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,7 @@ def _compose_mro(cls, types):
740740
# Remove entries which are already present in the __mro__ or unrelated.
741741
def is_related(typ):
742742
return (typ not in bases and hasattr(typ, '__mro__')
743+
and not isinstance(typ, GenericAlias)
743744
and issubclass(cls, typ))
744745
types = [n for n in types if is_related(n)]
745746
# Remove entries which are strict bases of other entries (they will end up
@@ -841,9 +842,13 @@ def _is_union_type(cls):
841842
from typing import get_origin, Union
842843
return get_origin(cls) in {Union, types.UnionType}
843844

844-
def _is_valid_union_type(cls):
845+
def _is_valid_dispatch_type(cls):
846+
if isinstance(cls, type) and not isinstance(cls, GenericAlias):
847+
return True
845848
from typing import get_args
846-
return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
849+
return (_is_union_type(cls) and
850+
all(isinstance(arg, type) and not isinstance(arg, GenericAlias)
851+
for arg in get_args(cls)))
847852

848853
def register(cls, func=None):
849854
"""generic_func.register(cls, func) -> func
@@ -852,9 +857,15 @@ def register(cls, func=None):
852857
853858
"""
854859
nonlocal cache_token
855-
if func is None:
856-
if isinstance(cls, type) or _is_valid_union_type(cls):
860+
if _is_valid_dispatch_type(cls):
861+
if func is None:
857862
return lambda f: register(cls, f)
863+
else:
864+
if func is not None:
865+
raise TypeError(
866+
f"Invalid first argument to `register()`. "
867+
f"{cls!r} is not a class or union type."
868+
)
858869
ann = getattr(cls, '__annotations__', {})
859870
if not ann:
860871
raise TypeError(
@@ -867,7 +878,7 @@ def register(cls, func=None):
867878
# only import typing if annotation parsing is necessary
868879
from typing import get_type_hints
869880
argname, cls = next(iter(get_type_hints(func).items()))
870-
if not isinstance(cls, type) and not _is_valid_union_type(cls):
881+
if not _is_valid_dispatch_type(cls):
871882
if _is_union_type(cls):
872883
raise TypeError(
873884
f"Invalid annotation for {argname!r}. "

‎Lib/test/test_functools.py

+68
Original file line numberDiff line numberDiff line change
@@ -2722,6 +2722,74 @@ def _(arg: int | float):
27222722
self.assertEqual(f(1), "types.UnionType")
27232723
self.assertEqual(f(1.0), "types.UnionType")
27242724

2725+
def test_register_genericalias(self):
2726+
@functools.singledispatch
2727+
def f(arg):
2728+
return "default"
2729+
2730+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2731+
f.register(list[int], lambda arg: "types.GenericAlias")
2732+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2733+
f.register(typing.List[int], lambda arg: "typing.GenericAlias")
2734+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2735+
f.register(list[int] | str, lambda arg: "types.UnionTypes(types.GenericAlias)")
2736+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2737+
f.register(typing.List[float] | bytes, lambda arg: "typing.Union[typing.GenericAlias]")
2738+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2739+
f.register(typing.Any, lambda arg: "typing.Any")
2740+
2741+
self.assertEqual(f([1]), "default")
2742+
self.assertEqual(f([1.0]), "default")
2743+
self.assertEqual(f(""), "default")
2744+
self.assertEqual(f(b""), "default")
2745+
2746+
def test_register_genericalias_decorator(self):
2747+
@functools.singledispatch
2748+
def f(arg):
2749+
return "default"
2750+
2751+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2752+
f.register(list[int])
2753+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2754+
f.register(typing.List[int])
2755+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2756+
f.register(list[int] | str)
2757+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2758+
f.register(typing.List[int] | str)
2759+
with self.assertRaisesRegex(TypeError, "Invalid first argument to "):
2760+
f.register(typing.Any)
2761+
2762+
def test_register_genericalias_annotation(self):
2763+
@functools.singledispatch
2764+
def f(arg):
2765+
return "default"
2766+
2767+
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2768+
@f.register
2769+
def _(arg: list[int]):
2770+
return "types.GenericAlias"
2771+
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2772+
@f.register
2773+
def _(arg: typing.List[float]):
2774+
return "typing.GenericAlias"
2775+
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2776+
@f.register
2777+
def _(arg: list[int] | str):
2778+
return "types.UnionType(types.GenericAlias)"
2779+
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2780+
@f.register
2781+
def _(arg: typing.List[float] | bytes):
2782+
return "typing.Union[typing.GenericAlias]"
2783+
with self.assertRaisesRegex(TypeError, "Invalid annotation for 'arg'"):
2784+
@f.register
2785+
def _(arg: typing.Any):
2786+
return "typing.Any"
2787+
2788+
self.assertEqual(f([1]), "default")
2789+
self.assertEqual(f([1.0]), "default")
2790+
self.assertEqual(f(""), "default")
2791+
self.assertEqual(f(b""), "default")
2792+
27252793

27262794
class CachedCostItem:
27272795
_cost = 1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
The ``registry()`` method of :func:`functools.singledispatch` functions
2+
checks now the first argument or the first parameter annotation and raises a
3+
TypeError if it is not supported. Previously unsupported "types" were
4+
ignored (e.g. ``typing.List[int]``) or caused an error at calling time (e.g.
5+
``list[int]``).

0 commit comments

Comments
 (0)
Please sign in to comment.