diff --git a/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/named_expr.py b/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/named_expr.py index 9377e9704e721..15ac1c75775ee 100644 --- a/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/named_expr.py +++ b/crates/ruff_python_formatter/resources/test/fixtures/ruff/expression/named_expr.py @@ -11,3 +11,28 @@ y0 = (y1 := f(x)) f(x:=y, z=True) + +assert (x := 1) + + +def f(): + return (x := 1) + + +for x in (y := [1, 2, 3]): + pass + +async for x in (y := [1, 2, 3]): + pass + +del (x := 1) + +try: + pass +except (e := Exception): + if x := 1: + pass + +(x := 1) + +(x := 1) + (y := 2) diff --git a/crates/ruff_python_formatter/src/expression/expr_named_expr.rs b/crates/ruff_python_formatter/src/expression/expr_named_expr.rs index cdb2224bf0d3a..88bf4b2052619 100644 --- a/crates/ruff_python_formatter/src/expression/expr_named_expr.rs +++ b/crates/ruff_python_formatter/src/expression/expr_named_expr.rs @@ -32,11 +32,25 @@ impl FormatNodeRule for FormatExprNamedExpr { impl NeedsParentheses for ExprNamedExpr { fn needs_parentheses( &self, - _parent: AnyNodeRef, + parent: AnyNodeRef, _context: &PyFormatContext, ) -> OptionalParentheses { // Unlike tuples, named expression parentheses are not part of the range even when // mandatory. See [PEP 572](https://peps.python.org/pep-0572/) for details. - OptionalParentheses::Always + if parent.is_stmt_ann_assign() + || parent.is_stmt_assign() + || parent.is_stmt_aug_assign() + || parent.is_stmt_assert() + || parent.is_stmt_return() + || parent.is_except_handler_except_handler() + || parent.is_with_item() + || parent.is_stmt_delete() + || parent.is_stmt_for() + || parent.is_stmt_async_for() + { + OptionalParentheses::Always + } else { + OptionalParentheses::Never + } } } diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_38__pep_572.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_38__pep_572.py.snap deleted file mode 100644 index 061701bf161f7..0000000000000 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_38__pep_572.py.snap +++ /dev/null @@ -1,196 +0,0 @@ ---- -source: crates/ruff_python_formatter/tests/fixtures.rs -input_file: crates/ruff_python_formatter/resources/test/fixtures/black/py_38/pep_572.py ---- -## Input - -```py -(a := 1) -(a := a) -if (match := pattern.search(data)) is None: - pass -if match := pattern.search(data): - pass -[y := f(x), y**2, y**3] -filtered_data = [y for x in data if (y := f(x)) is None] -(y := f(x)) -y0 = (y1 := f(x)) -foo(x=(y := f(x))) - - -def foo(answer=(p := 42)): - pass - - -def foo(answer: (p := 42) = 5): - pass - - -lambda: (x := 1) -(x := lambda: 1) -(x := lambda: (y := 1)) -lambda line: (m := re.match(pattern, line)) and m.group(1) -x = (y := 0) -(z := (y := (x := 0))) -(info := (name, phone, *rest)) -(x := 1, 2) -(total := total + tax) -len(lines := f.readlines()) -foo(x := 3, cat="vector") -foo(cat=(category := "vector")) -if any(len(longline := l) >= 100 for l in lines): - print(longline) -if env_base := os.environ.get("PYTHONUSERBASE", None): - return env_base -if self._is_special and (ans := self._check_nans(context=context)): - return ans -foo(b := 2, a=1) -foo((b := 2), a=1) -foo(c=(b := 2), a=1) - -while x := f(x): - pass -while x := f(x): - pass -``` - -## Black Differences - -```diff ---- Black -+++ Ruff -@@ -2,7 +2,7 @@ - (a := a) - if (match := pattern.search(data)) is None: - pass --if match := pattern.search(data): -+if (match := pattern.search(data)): - pass - [y := f(x), y**2, y**3] - filtered_data = [y for x in data if (y := f(x)) is None] -@@ -33,7 +33,7 @@ - foo(cat=(category := "vector")) - if any(len(longline := l) >= 100 for l in lines): - print(longline) --if env_base := os.environ.get("PYTHONUSERBASE", None): -+if (env_base := os.environ.get("PYTHONUSERBASE", None)): - return env_base - if self._is_special and (ans := self._check_nans(context=context)): - return ans -@@ -41,7 +41,7 @@ - foo((b := 2), a=1) - foo(c=(b := 2), a=1) - --while x := f(x): -+while (x := f(x)): - pass --while x := f(x): -+while (x := f(x)): - pass -``` - -## Ruff Output - -```py -(a := 1) -(a := a) -if (match := pattern.search(data)) is None: - pass -if (match := pattern.search(data)): - pass -[y := f(x), y**2, y**3] -filtered_data = [y for x in data if (y := f(x)) is None] -(y := f(x)) -y0 = (y1 := f(x)) -foo(x=(y := f(x))) - - -def foo(answer=(p := 42)): - pass - - -def foo(answer: (p := 42) = 5): - pass - - -lambda: (x := 1) -(x := lambda: 1) -(x := lambda: (y := 1)) -lambda line: (m := re.match(pattern, line)) and m.group(1) -x = (y := 0) -(z := (y := (x := 0))) -(info := (name, phone, *rest)) -(x := 1, 2) -(total := total + tax) -len(lines := f.readlines()) -foo(x := 3, cat="vector") -foo(cat=(category := "vector")) -if any(len(longline := l) >= 100 for l in lines): - print(longline) -if (env_base := os.environ.get("PYTHONUSERBASE", None)): - return env_base -if self._is_special and (ans := self._check_nans(context=context)): - return ans -foo(b := 2, a=1) -foo((b := 2), a=1) -foo(c=(b := 2), a=1) - -while (x := f(x)): - pass -while (x := f(x)): - pass -``` - -## Black Output - -```py -(a := 1) -(a := a) -if (match := pattern.search(data)) is None: - pass -if match := pattern.search(data): - pass -[y := f(x), y**2, y**3] -filtered_data = [y for x in data if (y := f(x)) is None] -(y := f(x)) -y0 = (y1 := f(x)) -foo(x=(y := f(x))) - - -def foo(answer=(p := 42)): - pass - - -def foo(answer: (p := 42) = 5): - pass - - -lambda: (x := 1) -(x := lambda: 1) -(x := lambda: (y := 1)) -lambda line: (m := re.match(pattern, line)) and m.group(1) -x = (y := 0) -(z := (y := (x := 0))) -(info := (name, phone, *rest)) -(x := 1, 2) -(total := total + tax) -len(lines := f.readlines()) -foo(x := 3, cat="vector") -foo(cat=(category := "vector")) -if any(len(longline := l) >= 100 for l in lines): - print(longline) -if env_base := os.environ.get("PYTHONUSERBASE", None): - return env_base -if self._is_special and (ans := self._check_nans(context=context)): - return ans -foo(b := 2, a=1) -foo((b := 2), a=1) -foo(c=(b := 2), a=1) - -while x := f(x): - pass -while x := f(x): - pass -``` - - diff --git a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_39__python39.py.snap b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_39__python39.py.snap index cef4b88f67403..0bfa03ff16902 100644 --- a/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_39__python39.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/black_compatibility@py_39__python39.py.snap @@ -32,19 +32,6 @@ def f(): @relaxed_decorator[0] def f(): ... -@@ -13,8 +12,10 @@ - ... - - --@extremely_long_variable_name_that_doesnt_fit := complex.expression( -- with_long="arguments_value_that_wont_fit_at_the_end_of_the_line" -+@( -+ extremely_long_variable_name_that_doesnt_fit := complex.expression( -+ with_long="arguments_value_that_wont_fit_at_the_end_of_the_line" -+ ) - ) - def f(): - ... ``` ## Ruff Output @@ -64,10 +51,8 @@ def f(): ... -@( - extremely_long_variable_name_that_doesnt_fit := complex.expression( - with_long="arguments_value_that_wont_fit_at_the_end_of_the_line" - ) +@extremely_long_variable_name_that_doesnt_fit := complex.expression( + with_long="arguments_value_that_wont_fit_at_the_end_of_the_line" ) def f(): ... diff --git a/crates/ruff_python_formatter/tests/snapshots/format@expression__named_expr.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@expression__named_expr.py.snap index 9cf0484b4bbaa..018df04c7204a 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@expression__named_expr.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@expression__named_expr.py.snap @@ -17,6 +17,31 @@ if ( y0 = (y1 := f(x)) f(x:=y, z=True) + +assert (x := 1) + + +def f(): + return (x := 1) + + +for x in (y := [1, 2, 3]): + pass + +async for x in (y := [1, 2, 3]): + pass + +del (x := 1) + +try: + pass +except (e := Exception): + if x := 1: + pass + +(x := 1) + +(x := 1) + (y := 2) ``` ## Output @@ -32,6 +57,31 @@ if ( y0 = (y1 := f(x)) f(x := y, z=True) + +assert (x := 1) + + +def f(): + return (x := 1) + + +for x in (y := [1, 2, 3]): + pass + +async for x in (y := [1, 2, 3]): + pass + +del (x := 1) + +try: + pass +except (e := Exception): + if x := 1: + pass + +(x := 1) + +(x := 1) + (y := 2) ```