Skip to content

Commit 5c73e6a

Browse files
authored
Support callback protocols (#5463)
Fixes #5453 The issue proposed an interesting idea. Allow callables as subtypes of protocols with `__call__`. IMO this is not just reasonable and type safe, but is also more clear that the extended callable syntax in `mypy_extensions`. For example: ```python class Combiner(Protocol): def __call__(self, *vals: bytes, maxlen: Optional[int] = None) -> List[bytes]: ... def batch_proc(data: Iterable[bytes], cb_results: Combiner) -> bytes: ... ``` The callback protocols: * Have cleaner familiar syntax in contrast to `Callable[[VarArg(bytes), DefaultNamedArg(Optional[int], 'maxlen')], List[bytes]]` (compare to above) * Allow to be more explicit/flexible about binding of type variables in generic callbacks (see tests for examples) * Support overloaded callbacks (this is simply impossible with the current extended callable syntax) * Are easy to implement If this will get some traction, I would actually propose to deprecate extended callable syntax in favor of callback protocols.
1 parent 85418fb commit 5c73e6a

File tree

8 files changed

+328
-11
lines changed

8 files changed

+328
-11
lines changed

docs/source/kinds_of_types.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ Any)`` function signature. Example:
186186
arbitrary_call(open) # Error: does not return an int
187187
arbitrary_call(1) # Error: 'int' is not callable
188188
189+
In situations where more precise or complex types of callbacks are
190+
necessary one can use flexible :ref:`callback protocols <callback_protocols>`.
189191
Lambdas are also supported. The lambda argument and return value types
190192
cannot be given explicitly; they are always inferred based on context
191193
using bidirectional type inference:

docs/source/protocols.rst

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,3 +409,54 @@ in ``typing`` such as ``Iterable``.
409409
``isinstance()`` with protocols is not completely safe at runtime.
410410
For example, signatures of methods are not checked. The runtime
411411
implementation only checks that all protocol members are defined.
412+
413+
.. _callback_protocols:
414+
415+
Callback protocols
416+
******************
417+
418+
Protocols can be used to define flexible callback types that are hard
419+
(or even impossible) to express using the ``Callable[...]`` syntax, such as variadic,
420+
overloaded, and complex generic callbacks. They are defined with a special ``__call__``
421+
member:
422+
423+
.. code-block:: python
424+
425+
from typing import Optional, Iterable, List
426+
from typing_extensions import Protocol
427+
428+
class Combiner(Protocol):
429+
def __call__(self, *vals: bytes, maxlen: Optional[int] = None) -> List[bytes]: ...
430+
431+
def batch_proc(data: Iterable[bytes], cb_results: Combiner) -> bytes:
432+
for item in data:
433+
...
434+
435+
def good_cb(*vals: bytes, maxlen: Optional[int] = None) -> List[bytes]:
436+
...
437+
def bad_cb(*vals: bytes, maxitems: Optional[int]) -> List[bytes]:
438+
...
439+
440+
batch_proc([], good_cb) # OK
441+
batch_proc([], bad_cb) # Error! Argument 2 has incompatible type because of
442+
# different name and kind in the callback
443+
444+
Callback protocols and ``Callable[...]`` types can be used interchangeably.
445+
Keyword argument names in ``__call__`` methods must be identical, unless
446+
a double underscore prefix is used. For example:
447+
448+
.. code-block:: python
449+
450+
from typing import Callable, TypeVar
451+
from typing_extensions import Protocol
452+
453+
T = TypeVar('T')
454+
455+
class Copy(Protocol):
456+
def __call__(self, __origin: T) -> T: ...
457+
458+
copy_a: Callable[[T], T]
459+
copy_b: Copy
460+
461+
copy_a = copy_b # OK
462+
copy_b = copy_a # Also OK

