Skip to content

Commit

Permalink
Add Access.is_type_hint for types used in classdef base and assignmen…
Browse files Browse the repository at this point in the history
…t values (#406)
  • Loading branch information
luciawlli authored Oct 28, 2020
1 parent 01c8098 commit a1b1ae4
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 13 deletions.
36 changes: 29 additions & 7 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,17 @@ def __new__(cls) -> "Tree":

is_annotation: bool

is_type_hint: bool

__assignments: Set["BaseAssignment"]

def __init__(self, node: cst.Name, scope: "Scope", is_annotation: bool) -> None:
def __init__(
self, node: cst.Name, scope: "Scope", is_annotation: bool, is_type_hint: bool
) -> None:
self.node = node
self.scope = scope
self.is_annotation = is_annotation
self.is_type_hint = is_type_hint
self.__assignments = set()

def __hash__(self) -> int:
Expand Down Expand Up @@ -646,7 +651,9 @@ def __init__(self, provider: "ScopeProvider") -> None:
self.__in_annotation: Set[
Union[cst.Call, cst.Annotation, cst.Subscript]
] = set()
self.__in_type_hint: Set[Union[cst.Call, cst.Annotation, cst.Subscript]] = set()
self.__in_ignored_subscript: Set[cst.Subscript] = set()
self.__ignore_annotation: int = 0

@contextmanager
def _new_scope(
Expand Down Expand Up @@ -705,15 +712,15 @@ def visit_Call(self, node: cst.Call) -> Optional[bool]:
qnames = self.scope.get_qualified_names_for(node)
if any(qn.name in {"typing.NewType", "typing.TypeVar"} for qn in qnames):
node.func.visit(self)
self.__in_annotation.add(node)
self.__in_type_hint.add(node)
for arg in node.args[1:]:
arg.visit(self)
return False
return True

def leave_Call(self, original_node: cst.Call) -> None:
self.__top_level_attribute_stack.pop()
self.__in_annotation.discard(original_node)
self.__in_type_hint.discard(original_node)

def visit_Annotation(self, node: cst.Annotation) -> Optional[bool]:
self.__in_annotation.add(node)
Expand All @@ -732,7 +739,9 @@ def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> Optional[boo
def _handle_string_annotation(
self, node: Union[cst.SimpleString, cst.ConcatenatedString]
) -> None:
if self.__in_annotation and not self.__in_ignored_subscript:
if (
self.__in_type_hint or self.__in_annotation
) and not self.__in_ignored_subscript:
value = node.evaluated_value
if value:
mod = cst.parse_module(value)
Expand All @@ -741,15 +750,15 @@ def _handle_string_annotation(
def visit_Subscript(self, node: cst.Subscript) -> Optional[bool]:
qnames = self.scope.get_qualified_names_for(node.value)
if any(qn.name.startswith(("typing.", "typing_extensions.")) for qn in qnames):
self.__in_annotation.add(node)
self.__in_type_hint.add(node)
if any(
qn.name in {"typing.Literal", "typing_extensions.Literal"} for qn in qnames
):
self.__in_ignored_subscript.add(node)
return True

def leave_Subscript(self, original_node: cst.Subscript) -> None:
self.__in_annotation.discard(original_node)
self.__in_type_hint.discard(original_node)
self.__in_ignored_subscript.discard(original_node)

def visit_Name(self, node: cst.Name) -> Optional[bool]:
Expand All @@ -758,7 +767,14 @@ def visit_Name(self, node: cst.Name) -> Optional[bool]:
if context == ExpressionContext.STORE:
self.scope.record_assignment(node.value, node)
elif context in (ExpressionContext.LOAD, ExpressionContext.DEL, None):
access = Access(node, self.scope, is_annotation=bool(self.__in_annotation))
access = Access(
node,
self.scope,
is_annotation=bool(
self.__in_annotation and not self.__ignore_annotation
),
is_type_hint=bool(self.__in_type_hint),
)
self.__deferred_accesses.append(
(access, self.__top_level_attribute_stack[-1])
)
Expand Down Expand Up @@ -817,6 +833,12 @@ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
statement.visit(self)
return False

def visit_ClassDef_bases(self, node: cst.ClassDef) -> None:
self.__ignore_annotation += 1

def leave_ClassDef_bases(self, node: cst.ClassDef) -> None:
self.__ignore_annotation -= 1

def visit_Global(self, node: cst.Global) -> Optional[bool]:
for name_item in node.names:
self.scope.record_global_overwrite(name_item.name.value)
Expand Down
32 changes: 26 additions & 6 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,8 +1018,8 @@ def g():
def test_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from typing import Literal, NewType, Optional, TypeVar
from a import A, B, C, D, E, F, G, H
from typing import Literal, NewType, Optional, TypeVar, Callable
from a import A, B, C, D, E, F, G, H, I, J
def x(a: A):
pass
def y(b: "B"):
Expand All @@ -1031,6 +1031,10 @@ def z(c: Literal["C"]):
FType = TypeVar("F")
GType = NewType("GType", "Optional[G]")
HType = Optional["H"]
IType = Callable[..., I]
class Test(Generic[J]):
pass
"""
)
imp = ensure_type(
Expand Down Expand Up @@ -1058,13 +1062,15 @@ def z(c: Literal["C"]):
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["E"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["F"])[0]
self.assertIsInstance(assignment, Assignment)
Expand All @@ -1074,13 +1080,27 @@ def z(c: Literal["C"]):
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["H"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["I"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)

assignment = list(scope["J"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)

def test_node_of_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
Expand Down

0 comments on commit a1b1ae4

Please sign in to comment.