Skip to content

Commit

Permalink
stubgen: fix handling of Protocol and add testcase (#12129)
Browse files Browse the repository at this point in the history
### Description

This PR fixes #12072 by correctly handling `Protocol` definitions. Previously, the `Protocol` base class was removed when generating type stubs which causes problems with other packages that want to use type definition (because they see it as a regular class, not as a `Protocol`).

## Test Plan

Added a testcase to the stubgen testset.

Co-authored-by: 97littleleaf11 <11172084+97littleleaf11@users.noreply.github.com>
  • Loading branch information
citruz and 97littleleaf11 committed Mar 18, 2022
1 parent aa0d186 commit 21d957a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
6 changes: 6 additions & 0 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,12 @@ def visit_class_def(self, o: ClassDef) -> None:
base_types.append('metaclass=abc.ABCMeta')
self.import_tracker.add_import('abc')
self.import_tracker.require_name('abc')
elif self.analyzed and o.info.is_protocol:
type_str = 'Protocol'
if o.info.type_vars:
type_str += f'[{", ".join(o.info.type_vars)}]'
base_types.append(type_str)
self.add_typing_import('Protocol')
if base_types:
self.add('(%s)' % ', '.join(base_types))
self.add(':\n')
Expand Down
24 changes: 24 additions & 0 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -2581,6 +2581,30 @@ def f(x: int, y: int) -> int: ...
@t.overload
def f(x: t.Tuple[int, int]) -> int: ...

[case testProtocol_semanal]
from typing import Protocol, TypeVar

class P(Protocol):
def f(self, x: int, y: int) -> str:
...

T = TypeVar('T')
T2 = TypeVar('T2')
class PT(Protocol[T, T2]):
def f(self, x: T) -> T2:
...

[out]
from typing import Protocol, TypeVar

class P(Protocol):
def f(self, x: int, y: int) -> str: ...
T = TypeVar('T')
T2 = TypeVar('T2')

class PT(Protocol[T, T2]):
def f(self, x: T) -> T2: ...

[case testNonDefaultKeywordOnlyArgAfterAsterisk]
def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...
[out]
Expand Down

0 comments on commit 21d957a

Please sign in to comment.