mypy/checker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3174,6 +3174,11 @@ def check_subtype(self, subtype: Type, supertype: Type, context: Context,
31743174
call = find_member('__call__', subtype, subtype)
31753175
if call:
31763176
self.msg.note_call(subtype, call, context)
3177+
if isinstance(subtype, (CallableType, Overloaded)) and isinstance(supertype, Instance):
3178+
if supertype.type.is_protocol and supertype.type.protocol_members == ['__call__']:
3179+
call = find_member('__call__', supertype, subtype)
3180+
assert call is not None
3181+
self.msg.note_call(supertype, call, context)
31773182
return False
31783183

31793184
def contains_none(self, t: Type) -> bool:

mypy/constraints.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,18 @@ def visit_type_var(self, template: TypeVarType) -> List[Constraint]:
306306
def visit_instance(self, template: Instance) -> List[Constraint]:
307307
original_actual = actual = self.actual
308308
res = [] # type: List[Constraint]
309+
if isinstance(actual, (CallableType, Overloaded)) and template.type.is_protocol:
310+
if template.type.protocol_members == ['__call__']:
311+
# Special case: a generic callback protocol
312+
if not any(is_same_type(template, t) for t in template.type.inferring):
313+
template.type.inferring.append(template)
314+
call = mypy.subtypes.find_member('__call__', template, actual)
315+
assert call is not None
316+
if mypy.subtypes.is_subtype(actual, erase_typevars(call)):
317+
subres = infer_constraints(call, actual, self.direction)
318+
res.extend(subres)
319+
template.type.inferring.pop()
320+
return res
309321
if isinstance(actual, CallableType) and actual.fallback is not None:
310322
actual = actual.fallback
311323
if isinstance(actual, Overloaded) and actual.fallback is not None:

mypy/join.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from mypy.maptype import map_instance_to_supertype
1212
from mypy.subtypes import (
1313
is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype,
14-
is_protocol_implementation
14+
is_protocol_implementation, find_member
1515
)
1616
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT
1717

@@ -154,6 +154,10 @@ def visit_instance(self, t: Instance) -> Type:
154154
return nominal
155155
return structural
156156
elif isinstance(self.s, FunctionLike):
157+
if t.type.is_protocol:
158+
call = unpack_callback_protocol(t)
159+
if call:
160+
return join_types(call, self.s)
157161
return join_types(t, self.s.fallback)
158162
elif isinstance(self.s, TypeType):
159163
return join_types(t, self.s)
@@ -174,8 +178,11 @@ def visit_callable_type(self, t: CallableType) -> Type:
174178
elif isinstance(self.s, Overloaded):
175179
# Switch the order of arguments to that we'll get to visit_overloaded.
176180
return join_types(t, self.s)
177-
else:
178-
return join_types(t.fallback, self.s)
181+
elif isinstance(self.s, Instance) and self.s.type.is_protocol:
182+
call = unpack_callback_protocol(self.s)
183+
if call:
184+
return join_types(t, call)
185+
return join_types(t.fallback, self.s)
179186

180187
def visit_overloaded(self, t: Overloaded) -> Type:
181188
# This is more complex than most other cases. Here are some
@@ -224,6 +231,10 @@ def visit_overloaded(self, t: Overloaded) -> Type:
224231
else:
225232
return Overloaded(result)
226233
return join_types(t.fallback, s.fallback)
234+
elif isinstance(s, Instance) and s.type.is_protocol:
235+
call = unpack_callback_protocol(s)
236+
if call:
237+
return join_types(t, call)
227238
return join_types(t.fallback, s)
228239

229240
def visit_tuple_type(self, t: TupleType) -> Type:
@@ -436,3 +447,10 @@ def join_type_list(types: List[Type]) -> Type:
436447
for t in types[1:]:
437448
joined = join_types(joined, t)
438449
return joined
450+
451+
452+
def unpack_callback_protocol(t: Instance) -> Optional[Type]:
453+
assert t.type.is_protocol
454+
if t.type.protocol_members == ['__call__']:
455+
return find_member('__call__', t, t)
456+
return None

mypy/meet.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from collections import OrderedDict
22
from typing import List, Optional, Tuple
33

4-
from mypy.join import is_similar_callables, combine_similar_callables, join_type_list
4+
from mypy.join import (
5+
is_similar_callables, combine_similar_callables, join_type_list, unpack_callback_protocol
6+
)
57
from mypy.types import (
68
Type, AnyType, TypeVisitor, UnboundType, NoneTyp, TypeVarType, Instance, CallableType,
79
TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType,
@@ -297,12 +299,15 @@ def visit_instance(self, t: Instance) -> Type:
297299
return UninhabitedType()
298300
else:
299301
return NoneTyp()
302+
elif isinstance(self.s, FunctionLike) and t.type.is_protocol:
303+
call = unpack_callback_protocol(t)
304+
if call:
305+
return meet_types(call, self.s)
300306
elif isinstance(self.s, TypeType):
301307
return meet_types(t, self.s)
302308
elif isinstance(self.s, TupleType):
303309
return meet_types(t, self.s)
304-
else:
305-
return self.default(self.s)
310+
return self.default(self.s)
306311

307312
def visit_callable_type(self, t: CallableType) -> Type:
308313
if isinstance(self.s, CallableType) and is_similar_callables(t, self.s):
@@ -313,8 +318,11 @@ def visit_callable_type(self, t: CallableType) -> Type:
313318
# Return a plain None or <uninhabited> instead of a weird function.
314319
return self.default(self.s)
315320
return result
316-
else:
317-
return self.default(self.s)
321+
elif isinstance(self.s, Instance) and self.s.type.is_protocol:
322+
call = unpack_callback_protocol(self.s)
323+
if call:
324+
return meet_types(t, call)
325+
return self.default(self.s)
318326

319327
def visit_overloaded(self, t: Overloaded) -> Type:
320328
# TODO: Implement a better algorithm that covers at least the same cases
@@ -329,6 +337,10 @@ def visit_overloaded(self, t: Overloaded) -> Type:
329337
return t
330338
else:
331339
return meet_types(t.fallback, s.fallback)
340+
elif isinstance(self.s, Instance) and self.s.type.is_protocol:
341+
call = unpack_callback_protocol(self.s)
342+
if call:
343+
return meet_types(t, call)
332344
return meet_types(t.fallback, s)
333345

334346
def visit_tuple_type(self, t: TupleType) -> Type:

mypy/subtypes.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,13 @@ def visit_callable_type(self, left: CallableType) -> bool:
249249
elif isinstance(right, Overloaded):
250250
return all(self._is_subtype(left, item) for item in right.items())
251251
elif isinstance(right, Instance):
252+
if right.type.is_protocol and right.type.protocol_members == ['__call__']:
253+
# OK, a callable can implement a protocol with a single `__call__` member.
254+
# TODO: we should probably explicitly exclude self-types in this case.
255+
call = find_member('__call__', right, left)
256+
assert call is not None
257+
if self._is_subtype(left, call):
258+
return True
252259
return self._is_subtype(left.fallback, right)
253260
elif isinstance(right, TypeType):
254261
# This is unsound, we don't check the __init__ signature.
@@ -315,6 +322,12 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool:
315322
def visit_overloaded(self, left: Overloaded) -> bool:
316323
right = self.right
317324
if isinstance(right, Instance):
325+
if right.type.is_protocol and right.type.protocol_members == ['__call__']:
326+
# same as for CallableType
327+
call = find_member('__call__', right, left)
328+
assert call is not None
329+
if self._is_subtype(left, call):
330+
return True
318331
return self._is_subtype(left.fallback, right)
319332
elif isinstance(right, CallableType):
320333
for item in left.items():
@@ -439,6 +452,7 @@ def f(self) -> A: ...
439452
# nominal subtyping currently ignores '__init__' and '__new__' signatures
440453
if member in ('__init__', '__new__'):
441454
continue
455+
ignore_names = member != '__call__' # __call__ can be passed kwargs
442456
# The third argument below indicates to what self type is bound.
443457
# We always bind self to the subtype. (Similarly to nominal types).
444458
supertype = find_member(member, right, left)
@@ -453,7 +467,7 @@ def f(self) -> A: ...
453467
# Nominal check currently ignores arg names
454468
# NOTE: If we ever change this, be sure to also change the call to
455469
# SubtypeVisitor.build_subtype_kind(...) down below.
456-
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True)
470+
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=ignore_names)
457471
else:
458472
is_compat = is_proper_subtype(subtype, supertype)
459473
if not is_compat:
@@ -476,8 +490,9 @@ def f(self) -> A: ...
476490
return False
477491

478492
if not proper_subtype:
479-
# Nominal check currently ignores arg names
480-
subtype_kind = SubtypeVisitor.build_subtype_kind(ignore_pos_arg_names=True)
493+
# Nominal check currently ignores arg names, but __call__ is special for protocols
494+
ignore_names = right.type.protocol_members != ['__call__']
495+
subtype_kind = SubtypeVisitor.build_subtype_kind(ignore_pos_arg_names=ignore_names)
481496
else:
482497
subtype_kind = ProperSubtypeVisitor.build_subtype_kind()
483498
TypeState.record_subtype_cache_entry(subtype_kind, left, right)

0 commit comments

Comments
 (0)