Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix is_annotation for types used in classdef base and assign value #406

Merged
merged 4 commits into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,17 @@ def __new__(cls) -> "Tree":

is_annotation: bool

is_type_hint: bool

__assignments: Set["BaseAssignment"]

def __init__(self, node: cst.Name, scope: "Scope", is_annotation: bool) -> None:
def __init__(
self, node: cst.Name, scope: "Scope", is_annotation: bool, is_type_hint: bool
) -> None:
self.node = node
self.scope = scope
self.is_annotation = is_annotation
self.is_type_hint = is_type_hint
self.__assignments = set()

def __hash__(self) -> int:
Expand Down Expand Up @@ -646,7 +651,9 @@ def __init__(self, provider: "ScopeProvider") -> None:
self.__in_annotation: Set[
Union[cst.Call, cst.Annotation, cst.Subscript]
] = set()
self.__in_type_hint: Set[Union[cst.Call, cst.Annotation, cst.Subscript]] = set()
self.__in_ignored_subscript: Set[cst.Subscript] = set()
self.__ignore_annotation: int = 0

@contextmanager
def _new_scope(
Expand Down Expand Up @@ -705,15 +712,15 @@ def visit_Call(self, node: cst.Call) -> Optional[bool]:
qnames = self.scope.get_qualified_names_for(node)
if any(qn.name in {"typing.NewType", "typing.TypeVar"} for qn in qnames):
node.func.visit(self)
self.__in_annotation.add(node)
self.__in_type_hint.add(node)
for arg in node.args[1:]:
arg.visit(self)
return False
return True

def leave_Call(self, original_node: cst.Call) -> None:
self.__top_level_attribute_stack.pop()
self.__in_annotation.discard(original_node)
self.__in_type_hint.discard(original_node)

def visit_Annotation(self, node: cst.Annotation) -> Optional[bool]:
self.__in_annotation.add(node)
Expand All @@ -732,7 +739,9 @@ def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> Optional[boo
def _handle_string_annotation(
self, node: Union[cst.SimpleString, cst.ConcatenatedString]
) -> None:
if self.__in_annotation and not self.__in_ignored_subscript:
if (
self.__in_type_hint or self.__in_annotation
) and not self.__in_ignored_subscript:
value = node.evaluated_value
if value:
mod = cst.parse_module(value)
Expand All @@ -741,15 +750,15 @@ def _handle_string_annotation(
def visit_Subscript(self, node: cst.Subscript) -> Optional[bool]:
qnames = self.scope.get_qualified_names_for(node.value)
if any(qn.name.startswith(("typing.", "typing_extensions.")) for qn in qnames):
self.__in_annotation.add(node)
self.__in_type_hint.add(node)
if any(
qn.name in {"typing.Literal", "typing_extensions.Literal"} for qn in qnames
):
self.__in_ignored_subscript.add(node)
return True

def leave_Subscript(self, original_node: cst.Subscript) -> None:
self.__in_annotation.discard(original_node)
self.__in_type_hint.discard(original_node)
self.__in_ignored_subscript.discard(original_node)

def visit_Name(self, node: cst.Name) -> Optional[bool]:
Expand All @@ -758,7 +767,14 @@ def visit_Name(self, node: cst.Name) -> Optional[bool]:
if context == ExpressionContext.STORE:
self.scope.record_assignment(node.value, node)
elif context in (ExpressionContext.LOAD, ExpressionContext.DEL, None):
access = Access(node, self.scope, is_annotation=bool(self.__in_annotation))
access = Access(
node,
self.scope,
is_annotation=bool(
self.__in_annotation and not self.__ignore_annotation
),
is_type_hint=bool(self.__in_type_hint),
)
self.__deferred_accesses.append(
(access, self.__top_level_attribute_stack[-1])
)
Expand Down Expand Up @@ -817,6 +833,12 @@ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
statement.visit(self)
return False

def visit_ClassDef_bases(self, node: cst.ClassDef) -> None:
self.__ignore_annotation += 1

def leave_ClassDef_bases(self, node: cst.ClassDef) -> None:
luciawlli marked this conversation as resolved.
Show resolved Hide resolved
self.__ignore_annotation -= 1

def visit_Global(self, node: cst.Global) -> Optional[bool]:
for name_item in node.names:
self.scope.record_global_overwrite(name_item.name.value)
Expand Down
32 changes: 26 additions & 6 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,8 +1018,8 @@ def g():
def test_annotation_access(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
from typing import Literal, NewType, Optional, TypeVar
from a import A, B, C, D, E, F, G, H
from typing import Literal, NewType, Optional, TypeVar, Callable
from a import A, B, C, D, E, F, G, H, I, J
def x(a: A):
pass
def y(b: "B"):
Expand All @@ -1031,6 +1031,10 @@ def z(c: Literal["C"]):
FType = TypeVar("F")
GType = NewType("GType", "Optional[G]")
HType = Optional["H"]
IType = Callable[..., I]

class Test(Generic[J]):
pass
"""
)
imp = ensure_type(
Expand Down Expand Up @@ -1058,13 +1062,15 @@ def z(c: Literal["C"]):
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["E"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["F"])[0]
self.assertIsInstance(assignment, Assignment)
Expand All @@ -1074,13 +1080,27 @@ def z(c: Literal["C"]):
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["H"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertTrue(references[0].is_annotation)
self.assertFalse(references[0].is_annotation)
self.assertTrue(references[0].is_type_hint)

assignment = list(scope["I"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)

assignment = list(scope["J"])[0]
self.assertIsInstance(assignment, Assignment)
self.assertEqual(len(assignment.references), 1)
references = list(assignment.references)
self.assertFalse(references[0].is_annotation)

def test_node_of_scopes(self) -> None:
m, scopes = get_scope_metadata_provider(
Expand Down