From c586ef0987681f2f649a423fec7b1afeac6b6039 Mon Sep 17 00:00:00 2001 From: David Liu Date: Mon, 19 Jul 2021 22:57:50 -0400 Subject: [PATCH 1/3] Improve variable lookup to ignore exclusive statements --- ChangeLog | 4 + astroid/node_classes.py | 15 ++- tests/unittest_lookup.py | 270 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 286 insertions(+), 3 deletions(-) diff --git a/ChangeLog b/ChangeLog index ca355f1b39..8faed02351 100644 --- a/ChangeLog +++ b/ChangeLog @@ -13,6 +13,10 @@ Release date: TBA * Added support to infer return type of ``typing.cast()`` +* Fix variable lookup's handling of exclusive statements + + Closes PyCQA/pylint#3711 + What's New in astroid 2.6.5? ============================ diff --git a/astroid/node_classes.py b/astroid/node_classes.py index e9b2fcfaf2..0226b096de 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -1211,17 +1211,26 @@ def _filter_stmts(self, stmts, frame, offset): if not (optional_assign or are_exclusive(_stmts[pindex], node)): del _stmt_parents[pindex] del _stmts[pindex] + + # If self and node are exclusive, then we can ignore node + if are_exclusive(self, node): + continue + if isinstance(node, AssignName): + # Remove all previously stored assignments if: + # 1. node's statement always assigns + # 2. node has the same parent as self (i.e., they're in the same block) if not optional_assign and stmt.parent is mystmt.parent: _stmts = [] _stmt_parents = [] elif isinstance(node, DelName): + # Remove all previously stored assignments _stmts = [] _stmt_parents = [] continue - if not are_exclusive(self, node): - _stmts.append(node) - _stmt_parents.append(stmt.parent) + # Add the new assignment + _stmts.append(node) + _stmt_parents.append(stmt.parent) return _stmts diff --git a/tests/unittest_lookup.py b/tests/unittest_lookup.py index f8b0b66b21..1d6e77b598 100644 --- a/tests/unittest_lookup.py +++ b/tests/unittest_lookup.py @@ -474,5 +474,275 @@ def run1(): self.assertEqual(len(stmts), 0) +class LookupControlFlowTest(unittest.TestCase): + """Tests for lookup capabilities and control flow""" + + def test_consecutive_assign(self): + """When multiple assignment statements are in the same block, only the last one + is returned. + """ + code = """ + x = 10 + x = 100 + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 1) + self.assertEqual(stmts[0].lineno, 3) + + def test_assign_after_use(self): + """An assignment statement appearing after the variable is not returned.""" + code = """ + print(x) + x = 10 + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 0) + + def test_del_removes_prior(self): + """Delete statement removes any prior assignments""" + code = """ + x = 10 + del x + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 0) + + def test_del_no_effect_after(self): + """Delete statement doesn't remove future assignments""" + code = """ + x = 10 + del x + x = 100 + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 1) + self.assertEqual(stmts[0].lineno, 4) + + def test_if_assign(self): + """Assignment in if statement is added to lookup results, but does not replace + prior assignments. + """ + code = """ + def f(b): + x = 10 + if b: + x = 100 + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 2) + self.assertCountEqual([stmt.lineno for stmt in stmts], [3, 5]) + + def test_if_assigns_same_branch(self): + """When if branch has multiple assignment statements, only the last one + is added. + """ + code = """ + def f(b): + x = 10 + if b: + x = 100 + x = 1000 + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 2) + self.assertCountEqual([stmt.lineno for stmt in stmts], [3, 6]) + + def test_if_assigns_different_branch(self): + """When different branches have assignment statements, the last one + in each branch is added. + """ + code = """ + def f(b): + x = 10 + if b == 1: + x = 100 + x = 1000 + elif b == 2: + x = 3 + elif b == 3: + x = 4 + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 4) + self.assertCountEqual([stmt.lineno for stmt in stmts], [3, 6, 8, 10]) + + def test_assign_exclusive(self): + """When the variable appears inside a branch of an if statement, + no assignment statements from other branches are returned. + """ + code = """ + def f(b): + x = 10 + if b == 1: + x = 100 + x = 1000 + elif b == 2: + x = 3 + elif b == 3: + x = 4 + else: + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 1) + self.assertEqual(stmts[0].lineno, 3) + + def test_assign_not_exclusive(self): + """When the variable appears inside a branch of an if statement, + only the last assignment statement in the same branch is returned. + """ + code = """ + def f(b): + x = 10 + if b == 1: + x = 100 + x = 1000 + elif b == 2: + x = 3 + elif b == 3: + x = 4 + print(x) + else: + x = 5 + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 1) + self.assertEqual(stmts[0].lineno, 10) + + def test_if_else(self): + """When an assignment statement appears in both an if and else branch, both + are added. This does NOT replace an assignment statement appearing before the + if statement. (See issue #213) + """ + code = """ + def f(b): + x = 10 + if b: + x = 100 + else: + x = 1000 + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 3) + self.assertCountEqual([stmt.lineno for stmt in stmts], [3, 5, 7]) + + def test_if_variable_in_condition_1(self): + """Test lookup works correctly when a variable appears in an if condition.""" + code = """ + x = 10 + if x > 10: + print('a') + elif x > 0: + print('b') + """ + astroid = builder.parse(code) + x_name1, x_name2 = ( + n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x" + ) + + _, stmts1 = x_name1.lookup("x") + self.assertEqual(len(stmts1), 1) + self.assertEqual(stmts1[0].lineno, 2) + + _, stmts2 = x_name2.lookup("x") + self.assertEqual(len(stmts2), 1) + self.assertEqual(stmts2[0].lineno, 2) + + def test_if_variable_in_condition_2(self): + """Test lookup works correctly when a variable appears in an if condition, + and the variable is reassigned in each branch. + + This is based on PyCQA/pylint issue #3711. + """ + code = """ + x = 10 + if x > 10: + x = 100 + elif x > 0: + x = 200 + elif x > -10: + x = 300 + else: + x = 400 + """ + astroid = builder.parse(code) + x_names = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"] + + # All lookups should refer only to the initial x = 10. + for x_name in x_names: + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 1) + self.assertEqual(stmts[0].lineno, 2) + + def test_del_not_exclusive(self): + """A delete statement in an if statement branch removes all previous + assignment statements when the delete statement is not exclusive with + the variable (e.g., when the variable is used below the if statement). + """ + code = """ + def f(b): + x = 10 + if b == 1: + x = 100 + elif b == 2: + del x + elif b == 3: + x = 4 # Only this assignment statement is returned + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 1) + self.assertEqual(stmts[0].lineno, 9) + + def test_del_exclusive(self): + """A delete statement in an if statement branch that is exclusive with the + variable does not remove previous assignment statements. + """ + code = """ + def f(b): + x = 10 + if b == 1: + x = 100 + elif b == 2: + del x + else: + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 1) + self.assertEqual(stmts[0].lineno, 3) + + if __name__ == "__main__": unittest.main() From 7f07d0dc9ed80e1d78e3f7e09dbe393970d61795 Mon Sep 17 00:00:00 2001 From: David Liu Date: Tue, 20 Jul 2021 16:36:21 -0400 Subject: [PATCH 2/3] Improve variable lookup to handle function parameters being overwritten --- ChangeLog | 4 ++ astroid/node_classes.py | 17 ++++-- tests/unittest_inference.py | 18 ++++++ tests/unittest_lookup.py | 108 +++++++++++++++++++++++++++++++++++- 4 files changed, 142 insertions(+), 5 deletions(-) diff --git a/ChangeLog b/ChangeLog index 8faed02351..92919c4d88 100644 --- a/ChangeLog +++ b/ChangeLog @@ -17,6 +17,10 @@ Release date: TBA Closes PyCQA/pylint#3711 +* Fix variable lookup's handling of function parameters + + Closes PyCQA/astroid#180 + What's New in astroid 2.6.5? ============================ diff --git a/astroid/node_classes.py b/astroid/node_classes.py index 0226b096de..042bbb8cf3 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -1216,10 +1216,10 @@ def _filter_stmts(self, stmts, frame, offset): if are_exclusive(self, node): continue + # An AssignName node overrides previous assignments if: + # 1. node's statement always assigns + # 2. node and self are in the same block (i.e., has the same parent as self) if isinstance(node, AssignName): - # Remove all previously stored assignments if: - # 1. node's statement always assigns - # 2. node has the same parent as self (i.e., they're in the same block) if not optional_assign and stmt.parent is mystmt.parent: _stmts = [] _stmt_parents = [] @@ -1230,7 +1230,16 @@ def _filter_stmts(self, stmts, frame, offset): continue # Add the new assignment _stmts.append(node) - _stmt_parents.append(stmt.parent) + if isinstance(node, Arguments) or isinstance(node.parent, Arguments): + # Special case for _stmt_parents when node is a function parameter; + # in this case, stmt is the enclosing FunctionDef, which is what we + # want to add to _stmt_parents, not stmt.parent. This case occurs when + # node is an Arguments node (representing varargs or kwargs parameter), + # and when node.parent is an Arguments node (other parameters). + # See issue #180. + _stmt_parents.append(stmt) + else: + _stmt_parents.append(stmt.parent) return _stmts diff --git a/tests/unittest_inference.py b/tests/unittest_inference.py index 591696f9e8..5c09fbda12 100644 --- a/tests/unittest_inference.py +++ b/tests/unittest_inference.py @@ -4805,6 +4805,24 @@ def test(*args): return args inferred = next(node.infer()) self.assertEqual(inferred, util.Uninferable) + def test_args_overwritten(self): + # https://github.com/PyCQA/astroid/issues/180 + node = extract_node( + """ + next = 42 + def wrapper(next=next): + next = 24 + def test(): + return next + return test + wrapper()() #@ + """ + ) + inferred = node.inferred() + self.assertEqual(len(inferred), 1) + self.assertIsInstance(inferred[0], nodes.Const, inferred[0]) + self.assertEqual(inferred[0].value, 24) + class SliceTest(unittest.TestCase): def test_slice(self): diff --git a/tests/unittest_lookup.py b/tests/unittest_lookup.py index 1d6e77b598..86f9e8cc95 100644 --- a/tests/unittest_lookup.py +++ b/tests/unittest_lookup.py @@ -18,7 +18,7 @@ import functools import unittest -from astroid import builder, nodes, scoped_nodes +from astroid import builder, nodes, scoped_nodes, test_utils from astroid.exceptions import ( AttributeInferenceError, InferenceError, @@ -743,6 +743,112 @@ def f(b): self.assertEqual(len(stmts), 1) self.assertEqual(stmts[0].lineno, 3) + def test_assign_after_param(self): + """When an assignment statement overwrites a function parameter, only the + assignment is returned, even when the variable and assignment do not have + the same parent. + """ + code = """ + def f1(x): + x = 100 + print(x) + + def f2(x): + x = 100 + if True: + print(x) + """ + astroid = builder.parse(code) + x_name1, x_name2 = ( + n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x" + ) + _, stmts1 = x_name1.lookup("x") + self.assertEqual(len(stmts1), 1) + self.assertEqual(stmts1[0].lineno, 3) + + _, stmts2 = x_name2.lookup("x") + self.assertEqual(len(stmts2), 1) + self.assertEqual(stmts2[0].lineno, 7) + + def test_assign_after_kwonly_param(self): + """When an assignment statement overwrites a function keyword-only parameter, + only the assignment is returned, even when the variable and assignment do + not have the same parent. + """ + code = """ + def f1(*, x): + x = 100 + print(x) + + def f2(*, x): + x = 100 + if True: + print(x) + """ + astroid = builder.parse(code) + x_name1, x_name2 = ( + n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x" + ) + _, stmts1 = x_name1.lookup("x") + self.assertEqual(len(stmts1), 1) + self.assertEqual(stmts1[0].lineno, 3) + + _, stmts2 = x_name2.lookup("x") + self.assertEqual(len(stmts2), 1) + self.assertEqual(stmts2[0].lineno, 7) + + @test_utils.require_version(minver="3.8") + def test_assign_after_posonly_param(self): + """When an assignment statement overwrites a function positional-only parameter, + only the assignment is returned, even when the variable and assignment do + not have the same parent. + """ + code = """ + def f1(x, /): + x = 100 + print(x) + + def f2(x, /): + x = 100 + if True: + print(x) + """ + astroid = builder.parse(code) + x_name1, x_name2 = ( + n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x" + ) + _, stmts1 = x_name1.lookup("x") + self.assertEqual(len(stmts1), 1) + self.assertEqual(stmts1[0].lineno, 3) + + _, stmts2 = x_name2.lookup("x") + self.assertEqual(len(stmts2), 1) + self.assertEqual(stmts2[0].lineno, 7) + + def test_assign_after_args_param(self): + """When an assignment statement overwrites a function parameter, only the + assignment is returned. + """ + code = """ + def f(*args, **kwargs): + args = [100] + kwargs = {} + if True: + print(args, kwargs) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "args"][0] + _, stmts1 = x_name.lookup("args") + self.assertEqual(len(stmts1), 1) + self.assertEqual(stmts1[0].lineno, 3) + + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "kwargs"][ + 0 + ] + _, stmts2 = x_name.lookup("kwargs") + self.assertEqual(len(stmts2), 1) + self.assertEqual(stmts2[0].lineno, 4) + if __name__ == "__main__": unittest.main() From 35beea0a2628b005df358c294a41f7af6a8312fc Mon Sep 17 00:00:00 2001 From: David Liu Date: Tue, 20 Jul 2021 17:12:14 -0400 Subject: [PATCH 3/3] Improve variable lookup to handle except clause variable scope --- ChangeLog | 2 + astroid/node_classes.py | 13 ++- tests/unittest_lookup.py | 168 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 182 insertions(+), 1 deletion(-) diff --git a/ChangeLog b/ChangeLog index 92919c4d88..01c884a8ad 100644 --- a/ChangeLog +++ b/ChangeLog @@ -21,6 +21,8 @@ Release date: TBA Closes PyCQA/astroid#180 +* Fix variable lookup's handling of except clause variables + What's New in astroid 2.6.5? ============================ diff --git a/astroid/node_classes.py b/astroid/node_classes.py index 042bbb8cf3..00c1f33485 100644 --- a/astroid/node_classes.py +++ b/astroid/node_classes.py @@ -1220,7 +1220,18 @@ def _filter_stmts(self, stmts, frame, offset): # 1. node's statement always assigns # 2. node and self are in the same block (i.e., has the same parent as self) if isinstance(node, AssignName): - if not optional_assign and stmt.parent is mystmt.parent: + if isinstance(stmt, ExceptHandler): + # If node's statement is an ExceptHandler, then it is the variable + # bound to the caught exception. If self is not contained within + # the exception handler block, node should override previous assignments; + # otherwise, node should be ignored, as an exception variable + # is local to the handler block. + if stmt.parent_of(self): + _stmts = [] + _stmt_parents = [] + else: + continue + elif not optional_assign and stmt.parent is mystmt.parent: _stmts = [] _stmt_parents = [] elif isinstance(node, DelName): diff --git a/tests/unittest_lookup.py b/tests/unittest_lookup.py index 86f9e8cc95..d1ced9bed4 100644 --- a/tests/unittest_lookup.py +++ b/tests/unittest_lookup.py @@ -849,6 +849,174 @@ def f(*args, **kwargs): self.assertEqual(len(stmts2), 1) self.assertEqual(stmts2[0].lineno, 4) + def test_except_var_in_block(self): + """When the variable bound to an exception in an except clause, it is returned + when that variable is used inside the except block. + """ + code = """ + try: + 1 / 0 + except ZeroDivisionError as e: + print(e) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "e"][0] + _, stmts = x_name.lookup("e") + self.assertEqual(len(stmts), 1) + self.assertEqual(stmts[0].lineno, 4) + + def test_except_var_in_block_overwrites(self): + """When the variable bound to an exception in an except clause, it is returned + when that variable is used inside the except block, and replaces any previous + assignments. + """ + code = """ + e = 0 + try: + 1 / 0 + except ZeroDivisionError as e: + print(e) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "e"][0] + _, stmts = x_name.lookup("e") + self.assertEqual(len(stmts), 1) + self.assertEqual(stmts[0].lineno, 5) + + def test_except_var_in_multiple_blocks(self): + """When multiple variables with the same name are bound to an exception + in an except clause, and the variable is used inside the except block, + only the assignment from the corresponding except clause is returned. + """ + code = """ + e = 0 + try: + 1 / 0 + except ZeroDivisionError as e: + print(e) + except NameError as e: + print(e) + """ + astroid = builder.parse(code) + x_names = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "e"] + + _, stmts1 = x_names[0].lookup("e") + self.assertEqual(len(stmts1), 1) + self.assertEqual(stmts1[0].lineno, 5) + + _, stmts2 = x_names[1].lookup("e") + self.assertEqual(len(stmts2), 1) + self.assertEqual(stmts2[0].lineno, 7) + + def test_except_var_after_block_single(self): + """When the variable bound to an exception in an except clause, it is NOT returned + when that variable is used after the except block. + """ + code = """ + try: + 1 / 0 + except NameError as e: + pass + print(e) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "e"][0] + _, stmts = x_name.lookup("e") + self.assertEqual(len(stmts), 0) + + def test_except_var_after_block_multiple(self): + """When the variable bound to an exception in multiple except clauses, it is NOT returned + when that variable is used after the except blocks. + """ + code = """ + try: + 1 / 0 + except NameError as e: + pass + except ZeroDivisionError as e: + pass + print(e) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "e"][0] + _, stmts = x_name.lookup("e") + self.assertEqual(len(stmts), 0) + + def test_except_assign_in_block(self): + """When a variable is assigned in an except block, it is returned + when that variable is used in the except block. + """ + code = """ + try: + 1 / 0 + except ZeroDivisionError as e: + x = 10 + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 1) + self.assertEqual(stmts[0].lineno, 5) + + def test_except_assign_in_block_multiple(self): + """When a variable is assigned in multiple except blocks, and the variable is + used in one of the blocks, only the assignments in that block are returned. + """ + code = """ + try: + 1 / 0 + except ZeroDivisionError: + x = 10 + except NameError: + x = 100 + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 1) + self.assertEqual(stmts[0].lineno, 7) + + def test_except_assign_after_block(self): + """When a variable is assigned in an except clause, it is returned + when that variable is used after the except block. + """ + code = """ + try: + 1 / 0 + except ZeroDivisionError: + x = 10 + except NameError: + x = 100 + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 2) + self.assertCountEqual([stmt.lineno for stmt in stmts], [5, 7]) + + def test_except_assign_after_block_overwritten(self): + """When a variable is assigned in an except clause, it is not returned + when it is reassigned and used after the except block. + """ + code = """ + try: + 1 / 0 + except ZeroDivisionError: + x = 10 + except NameError: + x = 100 + x = 1000 + print(x) + """ + astroid = builder.parse(code) + x_name = [n for n in astroid.nodes_of_class(nodes.Name) if n.name == "x"][0] + _, stmts = x_name.lookup("x") + self.assertEqual(len(stmts), 1) + self.assertEqual(stmts[0].lineno, 8) + if __name__ == "__main__": unittest.main()