Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix slow qualified name perf in 0.4.2+ #698

Merged
merged 1 commit into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,16 +406,18 @@ class Scope(abc.ABC):
#: Refers to the GlobalScope.
globals: "GlobalScope"
_assignments: MutableMapping[str, Set[BaseAssignment]]
_accesses: MutableMapping[str, Set[Access]]
_assignment_count: int
_accesses_by_name: MutableMapping[str, Set[Access]]
_accesses_by_node: MutableMapping[cst.CSTNode, Set[Access]]

def __init__(self, parent: "Scope") -> None:
super().__init__()
self.parent = parent
self.globals = parent.globals
self._assignments = defaultdict(set)
self._accesses = defaultdict(set)
self._assignment_count = 0
self._accesses_by_name = defaultdict(set)
self._accesses_by_node = defaultdict(set)

def record_assignment(self, name: str, node: cst.CSTNode) -> None:
target = self._find_assignment_target(name)
Expand Down Expand Up @@ -446,7 +448,8 @@ def _find_assignment_target_parent(self, name: str) -> "Scope":
return self

def record_access(self, name: str, access: Access) -> None:
self._accesses[name].add(access)
self._accesses_by_name[name].add(access)
self._accesses_by_node[access.node].add(access)

def _getitem_from_self_or_parent(self, name: str) -> Set[BaseAssignment]:
"""Overridden by ClassScope to hide it's assignments from child scopes."""
Expand Down Expand Up @@ -545,12 +548,9 @@ def f(self) -> "c":
"""

# if this node is an access we know the assignment and we can use that name
node_accesses = {
access
for all_accesses in self._accesses.values()
for access in all_accesses
if access.node == node
}
node_accesses = (
self._accesses_by_node.get(node) if isinstance(node, cst.CSTNode) else None
)
if node_accesses:
return {
qname
Expand Down Expand Up @@ -589,7 +589,7 @@ def assignments(self) -> Assignments:
@property
def accesses(self) -> Accesses:
"""Return an :class:`~libcst.metadata.Accesses` contains all accesses in current scope."""
return Accesses(self._accesses)
return Accesses(self._accesses_by_name)


class BuiltinScope(Scope):
Expand Down
34 changes: 16 additions & 18 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,35 +1445,33 @@ def test_keyword_arg_in_call(self) -> None:

def test_global_contains_is_read_only(self) -> None:
gscope = GlobalScope()
before_assignments = list(gscope._assignments.items())
before_accesses = list(gscope._accesses.items())
before_assignments = list(gscope.assignments)
before_accesses = list(gscope.accesses)
self.assertFalse("doesnt_exist" in gscope)
self.assertEqual(list(gscope._accesses.items()), before_accesses)
self.assertEqual(list(gscope._assignments.items()), before_assignments)
self.assertEqual(list(gscope.accesses), before_accesses)
self.assertEqual(list(gscope.assignments), before_assignments)

def test_contains_is_read_only(self) -> None:
for s in [LocalScope, FunctionScope, ClassScope, ComprehensionScope]:
with self.subTest(scope=s):
gscope = GlobalScope()
scope = s(parent=gscope, node=cst.Name("lol"))
before_assignments = list(scope._assignments.items())
before_accesses = list(scope._accesses.items())
before_assignments = list(scope.assignments)
before_accesses = list(scope.accesses)
before_overwrites = list(scope._scope_overwrites.items())
before_parent_assignments = list(scope.parent._assignments.items())
before_parent_accesses = list(scope.parent._accesses.items())
before_parent_assignments = list(scope.parent.assignments)
before_parent_accesses = list(scope.parent.accesses)

self.assertFalse("doesnt_exist" in scope)
self.assertEqual(list(scope._accesses.items()), before_accesses)
self.assertEqual(list(scope._assignments.items()), before_assignments)
self.assertEqual(list(scope.accesses), before_accesses)
self.assertEqual(list(scope.assignments), before_assignments)
self.assertEqual(
list(scope._scope_overwrites.items()), before_overwrites
)
self.assertEqual(
list(scope.parent._assignments.items()), before_parent_assignments
)
self.assertEqual(
list(scope.parent._accesses.items()), before_parent_accesses
list(scope.parent.assignments), before_parent_assignments
)
self.assertEqual(list(scope.parent.accesses), before_parent_accesses)

def test_attribute_of_function_call(self) -> None:
get_scope_metadata_provider("foo().bar")
Expand All @@ -1496,11 +1494,11 @@ def test_get_qualified_names_for_is_read_only(self) -> None:
)
a = m.body[0]
scope = scopes[a]
assignments_len_before = len(scope._assignments)
accesses_len_before = len(scope._accesses)
assignments_before = list(scope.assignments)
accesses_before = list(scope.accesses)
scope.get_qualified_names_for("doesnt_exist")
self.assertEqual(len(scope._assignments), assignments_len_before)
self.assertEqual(len(scope._accesses), accesses_len_before)
self.assertEqual(list(scope.assignments), assignments_before)
self.assertEqual(list(scope.accesses), accesses_before)

def test_gen_dotted_names(self) -> None:
names = {name for name, node in _gen_dotted_names(cst.Name(value="a"))}
Expand Down