Skip to content

Commit d6b6480

Browse files
0x000Atushar-deepsource
authored andcommitted
stubgen: preserve string arguments in annotations (python#11292)
Fixes python#11222
1 parent 61ecaf8 commit d6b6480

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

mypy/stubgen.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def visit_unbound_type(self, t: UnboundType) -> str:
255255
s = t.name
256256
self.stubgen.import_tracker.require_name(s)
257257
if t.args:
258-
s += '[{}]'.format(self.list_str(t.args))
258+
s += '[{}]'.format(self.args_str(t.args))
259259
return s
260260

261261
def visit_none_type(self, t: NoneType) -> str:
@@ -264,6 +264,22 @@ def visit_none_type(self, t: NoneType) -> str:
264264
def visit_type_list(self, t: TypeList) -> str:
265265
return '[{}]'.format(self.list_str(t.items))
266266

267+
def args_str(self, args: Iterable[Type]) -> str:
268+
"""Convert an array of arguments to strings and join the results with commas.
269+
270+
The main difference from list_str is the preservation of quotes for string
271+
arguments
272+
"""
273+
types = ['builtins.bytes', 'builtins.unicode']
274+
res = []
275+
for arg in args:
276+
arg_str = arg.accept(self)
277+
if isinstance(arg, UnboundType) and arg.original_str_fallback in types:
278+
res.append("'{}'".format(arg_str))
279+
else:
280+
res.append(arg_str)
281+
return ', '.join(res)
282+
267283

268284
class AliasPrinter(NodeVisitor[str]):
269285
"""Visitor used to collect type aliases _and_ type variable definitions.

test-data/unit/stubgen.test

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,25 @@ def g(x: Foo = Foo()) -> Bar: ...
6464
def f(x: Foo) -> Bar: ...
6565
def g(x: Foo = ...) -> Bar: ...
6666

67+
[case testPreserveFunctionAnnotationWithArgs]
68+
def f(x: foo['x']) -> bar: ...
69+
def g(x: foo[x]) -> bar: ...
70+
def h(x: foo['x', 'y']) -> bar: ...
71+
def i(x: foo[x, y]) -> bar: ...
72+
def j(x: foo['x', y]) -> bar: ...
73+
def k(x: foo[x, 'y']) -> bar: ...
74+
def lit_str(x: Literal['str']) -> Literal['str']: ...
75+
def lit_int(x: Literal[1]) -> Literal[1]: ...
76+
[out]
77+
def f(x: foo['x']) -> bar: ...
78+
def g(x: foo[x]) -> bar: ...
79+
def h(x: foo['x', 'y']) -> bar: ...
80+
def i(x: foo[x, y]) -> bar: ...
81+
def j(x: foo['x', y]) -> bar: ...
82+
def k(x: foo[x, 'y']) -> bar: ...
83+
def lit_str(x: Literal['str']) -> Literal['str']: ...
84+
def lit_int(x: Literal[1]) -> Literal[1]: ...
85+
6786
[case testPreserveVarAnnotation]
6887
x: Foo
6988
[out]

0 commit comments

Comments
 (0)