From 92cab7ef2813e3940804eff0b0e86092f8d434c9 Mon Sep 17 00:00:00 2001 From: Shivansh-007 Date: Fri, 14 Jan 2022 11:30:40 +0530 Subject: [PATCH] Use parenthesis with equality check in walrus/assigment statements Closes #449 --- src/black/linegen.py | 21 ++++++ tests/data/paren_eq_check_in_assigments.py | 83 ++++++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 tests/data/paren_eq_check_in_assigments.py diff --git a/src/black/linegen.py b/src/black/linegen.py index 6008c773f94..3152360d551 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -99,6 +99,27 @@ def visit_default(self, node: LN) -> Iterator[Line]: self.current_line.append(node) yield from super().visit_default(node) + def visit_comparison(self, node: Node) -> Iterator[Line]: + parent: Optional[Node] = node.parent + grandparent: Optional[Node] = parent.parent if parent else None + if ( + parent is not None + and Leaf(token.EQEQUAL, "==") in node.children + and ( + parent.type == syms.namedexpr_test + or ( + grandparent is not None + and grandparent.type in (syms.expr_stmt, syms.annassign) + ) + ) + ): + lpar = Leaf(token.LPAR, "(") + rpar = Leaf(token.RPAR, ")") + node.insert_child(0, lpar) + node.insert_child(len(node.children), rpar) + + yield from self.visit_default(node) + def visit_INDENT(self, node: Leaf) -> Iterator[Line]: """Increase indentation level, maybe yield a line.""" # In blib2to3 INDENT never holds comments. diff --git a/tests/data/paren_eq_check_in_assigments.py b/tests/data/paren_eq_check_in_assigments.py new file mode 100644 index 00000000000..2cbb64ceafb --- /dev/null +++ b/tests/data/paren_eq_check_in_assigments.py @@ -0,0 +1,83 @@ +match_count += new_value == old_value + +if on_windows := os.name == "nt": + ... + +on_windows: bool = (os.name == "nt") + +implementation_version = ( + platform.python_version() if platform.python_implementation() == "CPython" else "Unknown" +) + +is_mac = platform.system() == 'Darwin' + +s = y == 2 + y == 4 + +name1 = name2 = name3 + +name1 == name2 == name3 + +check_sockets(on_windows=os.name == "nt") + +if a := b == c: + pass + +a = [y := f(x) == True, y ** 2, y ** 3] + +a = lambda line: (m := re.match(pattern, line) == True) + +a = b in c and b == d + +a = b == c == d + +a = b == c in d + +a = b >= c == True + +a = b in c + +a = b > c + +# output + +match_count += (new_value == old_value) + +if on_windows := (os.name == "nt"): + ... + +on_windows: bool = (os.name == "nt") + +implementation_version = ( + platform.python_version() + if platform.python_implementation() == "CPython" + else "Unknown" +) + +is_mac = (platform.system() == "Darwin") + +s = (y == 2 + y == 4) + +name1 = name2 = name3 + +name1 == name2 == name3 + +check_sockets(on_windows=os.name == "nt") + +if a := (b == c): + pass + +a = [y := (f(x) == True), y ** 2, y ** 3] + +a = lambda line: (m := (re.match(pattern, line) == True)) + +a = b in c and b == d + +a = (b == c == d) + +a = (b == c in d) + +a = (b >= c == True) + +a = b in c + +a = b > c