Skip to content

Commit

Permalink
Gathering exports in augmented assignment statements
Browse files Browse the repository at this point in the history
  • Loading branch information
Kronuz committed Dec 15, 2020
1 parent 8eee3cc commit f2c0c0a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
14 changes: 14 additions & 0 deletions libcst/codemod/visitors/_gather_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Set, Union

import libcst as cst
import libcst.matchers as m
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareVisitor
from libcst.helpers import get_full_name_for_node
Expand Down Expand Up @@ -53,6 +54,19 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
return True
return False

def visit_AugAssign(self, node: cst.AugAssign) -> bool:
if m.matches(
node,
m.AugAssign(
target=m.Name("__all__"),
operator=m.AddAssign,
value=m.List() | m.Tuple(),
),
):
self._is_assigned_export.add(node.value)
return True
return False

def visit_Assign(self, node: cst.Assign) -> bool:
for target_node in node.targets:
if self._handle_assign_target(target_node.target, node.value):
Expand Down
12 changes: 12 additions & 0 deletions libcst/codemod/visitors/tests/test_gather_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ def test_gather_exports_simple(self) -> None:
gatherer = self.gather_exports(code)
self.assertEqual(gatherer.explicit_exported_objects, {"bar", "baz"})

def test_gather_exports_simple2(self) -> None:
code = """
from foo import bar
from biz import baz
__all__ = ["bar"]
__all__ += ["baz"]
"""

gatherer = self.gather_exports(code)
self.assertEqual(gatherer.explicit_exported_objects, {"bar", "baz"})

def test_gather_exports_simple_set(self) -> None:
code = """
from foo import bar
Expand Down

0 comments on commit f2c0c0a

Please sign in to comment.