From 21d957afc6241c57eaa364c70920e3927ff94fdd Mon Sep 17 00:00:00 2001 From: citruz <3756270+citruz@users.noreply.github.com> Date: Fri, 18 Mar 2022 02:57:02 +0100 Subject: [PATCH] stubgen: fix handling of Protocol and add testcase (#12129) ### Description This PR fixes #12072 by correctly handling `Protocol` definitions. Previously, the `Protocol` base class was removed when generating type stubs which causes problems with other packages that want to use type definition (because they see it as a regular class, not as a `Protocol`). ## Test Plan Added a testcase to the stubgen testset. Co-authored-by: 97littleleaf11 <11172084+97littleleaf11@users.noreply.github.com> --- mypy/stubgen.py | 6 ++++++ test-data/unit/stubgen.test | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index dafb446a835a..6db5aa75d102 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -851,6 +851,12 @@ def visit_class_def(self, o: ClassDef) -> None: base_types.append('metaclass=abc.ABCMeta') self.import_tracker.add_import('abc') self.import_tracker.require_name('abc') + elif self.analyzed and o.info.is_protocol: + type_str = 'Protocol' + if o.info.type_vars: + type_str += f'[{", ".join(o.info.type_vars)}]' + base_types.append(type_str) + self.add_typing_import('Protocol') if base_types: self.add('(%s)' % ', '.join(base_types)) self.add(':\n') diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 6592791f9aa3..62fae21df4f4 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -2581,6 +2581,30 @@ def f(x: int, y: int) -> int: ... @t.overload def f(x: t.Tuple[int, int]) -> int: ... +[case testProtocol_semanal] +from typing import Protocol, TypeVar + +class P(Protocol): + def f(self, x: int, y: int) -> str: + ... + +T = TypeVar('T') +T2 = TypeVar('T2') +class PT(Protocol[T, T2]): + def f(self, x: T) -> T2: + ... + +[out] +from typing import Protocol, TypeVar + +class P(Protocol): + def f(self, x: int, y: int) -> str: ... +T = TypeVar('T') +T2 = TypeVar('T2') + +class PT(Protocol[T, T2]): + def f(self, x: T) -> T2: ... + [case testNonDefaultKeywordOnlyArgAfterAsterisk] def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ... [out]