Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle string type references in cast() #418

Merged
merged 2 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,13 +709,19 @@ def visit_Attribute(self, node: cst.Attribute) -> Optional[bool]:

def visit_Call(self, node: cst.Call) -> Optional[bool]:
self.__top_level_attribute_stack.append(None)
qnames = self.scope.get_qualified_names_for(node)
if any(qn.name in {"typing.NewType", "typing.TypeVar"} for qn in qnames):
qnames = {qn.name for qn in self.scope.get_qualified_names_for(node)}
if "typing.NewType" in qnames or "typing.TypeVar" in qnames:
node.func.visit(self)
self.__in_type_hint.add(node)
for arg in node.args[1:]:
arg.visit(self)
return False
if "typing.cast" in qnames:
node.func.visit(self)
self.__in_type_hint.add(node)
if len(node.args) > 0:
node.args[0].visit(self)
return False
return True

def leave_Call(self, original_node: cst.Call) -> None:
Expand Down Expand Up @@ -750,12 +756,10 @@ def _handle_string_annotation(
return False

def visit_Subscript(self, node: cst.Subscript) -> Optional[bool]:
qnames = self.scope.get_qualified_names_for(node.value)
if any(qn.name.startswith(("typing.", "typing_extensions.")) for qn in qnames):
qnames = {qn.name for qn in self.scope.get_qualified_names_for(node.value)}
if any(qn.startswith(("typing.", "typing_extensions.")) for qn in qnames):
self.__in_type_hint.add(node)
if any(
qn.name in {"typing.Literal", "typing_extensions.Literal"} for qn in qnames
):
if "typing.Literal" in qnames or "typing_extensions.Literal" in qnames:
self.__in_ignored_subscript.add(node)
return True

Expand Down
33 changes: 28 additions & 5 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,23 +1037,24 @@ def g():
def test_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from typing import Literal, NewType, Optional, TypeVar, Callable
from a import A, B, C, D, E, F, G, H, I, J
from typing import Literal, NewType, Optional, TypeVar, Callable, cast
from a import A, B, C, D, D2, E, E2, F, G, G2, H, I, J, K, K2
def x(a: A):
pass
def y(b: "B"):
pass
def z(c: Literal["C"]):
pass
DType = TypeVar("DType", bound=D)
EType = TypeVar("EType", bound="E")
DType = TypeVar("D2", bound=D)
EType = TypeVar("E2", bound="E")
FType = TypeVar("F")
GType = NewType("GType", "Optional[G]")
GType = NewType("G2", "Optional[G]")
HType = Optional["H"]
IType = Callable[..., I]

class Test(Generic[J]):
pass
casted = cast("K", "K2")
"""
)
imp = ensure_type(
Expand Down Expand Up @@ -1084,13 +1085,21 @@ class Test(Generic[J]):
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["D2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)

assignment = list(scope["E"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["E2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)

assignment = list(scope["F"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)
Expand All @@ -1102,6 +1111,10 @@ class Test(Generic[J]):
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["G2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)

assignment = list(scope["H"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
Expand All @@ -1121,6 +1134,16 @@ class Test(Generic[J]):
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)

assignment = list(scope["K"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)

assignment = list(scope["K2"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 0)

def test_node_of_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
Expand Down