From a401989b7a228034d14a772c31cd2ac0f8569003 Mon Sep 17 00:00:00 2001 From: Micha Reiser Date: Fri, 2 Jun 2023 14:52:38 +0200 Subject: [PATCH] Format StmtExpr (#4788) --- .../src/expression/mod.rs | 55 +++++++++++++++++++ ...atter__tests__black_test__comments_py.snap | 24 ++++---- ...ter__tests__black_test__expression_py.snap | 9 +-- .../src/statement/stmt_expr.rs | 8 ++- 4 files changed, 77 insertions(+), 19 deletions(-) diff --git a/crates/ruff_python_formatter/src/expression/mod.rs b/crates/ruff_python_formatter/src/expression/mod.rs index dd73ca1f6dc98..362b1daa864bb 100644 --- a/crates/ruff_python_formatter/src/expression/mod.rs +++ b/crates/ruff_python_formatter/src/expression/mod.rs @@ -1,3 +1,7 @@ +use crate::prelude::*; +use ruff_formatter::{FormatOwnedWithRule, FormatRefWithRule, FormatRule}; +use rustpython_parser::ast::Expr; + pub(crate) mod expr_attribute; pub(crate) mod expr_await; pub(crate) mod expr_bin_op; @@ -25,3 +29,54 @@ pub(crate) mod expr_tuple; pub(crate) mod expr_unary_op; pub(crate) mod expr_yield; pub(crate) mod expr_yield_from; + +#[derive(Default)] +pub struct FormatExpr; + +impl FormatRule> for FormatExpr { + fn fmt(&self, item: &Expr, f: &mut PyFormatter) -> FormatResult<()> { + match item { + Expr::BoolOp(expr) => expr.format().fmt(f), + Expr::NamedExpr(expr) => expr.format().fmt(f), + Expr::BinOp(expr) => expr.format().fmt(f), + Expr::UnaryOp(expr) => expr.format().fmt(f), + Expr::Lambda(expr) => expr.format().fmt(f), + Expr::IfExp(expr) => expr.format().fmt(f), + Expr::Dict(expr) => expr.format().fmt(f), + Expr::Set(expr) => expr.format().fmt(f), + Expr::ListComp(expr) => expr.format().fmt(f), + Expr::SetComp(expr) => expr.format().fmt(f), + Expr::DictComp(expr) => expr.format().fmt(f), + Expr::GeneratorExp(expr) => expr.format().fmt(f), + Expr::Await(expr) => expr.format().fmt(f), + Expr::Yield(expr) => expr.format().fmt(f), + Expr::YieldFrom(expr) => expr.format().fmt(f), + Expr::Compare(expr) => expr.format().fmt(f), + Expr::Call(expr) => expr.format().fmt(f), + Expr::FormattedValue(expr) => expr.format().fmt(f), + Expr::JoinedStr(expr) => expr.format().fmt(f), + Expr::Constant(expr) => expr.format().fmt(f), + Expr::Attribute(expr) => expr.format().fmt(f), + Expr::Subscript(expr) => expr.format().fmt(f), + Expr::Starred(expr) => expr.format().fmt(f), + Expr::Name(expr) => expr.format().fmt(f), + Expr::List(expr) => expr.format().fmt(f), + Expr::Tuple(expr) => expr.format().fmt(f), + Expr::Slice(expr) => expr.format().fmt(f), + } + } +} + +impl<'ast> AsFormat> for Expr { + type Format<'a> = FormatRefWithRule<'a, Expr, FormatExpr, PyFormatContext<'ast>>; + fn format(&self) -> Self::Format<'_> { + FormatRefWithRule::new(self, FormatExpr::default()) + } +} + +impl<'ast> IntoFormat> for Expr { + type Format = FormatOwnedWithRule>; + fn into_format(self) -> Self::Format { + FormatOwnedWithRule::new(self, FormatExpr::default()) + } +} diff --git a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments_py.snap b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments_py.snap index b29e94e259f76..419cf8868bd3a 100644 --- a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments_py.snap +++ b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__comments_py.snap @@ -109,7 +109,7 @@ async def wat(): ```diff --- Black +++ Ruff -@@ -8,27 +8,19 @@ +@@ -8,27 +8,17 @@ Possibly also many, many lines. """ @@ -128,16 +128,18 @@ async def wat(): - -# Some comment before a function. y = 1 - ( - # some strings - y # type: ignore - ) +-( +- # some strings +- y # type: ignore +-) - - ++# some strings ++y # type: ignore def function(default=None): """Docstring comes first. -@@ -45,12 +37,8 @@ +@@ -45,12 +35,8 @@ # This return is also commented for some reason. return default @@ -150,7 +152,7 @@ async def wat(): # Another comment! # This time two lines. -@@ -73,8 +61,6 @@ +@@ -73,8 +59,6 @@ self.spam = 4 """Docstring for instance attribute spam.""" @@ -159,7 +161,7 @@ async def wat(): #'

This is pweave!

-@@ -93,4 +79,4 @@ +@@ -93,4 +77,4 @@ # Some closing comments. # Maybe Vim or Emacs directives for formatting. @@ -190,10 +192,8 @@ try: except ImportError: import slow as fast y = 1 -( - # some strings - y # type: ignore -) +# some strings +y # type: ignore def function(default=None): """Docstring comes first. diff --git a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__expression_py.snap b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__expression_py.snap index 1a8d43ff9b9d7..cf47dd1cbea38 100644 --- a/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__expression_py.snap +++ b/crates/ruff_python_formatter/src/snapshots/ruff_python_formatter__tests__black_test__expression_py.snap @@ -319,7 +319,7 @@ last_call() -) -{"2.7": dead, "3.7": (long_live or die_hard)} -{"2.7": dead, "3.7": (long_live or die_hard), **{"3.6": verygood}} -+((super_long_variable_name or None) if (1 if super_long_test_name else 2) else (str or bytes or None)) ++(super_long_variable_name or None) if (1 if super_long_test_name else 2) else (str or bytes or None) +{'2.7': dead, '3.7': (long_live or die_hard)} +{'2.7': dead, '3.7': (long_live or die_hard), **{'3.6': verygood}} {**a, **b, **c} @@ -450,8 +450,9 @@ last_call() +{'2.7': dead, '3.7': long_live or die_hard} +{'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'} [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C] - (SomeName) +-(SomeName) SomeName ++SomeName (Good, Bad, Ugly) (i for i in (1, 2, 3)) -((i**2) for i in (1, 2, 3)) @@ -735,7 +736,7 @@ str or None if True else str or bytes or None (str or None) if True else (str or bytes or None) str or None if (1 if True else 2) else str or bytes or None (str or None) if (1 if True else 2) else (str or bytes or None) -((super_long_variable_name or None) if (1 if super_long_test_name else 2) else (str or bytes or None)) +(super_long_variable_name or None) if (1 if super_long_test_name else 2) else (str or bytes or None) {'2.7': dead, '3.7': (long_live or die_hard)} {'2.7': dead, '3.7': (long_live or die_hard), **{'3.6': verygood}} {**a, **b, **c} @@ -832,7 +833,7 @@ numpy[np.newaxis, :] {'2.7': dead, '3.7': long_live or die_hard} {'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'} [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C] -(SomeName) +SomeName SomeName (Good, Bad, Ugly) (i for i in (1, 2, 3)) diff --git a/crates/ruff_python_formatter/src/statement/stmt_expr.rs b/crates/ruff_python_formatter/src/statement/stmt_expr.rs index b19a723837682..691411c762a17 100644 --- a/crates/ruff_python_formatter/src/statement/stmt_expr.rs +++ b/crates/ruff_python_formatter/src/statement/stmt_expr.rs @@ -1,5 +1,5 @@ -use crate::{verbatim_text, FormatNodeRule, PyFormatter}; -use ruff_formatter::{write, Buffer, FormatResult}; +use crate::prelude::*; +use crate::FormatNodeRule; use rustpython_parser::ast::StmtExpr; #[derive(Default)] @@ -7,6 +7,8 @@ pub struct FormatStmtExpr; impl FormatNodeRule for FormatStmtExpr { fn fmt_fields(&self, item: &StmtExpr, f: &mut PyFormatter) -> FormatResult<()> { - write!(f, [verbatim_text(item.range)]) + let StmtExpr { value, .. } = item; + + value.format().fmt(f) } }