From eb8874bf864f3ded5412f15232b54020754bdbb5 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Thu, 9 Feb 2023 17:35:21 -0500 Subject: [PATCH] Improve parentheses handling in some cases (#1) --- crates/ruff_fmt/example.py | 17 +- crates/ruff_fmt/src/cst.rs | 60 +++- crates/ruff_fmt/src/format/comprehension.rs | 16 +- crates/ruff_fmt/src/format/expr.rs | 200 +++++++----- crates/ruff_fmt/src/format/helpers.rs | 43 ++- crates/ruff_fmt/src/format/stmt.rs | 79 +++-- crates/ruff_fmt/src/lib.rs | 9 +- crates/ruff_fmt/src/parentheses.rs | 169 +++++++++++ ...sts__simple_cases__empty_lines.py.snap.new | 4 +- ...ests__simple_cases__expression.py.snap.new | 286 ++++++++---------- ..._tests__simple_cases__function.py.snap.new | 6 +- ...sts__simple_cases__tupleassign.py.snap.new | 17 ++ crates/ruff_fmt/src/trivia.rs | 139 ++++++--- 13 files changed, 697 insertions(+), 348 deletions(-) create mode 100644 crates/ruff_fmt/src/parentheses.rs create mode 100644 crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__tupleassign.py.snap.new diff --git a/crates/ruff_fmt/example.py b/crates/ruff_fmt/example.py index 07bb41f3f116d6..0840932db67a27 100644 --- a/crates/ruff_fmt/example.py +++ b/crates/ruff_fmt/example.py @@ -1,8 +1,9 @@ -if ( # Foo - # Bar - a, # A - # Baz - b, # B - # Bop -): # Qux - pass +{ + x + for (x, y, a, x, y, a, x, y, a, x, y, a, x, y, a, x, y, a, x, y, a, x, y, a, x, y, a) in z +} + +x = ajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsa + +(ajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsaajdjaskldsa) +(foo) diff --git a/crates/ruff_fmt/src/cst.rs b/crates/ruff_fmt/src/cst.rs index de54dc8e0df88f..07c690e1c8e177 100644 --- a/crates/ruff_fmt/src/cst.rs +++ b/crates/ruff_fmt/src/cst.rs @@ -2,7 +2,7 @@ use rustpython_ast::{Constant, Location}; -use crate::trivia::Trivia; +use crate::trivia::{Parenthesize, Trivia}; type Ident = String; @@ -12,6 +12,7 @@ pub struct Located { pub end_location: Option, pub node: T, pub trivia: Vec, + pub parentheses: Parenthesize, } impl Located { @@ -21,6 +22,7 @@ impl Located { end_location: Some(end_location), node, trivia: Vec::new(), + parentheses: Parenthesize::Never, } } @@ -508,6 +510,7 @@ impl From for Alias { asname: alias.node.asname, }, trivia: vec![], + parentheses: Parenthesize::Never, } } } @@ -534,6 +537,7 @@ impl From for Excepthandler { body: body.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, } } } @@ -548,12 +552,14 @@ impl From for Stmt { value: Box::new((*value).into()), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::Pass => Stmt { location: stmt.location, end_location: stmt.end_location, node: StmtKind::Pass, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::Return { value } => Stmt { location: stmt.location, @@ -562,6 +568,7 @@ impl From for Stmt { value: value.map(|v| (*v).into()), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::Assign { targets, @@ -576,6 +583,7 @@ impl From for Stmt { type_comment, }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::ClassDef { name, @@ -594,6 +602,7 @@ impl From for Stmt { decorator_list: decorator_list.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::FunctionDef { name, @@ -616,6 +625,7 @@ impl From for Stmt { type_comment, }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::If { test, body, orelse } => Stmt { location: stmt.location, @@ -626,6 +636,7 @@ impl From for Stmt { orelse: orelse.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::Assert { test, msg } => Stmt { location: stmt.location, @@ -635,6 +646,7 @@ impl From for Stmt { msg: msg.map(|msg| Box::new((*msg).into())), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::AsyncFunctionDef { name, @@ -655,6 +667,7 @@ impl From for Stmt { type_comment, }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::Delete { targets } => Stmt { location: stmt.location, @@ -663,6 +676,7 @@ impl From for Stmt { targets: targets.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::AugAssign { target, op, value } => Stmt { location: stmt.location, @@ -673,6 +687,7 @@ impl From for Stmt { value: Box::new((*value).into()), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::AnnAssign { target, @@ -689,6 +704,7 @@ impl From for Stmt { simple, }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::For { target, @@ -707,6 +723,7 @@ impl From for Stmt { type_comment, }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::AsyncFor { target, @@ -725,6 +742,7 @@ impl From for Stmt { type_comment, }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::While { test, body, orelse } => Stmt { location: stmt.location, @@ -735,6 +753,7 @@ impl From for Stmt { orelse: orelse.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::With { items, @@ -749,6 +768,7 @@ impl From for Stmt { type_comment, }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::AsyncWith { items, @@ -763,6 +783,7 @@ impl From for Stmt { type_comment, }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::Match { .. } => { todo!("match statement"); @@ -775,6 +796,7 @@ impl From for Stmt { cause: cause.map(|cause| Box::new((*cause).into())), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::Try { body, @@ -791,6 +813,7 @@ impl From for Stmt { finalbody: finalbody.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::Import { names } => Stmt { location: stmt.location, @@ -799,6 +822,7 @@ impl From for Stmt { names: names.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::ImportFrom { module, @@ -813,30 +837,35 @@ impl From for Stmt { level, }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::Global { names } => Stmt { location: stmt.location, end_location: stmt.end_location, node: StmtKind::Global { names }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::Nonlocal { names } => Stmt { location: stmt.location, end_location: stmt.end_location, node: StmtKind::Nonlocal { names }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::Break => Stmt { location: stmt.location, end_location: stmt.end_location, node: StmtKind::Break, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::StmtKind::Continue => Stmt { location: stmt.location, end_location: stmt.end_location, node: StmtKind::Continue, trivia: vec![], + parentheses: Parenthesize::Never, }, } } @@ -852,6 +881,7 @@ impl From for Keyword { value: keyword.node.value.into(), }, trivia: vec![], + parentheses: Parenthesize::Never, } } } @@ -867,6 +897,7 @@ impl From for Arg { type_comment: arg.node.type_comment, }, trivia: vec![], + parentheses: Parenthesize::Never, } } } @@ -907,6 +938,7 @@ impl From for Expr { ctx: ctx.into(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::BoolOp { op, values } => Expr { location: expr.location, @@ -916,6 +948,7 @@ impl From for Expr { values: values.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::NamedExpr { target, value } => Expr { location: expr.location, @@ -925,6 +958,7 @@ impl From for Expr { value: Box::new((*value).into()), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::BinOp { left, op, right } => Expr { location: expr.location, @@ -935,6 +969,7 @@ impl From for Expr { right: Box::new((*right).into()), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::UnaryOp { op, operand } => Expr { location: expr.location, @@ -944,6 +979,7 @@ impl From for Expr { operand: Box::new((*operand).into()), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::Lambda { args, body } => Expr { location: expr.location, @@ -953,6 +989,7 @@ impl From for Expr { body: Box::new((*body).into()), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::IfExp { test, body, orelse } => Expr { location: expr.location, @@ -963,6 +1000,7 @@ impl From for Expr { orelse: Box::new((*orelse).into()), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::Dict { keys, values } => Expr { location: expr.location, @@ -972,6 +1010,7 @@ impl From for Expr { values: values.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::Set { elts } => Expr { location: expr.location, @@ -980,6 +1019,7 @@ impl From for Expr { elts: elts.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::ListComp { elt, generators } => Expr { location: expr.location, @@ -989,6 +1029,7 @@ impl From for Expr { generators: generators.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::SetComp { elt, generators } => Expr { location: expr.location, @@ -998,6 +1039,7 @@ impl From for Expr { generators: generators.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::DictComp { key, @@ -1012,6 +1054,7 @@ impl From for Expr { generators: generators.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::GeneratorExp { elt, generators } => Expr { location: expr.location, @@ -1021,6 +1064,7 @@ impl From for Expr { generators: generators.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::Await { value } => Expr { location: expr.location, @@ -1029,6 +1073,7 @@ impl From for Expr { value: Box::new((*value).into()), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::Yield { value } => Expr { location: expr.location, @@ -1037,6 +1082,7 @@ impl From for Expr { value: value.map(|v| Box::new((*v).into())), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::YieldFrom { value } => Expr { location: expr.location, @@ -1045,6 +1091,7 @@ impl From for Expr { value: Box::new((*value).into()), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::Compare { left, @@ -1059,6 +1106,7 @@ impl From for Expr { comparators: comparators.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::Call { func, @@ -1073,6 +1121,7 @@ impl From for Expr { keywords: keywords.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::FormattedValue { value, @@ -1087,6 +1136,7 @@ impl From for Expr { format_spec: format_spec.map(|f| Box::new((*f).into())), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::JoinedStr { values } => Expr { location: expr.location, @@ -1095,12 +1145,14 @@ impl From for Expr { values: values.into_iter().map(Into::into).collect(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::Constant { value, kind } => Expr { location: expr.location, end_location: expr.end_location, node: ExprKind::Constant { value, kind }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::Attribute { value, attr, ctx } => Expr { location: expr.location, @@ -1111,6 +1163,7 @@ impl From for Expr { ctx: ctx.into(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::Subscript { value, slice, ctx } => Expr { location: expr.location, @@ -1121,6 +1174,7 @@ impl From for Expr { ctx: ctx.into(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::Starred { value, ctx } => Expr { location: expr.location, @@ -1130,6 +1184,7 @@ impl From for Expr { ctx: ctx.into(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::List { elts, ctx } => Expr { location: expr.location, @@ -1139,6 +1194,7 @@ impl From for Expr { ctx: ctx.into(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::Tuple { elts, ctx } => Expr { location: expr.location, @@ -1148,6 +1204,7 @@ impl From for Expr { ctx: ctx.into(), }, trivia: vec![], + parentheses: Parenthesize::Never, }, rustpython_ast::ExprKind::Slice { lower, upper, step } => Expr { location: expr.location, @@ -1158,6 +1215,7 @@ impl From for Expr { step: step.map(|s| Box::new((*s).into())), }, trivia: vec![], + parentheses: Parenthesize::Never, }, } } diff --git a/crates/ruff_fmt/src/format/comprehension.rs b/crates/ruff_fmt/src/format/comprehension.rs index c7ea2ebf48e0e6..711c11b953e978 100644 --- a/crates/ruff_fmt/src/format/comprehension.rs +++ b/crates/ruff_fmt/src/format/comprehension.rs @@ -22,13 +22,19 @@ impl Format> for FormatComprehension<'_> { let comprehension = self.item; write!(f, [soft_line_break_or_space()])?; - write!(f, [text("for ")])?; - write!(f, [comprehension.target.format()])?; - write!(f, [text(" in ")])?; - write!(f, [comprehension.iter.format()])?; + write!(f, [text("for")])?; + write!(f, [space()])?; + // TODO(charlie): If this is an unparenthesized tuple, we need to avoid expanding it. + // Should this be set on the context? + write!(f, [group(&comprehension.target.format())])?; + write!(f, [space()])?; + write!(f, [text("in")])?; + write!(f, [space()])?; + write!(f, [group(&comprehension.iter.format())])?; for if_clause in &comprehension.ifs { write!(f, [soft_line_break_or_space()])?; - write!(f, [text("if ")])?; + write!(f, [text("if")])?; + write!(f, [space()])?; write!(f, [if_clause.format()])?; } diff --git a/crates/ruff_fmt/src/format/expr.rs b/crates/ruff_fmt/src/format/expr.rs index d687aa859565c5..bc38d9c6730646 100644 --- a/crates/ruff_fmt/src/format/expr.rs +++ b/crates/ruff_fmt/src/format/expr.rs @@ -11,9 +11,9 @@ use crate::core::types::Range; use crate::cst::{ Arguments, Boolop, Cmpop, Comprehension, Expr, ExprKind, Keyword, Operator, Unaryop, }; -use crate::format::helpers::{is_self_closing, is_simple}; +use crate::format::helpers::{is_self_closing, is_simple_power, is_simple_slice}; use crate::shared_traits::AsFormat; -use crate::trivia::{Relationship, TriviaKind}; +use crate::trivia::{Parenthesize, Relationship, TriviaKind}; pub struct FormatExpr<'a> { item: &'a Expr, @@ -79,7 +79,7 @@ fn format_name( fn format_subscript( f: &mut Formatter>, - _expr: &Expr, + expr: &Expr, value: &Expr, slice: &Expr, ) -> FormatResult<()> { @@ -88,9 +88,7 @@ fn format_subscript( [ value.format(), text("["), - group(&format_args![soft_block_indent(&format_args![ - slice.format() - ])]), + group(&format_args![soft_block_indent(&slice.format())]), text("]") ] )?; @@ -102,7 +100,10 @@ fn format_tuple( expr: &Expr, elts: &[Expr], ) -> FormatResult<()> { - write!(f, [text("(")])?; + // If we're already parenthesized, avoid adding any "mandatory" parentheses. + // TODO(charlie): We also need to parenthesize tuples on the right-hand side of an + // assignment if the target is exploded. And sometimes the tuple gets exploded, like + // if the LHS is an exploded list? Lots of edge cases here. if elts.len() == 1 { write!( f, @@ -113,44 +114,83 @@ fn format_tuple( }))])] )?; } else if !elts.is_empty() { - // TODO(charlie): DRY. write!( f, - [group(&format_args![soft_block_indent(&format_with(|f| { - if expr - .trivia - .iter() - .any(|c| matches!(c.kind, TriviaKind::MagicTrailingComma)) - { - write!(f, [expand_parent()])?; + [group(&format_with(|f| { + if matches!(expr.parentheses, Parenthesize::IfExpanded) { + write!(f, [if_group_breaks(&text("("))])?; } - for (i, elt) in elts.iter().enumerate() { - write!(f, [elt.format()])?; - if i < elts.len() - 1 { - write!(f, [text(",")])?; - write!(f, [soft_line_break_or_space()])?; - } else { - write!(f, [if_group_breaks(&text(","))])?; + if matches!( + expr.parentheses, + Parenthesize::IfExpanded | Parenthesize::Always + ) { + write!( + f, + [soft_block_indent(&format_with(|f| { + // TODO(charlie): If the magic trailing comma isn't present, and the + // tuple is _already_ expanded, we're not supposed to add this. + let magic_trailing_comma = expr + .trivia + .iter() + .any(|c| matches!(c.kind, TriviaKind::MagicTrailingComma)); + if magic_trailing_comma { + write!(f, [expand_parent()])?; + } + for (i, elt) in elts.iter().enumerate() { + write!(f, [elt.format()])?; + if i < elts.len() - 1 { + write!(f, [text(",")])?; + write!(f, [soft_line_break_or_space()])?; + } else { + if magic_trailing_comma { + write!(f, [if_group_breaks(&text(","))])?; + } + } + } + Ok(()) + }))] + )?; + } else { + let magic_trailing_comma = expr + .trivia + .iter() + .any(|c| matches!(c.kind, TriviaKind::MagicTrailingComma)); + if magic_trailing_comma { + write!(f, [expand_parent()])?; } + for (i, elt) in elts.iter().enumerate() { + write!(f, [elt.format()])?; + if i < elts.len() - 1 { + write!(f, [text(",")])?; + write!(f, [soft_line_break_or_space()])?; + } else { + if magic_trailing_comma { + write!(f, [if_group_breaks(&text(","))])?; + } + } + } + } + if matches!(expr.parentheses, Parenthesize::IfExpanded) { + write!(f, [if_group_breaks(&text(")"))])?; } Ok(()) - }))])] + }))] )?; } - write!(f, [text(")")])?; Ok(()) } fn format_slice( f: &mut Formatter>, - _expr: &Expr, + expr: &Expr, lower: Option<&Expr>, upper: Option<&Expr>, step: Option<&Expr>, ) -> FormatResult<()> { - let is_simple = lower.map_or(true, is_simple) - && upper.map_or(true, is_simple) - && step.map_or(true, is_simple); + // https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#slices + let is_simple = lower.map_or(true, is_simple_slice) + && upper.map_or(true, is_simple_slice) + && step.map_or(true, is_simple_slice); if let Some(lower) = lower { write!(f, [lower.format()])?; @@ -164,7 +204,7 @@ fn format_slice( write!(f, [space()])?; } write!(f, [upper.format()])?; - if !is_simple { + if !is_simple && step.is_some() { write!(f, [space()])?; } } @@ -219,7 +259,7 @@ fn format_list( fn format_set( f: &mut Formatter>, - _expr: &Expr, + expr: &Expr, elts: &[Expr], ) -> FormatResult<()> { if elts.is_empty() { @@ -233,7 +273,7 @@ fn format_set( f, [group(&format_args![soft_block_indent(&format_with(|f| { for (i, elt) in elts.iter().enumerate() { - write!(f, [elt.format()])?; + write!(f, [group(&format_args![elt.format()])])?; if i < elts.len() - 1 { write!(f, [text(",")])?; write!(f, [soft_line_break_or_space()])?; @@ -332,7 +372,7 @@ fn format_call( fn format_list_comp( f: &mut Formatter>, - _expr: &Expr, + expr: &Expr, elt: &Expr, generators: &[Comprehension], ) -> FormatResult<()> { @@ -353,7 +393,7 @@ fn format_list_comp( fn format_set_comp( f: &mut Formatter>, - _expr: &Expr, + expr: &Expr, elt: &Expr, generators: &[Comprehension], ) -> FormatResult<()> { @@ -374,7 +414,7 @@ fn format_set_comp( fn format_dict_comp( f: &mut Formatter>, - _expr: &Expr, + expr: &Expr, key: &Expr, value: &Expr, generators: &[Comprehension], @@ -399,11 +439,10 @@ fn format_dict_comp( fn format_generator_exp( f: &mut Formatter>, - _expr: &Expr, + expr: &Expr, elt: &Expr, generators: &[Comprehension], ) -> FormatResult<()> { - write!(f, [text("(")])?; write!( f, [group(&format_args![soft_block_indent(&format_with(|f| { @@ -414,7 +453,6 @@ fn format_generator_exp( Ok(()) }))])] )?; - write!(f, [text(")")])?; Ok(()) } @@ -426,7 +464,7 @@ fn format_await( write!(f, [text("await")])?; write!(f, [space()])?; if is_self_closing(value) { - write!(f, [value.format()])?; + write!(f, [group(&format_args![value.format()])])?; } else { write!( f, @@ -445,13 +483,11 @@ fn format_yield( expr: &Expr, value: Option<&Expr>, ) -> FormatResult<()> { - // TODO(charlie): We need to insert these conditionally. - write!(f, [text("(")])?; write!(f, [text("yield")])?; if let Some(value) = value { write!(f, [space()])?; if is_self_closing(value) { - write!(f, [value.format()])?; + write!(f, [group(&format_args![value.format()])])?; } else { write!( f, @@ -463,7 +499,6 @@ fn format_yield( )?; } } - write!(f, [text(")")])?; Ok(()) } @@ -472,9 +507,6 @@ fn format_yield_from( expr: &Expr, value: &Expr, ) -> FormatResult<()> { - // TODO(charlie): We need to insert these conditionally. - write!(f, [text("(")])?; - write!( f, [group(&format_args![soft_block_indent(&format_with(|f| { @@ -497,7 +529,6 @@ fn format_yield_from( Ok(()) })),])] )?; - write!(f, [text(")")])?; Ok(()) } @@ -508,12 +539,12 @@ fn format_compare( ops: &[Cmpop], comparators: &[Expr], ) -> FormatResult<()> { - write!(f, [left.format()])?; + write!(f, [group(&format_args![left.format()])])?; for (i, op) in ops.iter().enumerate() { - write!(f, [space()])?; + write!(f, [soft_line_break_or_space()])?; write!(f, [op.format()])?; write!(f, [space()])?; - write!(f, [comparators[i].format()])?; + write!(f, [group(&format_args![comparators[i].format()])])?; } Ok(()) } @@ -539,7 +570,7 @@ fn format_constant( fn format_dict( f: &mut Formatter>, - _expr: &Expr, + expr: &Expr, keys: &[Option], values: &[Expr], ) -> FormatResult<()> { @@ -548,7 +579,7 @@ fn format_dict( // TODO(charlie): DRY. write!( f, - [group(&format_args![soft_block_indent(&format_with(|f| { + [soft_block_indent(&format_with(|f| { for (i, (k, v)) in keys.iter().zip(values).enumerate() { if let Some(k) = k { write!(f, [k.format()])?; @@ -589,7 +620,7 @@ fn format_dict( } } Ok(()) - }))])] + }))] )?; } write!(f, [text("}")])?; @@ -637,12 +668,12 @@ fn format_bool_op( let mut first = true; for value in values { if std::mem::take(&mut first) { - write!(f, [value.format()])?; + write!(f, [group(&format_args![value.format()])])?; } else { write!(f, [soft_line_break_or_space()])?; write!(f, [op.format()])?; write!(f, [space()])?; - write!(f, [value.format()])?; + write!(f, [group(&format_args![value.format()])])?; } } @@ -675,11 +706,19 @@ fn format_bin_op( op: &Operator, right: &Expr, ) -> FormatResult<()> { - write!(f, [left.format()])?; - write!(f, [soft_line_break_or_space()])?; + // https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#line-breaks-binary-operators + let is_simple = + matches!(op, Operator::Pow) && (is_simple_power(left) && is_simple_power(right)); + + write!(f, [group(&format_args![left.format()])])?; + if !is_simple { + write!(f, [soft_line_break_or_space()])?; + } write!(f, [op.format()])?; - write!(f, [space()])?; - write!(f, [right.format()])?; + if !is_simple { + write!(f, [space()])?; + } + write!(f, [group(&format_args![right.format()])])?; // Apply any inline comments. let mut first = true; @@ -705,18 +744,35 @@ fn format_bin_op( fn format_unary_op( f: &mut Formatter>, - _expr: &Expr, + expr: &Expr, op: &Unaryop, operand: &Expr, ) -> FormatResult<()> { write!(f, [op.format()])?; - write!(f, [operand.format()])?; + // TODO(charlie): Do this in the normalization pass. + if !matches!(op, Unaryop::Not) + && matches!( + operand.node, + ExprKind::BoolOp { .. } | ExprKind::Compare { .. } | ExprKind::BinOp { .. } + ) + { + let parenthesized = matches!(operand.parentheses, Parenthesize::Always); + if !parenthesized { + write!(f, [text("(")])?; + } + write!(f, [operand.format()])?; + if !parenthesized { + write!(f, [text(")")])?; + } + } else { + write!(f, [operand.format()])?; + } Ok(()) } fn format_lambda( f: &mut Formatter>, - _expr: &Expr, + expr: &Expr, args: &Arguments, body: &Expr, ) -> FormatResult<()> { @@ -733,36 +789,26 @@ fn format_lambda( fn format_if_exp( f: &mut Formatter>, - _expr: &Expr, + expr: &Expr, test: &Expr, body: &Expr, orelse: &Expr, ) -> FormatResult<()> { - write!(f, [body.format()])?; + write!(f, [group(&format_args![body.format()])])?; write!(f, [soft_line_break_or_space()])?; write!(f, [text("if")])?; write!(f, [space()])?; - write!(f, [test.format()])?; + write!(f, [group(&format_args![test.format()])])?; write!(f, [soft_line_break_or_space()])?; write!(f, [text("else")])?; write!(f, [space()])?; - write!(f, [orelse.format()])?; + write!(f, [group(&format_args![orelse.format()])])?; Ok(()) } impl Format> for FormatExpr<'_> { fn fmt(&self, f: &mut Formatter>) -> FormatResult<()> { - let parenthesize = !matches!( - self.item.node, - ExprKind::Tuple { .. } | ExprKind::GeneratorExp { .. } - ) && self.item.trivia.iter().any(|trivia| { - matches!( - (trivia.relationship, trivia.kind), - (Relationship::Leading, TriviaKind::LeftParen) - ) - }); - - if parenthesize { + if matches!(self.item.parentheses, Parenthesize::Always) { write!(f, [text("(")])?; } @@ -851,7 +897,7 @@ impl Format> for FormatExpr<'_> { } } - if parenthesize { + if matches!(self.item.parentheses, Parenthesize::Always) { write!(f, [text(")")])?; } diff --git a/crates/ruff_fmt/src/format/helpers.rs b/crates/ruff_fmt/src/format/helpers.rs index 4d3e8c21721ba3..63f216a5fa91bf 100644 --- a/crates/ruff_fmt/src/format/helpers.rs +++ b/crates/ruff_fmt/src/format/helpers.rs @@ -11,9 +11,12 @@ pub fn is_self_closing(expr: &Expr) -> bool { | ExprKind::DictComp { .. } | ExprKind::GeneratorExp { .. } | ExprKind::Call { .. } + | ExprKind::Name { .. } + | ExprKind::Constant { .. } | ExprKind::Subscript { .. } => true, + ExprKind::Lambda { body, .. } => is_self_closing(body), ExprKind::BinOp { left, right, .. } => { - matches!(left.node, ExprKind::Constant { .. }) + matches!(left.node, ExprKind::Constant { .. } | ExprKind::Name { .. }) && matches!( right.node, ExprKind::Tuple { .. } @@ -28,19 +31,35 @@ pub fn is_self_closing(expr: &Expr) -> bool { | ExprKind::Subscript { .. } ) } + ExprKind::BoolOp { values, .. } => values.last().map_or(false, |expr| { + matches!( + expr.node, + ExprKind::Tuple { .. } + | ExprKind::List { .. } + | ExprKind::Set { .. } + | ExprKind::Dict { .. } + | ExprKind::ListComp { .. } + | ExprKind::SetComp { .. } + | ExprKind::DictComp { .. } + | ExprKind::GeneratorExp { .. } + | ExprKind::Call { .. } + | ExprKind::Subscript { .. } + ) + }), + ExprKind::UnaryOp { operand, .. } => is_self_closing(operand), _ => false, } } /// Return `true` if an [`Expr`] adheres to Black's definition of a non-complex /// expression, in the context of a slice operation. -pub fn is_simple(expr: &Expr) -> bool { +pub fn is_simple_slice(expr: &Expr) -> bool { match &expr.node { ExprKind::UnaryOp { op, operand } => { if matches!(op, Unaryop::Not) { false } else { - is_simple(operand) + is_simple_slice(operand) } } ExprKind::Constant { .. } => true, @@ -48,3 +67,21 @@ pub fn is_simple(expr: &Expr) -> bool { _ => false, } } + +/// Return `true` if an [`Expr`] adheres to Black's definition of a non-complex +/// expression, in the context of a power operation. +pub fn is_simple_power(expr: &Expr) -> bool { + match &expr.node { + ExprKind::UnaryOp { op, operand } => { + if matches!(op, Unaryop::Not) { + false + } else { + is_simple_slice(operand) + } + } + ExprKind::Constant { .. } => true, + ExprKind::Name { .. } => true, + ExprKind::Attribute { .. } => true, + _ => false, + } +} diff --git a/crates/ruff_fmt/src/format/stmt.rs b/crates/ruff_fmt/src/format/stmt.rs index 6468476f3b62be..6e0e02dd7a7bb7 100644 --- a/crates/ruff_fmt/src/format/stmt.rs +++ b/crates/ruff_fmt/src/format/stmt.rs @@ -10,7 +10,7 @@ use crate::cst::{Alias, Arguments, Expr, ExprKind, Keyword, Stmt, StmtKind, With use crate::format::builders::{block, join_names}; use crate::format::helpers::is_self_closing; use crate::shared_traits::AsFormat; -use crate::trivia::{Relationship, TriviaKind}; +use crate::trivia::{Parenthesize, Relationship, TriviaKind}; fn format_break(f: &mut Formatter>) -> FormatResult<()> { write!(f, [text("break")]) @@ -49,7 +49,7 @@ fn format_continue(f: &mut Formatter>) -> FormatResult<()> fn format_global(f: &mut Formatter>, names: &[String]) -> FormatResult<()> { write!(f, [text("global")])?; if !names.is_empty() { - write!(f, [text(" ")])?; + write!(f, [space()])?; join_names(f, names)?; } Ok(()) @@ -58,7 +58,7 @@ fn format_global(f: &mut Formatter>, names: &[String]) -> F fn format_nonlocal(f: &mut Formatter>, names: &[String]) -> FormatResult<()> { write!(f, [text("nonlocal")])?; if !names.is_empty() { - write!(f, [text(" ")])?; + write!(f, [space()])?; join_names(f, names)?; } Ok(()) @@ -67,10 +67,10 @@ fn format_nonlocal(f: &mut Formatter>, names: &[String]) -> fn format_delete(f: &mut Formatter>, targets: &[Expr]) -> FormatResult<()> { write!(f, [text("del")])?; if targets.len() == 1 { - write!(f, [text(" ")])?; + write!(f, [space()])?; write!(f, [targets[0].format()])?; } else if !targets.is_empty() { - write!(f, [text(" ")])?; + write!(f, [space()])?; write!( f, [group(&format_args![ @@ -108,7 +108,7 @@ fn format_class_def( write!(f, [hard_line_break()])?; } write!(f, [text("class")])?; - write!(f, [text(" ")])?; + write!(f, [space()])?; write!(f, [dynamic_text(name, TextSize::default())])?; if !bases.is_empty() || !keywords.is_empty() { write!(f, [text("(")])?; @@ -168,10 +168,10 @@ fn format_func_def( } if async_ { write!(f, [text("async")])?; - write!(f, [text(" ")])?; + write!(f, [space()])?; } write!(f, [text("def")])?; - write!(f, [text(" ")])?; + write!(f, [space()])?; write!(f, [dynamic_text(name, TextSize::default())])?; write!(f, [text("(")])?; write!( @@ -233,17 +233,21 @@ fn format_assign( write!(f, [target.format()])?; } write!(f, [text(" = ")])?; - write!( - f, - [group(&format_args![ - if_group_breaks(&text("(")), - soft_block_indent(&format_with(|f| { - write!(f, [value.format()])?; - Ok(()) - })), - if_group_breaks(&text(")")), - ])] - )?; + if is_self_closing(value) { + write!(f, [group(&format_args![value.format()])])?; + } else { + write!( + f, + [group(&format_args![ + if_group_breaks(&text("(")), + soft_block_indent(&format_with(|f| { + write!(f, [value.format()])?; + Ok(()) + })), + if_group_breaks(&text(")")), + ])] + )?; + } // Apply any inline comments. let mut first = true; @@ -312,7 +316,7 @@ fn format_for( _type_comment: Option<&str>, ) -> FormatResult<()> { write!(f, [text("for")])?; - write!(f, [text(" ")])?; + write!(f, [space()])?; write!(f, [target.format()])?; write!(f, [text(" in ")])?; write!(f, [group(&format_args![iter.format()])])?; @@ -329,7 +333,7 @@ fn format_while( orelse: &[Stmt], ) -> FormatResult<()> { write!(f, [text("while")])?; - write!(f, [text(" ")])?; + write!(f, [space()])?; if is_self_closing(test) { write!(f, [test.format()])?; } else { @@ -358,7 +362,7 @@ fn format_if( orelse: &[Stmt], ) -> FormatResult<()> { write!(f, [text("if")])?; - write!(f, [text(" ")])?; + write!(f, [space()])?; if is_self_closing(test) { write!(f, [test.format()])?; } else { @@ -398,7 +402,7 @@ fn format_raise( ) -> FormatResult<()> { write!(f, [text("raise")])?; if let Some(exc) = exc { - write!(f, [text(" ")])?; + write!(f, [space()])?; write!(f, [exc.format()])?; if let Some(cause) = cause { write!(f, [text(" from ")])?; @@ -414,7 +418,7 @@ fn format_return( ) -> FormatResult<()> { write!(f, [text("return")])?; if let Some(value) = value { - write!(f, [text(" ")])?; + write!(f, [space()])?; write!(f, [value.format()])?; } Ok(()) @@ -427,7 +431,7 @@ fn format_assert( msg: Option<&Expr>, ) -> FormatResult<()> { write!(f, [text("assert")])?; - write!(f, [text(" ")])?; + write!(f, [space()])?; if is_self_closing(test) { write!(f, [test.format()])?; } else { @@ -453,7 +457,7 @@ fn format_import( names: &[Alias], ) -> FormatResult<()> { write!(f, [text("import")])?; - write!(f, [text(" ")])?; + write!(f, [space()])?; write!( f, @@ -486,7 +490,7 @@ fn format_import_from( level: Option<&usize>, ) -> FormatResult<()> { write!(f, [text("from")])?; - write!(f, [text(" ")])?; + write!(f, [space()])?; if let Some(level) = level { for _ in 0..*level { @@ -496,10 +500,10 @@ fn format_import_from( if let Some(module) = module { write!(f, [dynamic_text(module, TextSize::default())])?; } - write!(f, [text(" ")])?; + write!(f, [space()])?; write!(f, [text("import")])?; - write!(f, [text(" ")])?; + write!(f, [space()])?; write!( f, @@ -529,8 +533,17 @@ fn format_expr( stmt: &Stmt, expr: &Expr, ) -> FormatResult<()> { - if is_self_closing(expr) { - write!(f, [expr.format()])?; + if matches!(stmt.parentheses, Parenthesize::Always) { + write!( + f, + [group(&format_args![ + text("("), + soft_block_indent(&format_args![expr.format()]), + text(")"), + ])] + )?; + } else if is_self_closing(expr) { + write!(f, [group(&format_args![expr.format()])])?; } else { write!( f, @@ -574,10 +587,10 @@ fn format_with_( ) -> FormatResult<()> { if async_ { write!(f, [text("async")])?; - write!(f, [text(" ")])?; + write!(f, [space()])?; } write!(f, [text("with")])?; - write!(f, [text(" ")])?; + write!(f, [space()])?; write!( f, [group(&format_args![ diff --git a/crates/ruff_fmt/src/lib.rs b/crates/ruff_fmt/src/lib.rs index 2dc10692f65c38..95680df20ef45c 100644 --- a/crates/ruff_fmt/src/lib.rs +++ b/crates/ruff_fmt/src/lib.rs @@ -8,6 +8,7 @@ use crate::core::locator::Locator; use crate::core::rustpython_helpers; use crate::cst::Stmt; use crate::newlines::normalize_newlines; +use crate::parentheses::normalize_parentheses; mod attachment; pub mod builders; @@ -17,6 +18,7 @@ mod core; mod cst; mod format; mod newlines; +mod parentheses; pub mod shared_traits; #[cfg(test)] mod test; @@ -38,6 +40,7 @@ pub fn fmt(contents: &str) -> Result { // Attach trivia. attach(&mut python_cst, trivia); normalize_newlines(&mut python_cst); + normalize_parentheses(&mut python_cst); let elements = format!( ASTFormatContext::new( @@ -62,7 +65,6 @@ mod tests { use crate::fmt; use crate::test::test_resource_path; - #[test_case(Path::new("simple_cases/tupleassign.py"); "tupleassign")] #[test_case(Path::new("simple_cases/class_blank_parentheses.py"); "class_blank_parentheses")] #[test_case(Path::new("simple_cases/class_methods_new_line.py"); "class_methods_new_line")] #[test_case(Path::new("simple_cases/beginning_backslash.py"); "beginning_backslash")] @@ -76,10 +78,13 @@ mod tests { Ok(()) } + // Passing apart from one deviation in RHS tuple assignment. + #[test_case(Path::new("simple_cases/tupleassign.py"); "tupleassign")] + // Lots of deviations, _mostly_ related to string normalization and wrapping. + #[test_case(Path::new("simple_cases/expression.py"); "expression")] // #[test_case(Path::new("simple_cases/comments.py"); "comments")] // #[test_case(Path::new("simple_cases/function.py"); "function")] // #[test_case(Path::new("simple_cases/empty_lines.py"); "empty_lines")] - // #[test_case(Path::new("simple_cases/expression.py"); "expression")] fn failing(path: &Path) -> Result<()> { let snapshot = format!("{}", path.display()); let content = std::fs::read_to_string(test_resource_path( diff --git a/crates/ruff_fmt/src/parentheses.rs b/crates/ruff_fmt/src/parentheses.rs new file mode 100644 index 00000000000000..c3903b3ac796c5 --- /dev/null +++ b/crates/ruff_fmt/src/parentheses.rs @@ -0,0 +1,169 @@ +use crate::core::visitor; +use crate::core::visitor::Visitor; +use crate::cst::{Expr, ExprKind, Stmt, StmtKind}; +use crate::trivia::{Parenthesize, TriviaKind}; + +/// Modify an [`Expr`] to infer parentheses, rather than respecting any user-provided trivia. +fn use_inferred_parens(expr: &mut Expr) { + // Remove parentheses, unless it's a generator expression, in which case, keep them. + if !matches!(expr.node, ExprKind::GeneratorExp { .. }) { + expr.trivia + .retain(|trivia| !matches!(trivia.kind, TriviaKind::Parentheses)); + } + + // If it's a tuple, add parentheses if it's a singleton; otherwise, we only need parentheses + // if the tuple expands. + if let ExprKind::Tuple { elts, .. } = &expr.node { + expr.parentheses = if elts.len() > 1 { + Parenthesize::IfExpanded + } else { + Parenthesize::Always + }; + } +} + +struct ParenthesesNormalizer {} + +impl<'a> Visitor<'a> for ParenthesesNormalizer { + fn visit_stmt(&mut self, stmt: &'a mut Stmt) { + // Always remove parentheses around statements, unless it's an expression statement, + // in which case, remove parentheses around the expression. + let before = stmt.trivia.len(); + stmt.trivia + .retain(|trivia| !matches!(trivia.kind, TriviaKind::Parentheses)); + let after = stmt.trivia.len(); + if let StmtKind::Expr { value } = &mut stmt.node { + if before != after { + stmt.parentheses = Parenthesize::Always; + value.parentheses = Parenthesize::Never; + } + } + + // In a variety of contexts, remove parentheses around sub-expressions. Right now, the + // pattern is consistent (and repeated), but it may not end up that way. + // https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#parentheses + match &mut stmt.node { + StmtKind::FunctionDef { .. } => {} + StmtKind::AsyncFunctionDef { .. } => {} + StmtKind::ClassDef { .. } => {} + StmtKind::Return { value } => { + if let Some(value) = value { + use_inferred_parens(value); + } + } + StmtKind::Delete { .. } => {} + StmtKind::Assign { targets, value, .. } => { + for target in targets { + use_inferred_parens(target); + } + use_inferred_parens(value); + } + StmtKind::AugAssign { value, .. } => { + use_inferred_parens(value); + } + StmtKind::AnnAssign { value, .. } => { + if let Some(value) = value { + use_inferred_parens(value); + } + } + StmtKind::For { target, iter, .. } | StmtKind::AsyncFor { target, iter, .. } => { + use_inferred_parens(target); + use_inferred_parens(iter); + } + StmtKind::While { test, .. } => { + use_inferred_parens(test); + } + StmtKind::If { test, .. } => { + use_inferred_parens(test); + } + StmtKind::With { .. } => {} + StmtKind::AsyncWith { .. } => {} + StmtKind::Match { .. } => {} + StmtKind::Raise { .. } => {} + StmtKind::Try { .. } => {} + StmtKind::Assert { test, msg } => { + use_inferred_parens(test); + if let Some(msg) = msg { + use_inferred_parens(msg); + } + } + StmtKind::Import { .. } => {} + StmtKind::ImportFrom { .. } => {} + StmtKind::Global { .. } => {} + StmtKind::Nonlocal { .. } => {} + StmtKind::Expr { .. } => {} + StmtKind::Pass => {} + StmtKind::Break => {} + StmtKind::Continue => {} + } + + visitor::walk_stmt(self, stmt); + } + + fn visit_expr(&mut self, expr: &'a mut Expr) { + // Always retain parentheses around expressions. + let before = expr.trivia.len(); + expr.trivia + .retain(|trivia| !matches!(trivia.kind, TriviaKind::Parentheses)); + let after = expr.trivia.len(); + if before != after { + expr.parentheses = Parenthesize::Always; + } + + match &mut expr.node { + ExprKind::BoolOp { .. } => {} + ExprKind::NamedExpr { .. } => {} + ExprKind::BinOp { .. } => {} + ExprKind::UnaryOp { .. } => {} + ExprKind::Lambda { .. } => {} + ExprKind::IfExp { .. } => {} + ExprKind::Dict { .. } => {} + ExprKind::Set { .. } => {} + ExprKind::ListComp { .. } => {} + ExprKind::SetComp { .. } => {} + ExprKind::DictComp { .. } => {} + ExprKind::GeneratorExp { .. } => {} + ExprKind::Await { .. } => {} + ExprKind::Yield { .. } => {} + ExprKind::YieldFrom { .. } => {} + ExprKind::Compare { .. } => {} + ExprKind::Call { .. } => {} + ExprKind::FormattedValue { .. } => {} + ExprKind::JoinedStr { .. } => {} + ExprKind::Constant { .. } => {} + ExprKind::Attribute { .. } => {} + ExprKind::Subscript { value, slice, .. } => { + // If the slice isn't manually parenthesized, ensure that we _never_ parenthesize + // the value. + if !slice + .trivia + .iter() + .any(|trivia| matches!(trivia.kind, TriviaKind::Parentheses)) + { + value.parentheses = Parenthesize::Never; + } + } + ExprKind::Starred { .. } => {} + ExprKind::Name { .. } => {} + ExprKind::List { .. } => {} + ExprKind::Tuple { .. } => {} + ExprKind::Slice { .. } => {} + } + + visitor::walk_expr(self, expr); + } +} + +/// Normalize parentheses in a Python CST. +/// +/// It's not always possible to determine the correct parentheses to use during formatting +/// from the node (and trivia) alone; sometimes, we need to know the parent node. This +/// visitor normalizes parentheses via a top-down traversal, which simplifies the formatting +/// code later on. +/// +/// TODO(charlie): It's weird that we have both `TriviaKind::Parenthese` (which aren't used +/// during formatting) and `Parenthesize` (which are used during formatting). +pub fn normalize_parentheses(python_cst: &mut [Stmt]) { + let mut normalizer = ParenthesesNormalizer {}; + normalizer.visit_body(python_cst); +} diff --git a/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__empty_lines.py.snap.new b/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__empty_lines.py.snap.new index 3604cf40555eaf..768fd9270ccaa9 100644 --- a/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__empty_lines.py.snap.new +++ b/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__empty_lines.py.snap.new @@ -1,6 +1,6 @@ --- -source: ruff_fmt/src/lib.rs -assertion_line: 75 +source: crates/ruff_fmt/src/lib.rs +assertion_line: 89 expression: printed.as_code() --- """Docstring.""" diff --git a/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__expression.py.snap.new b/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__expression.py.snap.new index f2b0e874e88ab2..8ab07a8fdcedde 100644 --- a/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__expression.py.snap.new +++ b/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__expression.py.snap.new @@ -1,6 +1,6 @@ --- source: crates/ruff_fmt/src/lib.rs -assertion_line: 88 +assertion_line: 94 expression: printed.as_code() --- ... @@ -17,63 +17,59 @@ True or False True or False or None True and False True and False and None +(Name1 and Name2) or Name3 Name1 and Name2 or Name3 -Name1 and Name2 or Name3 -Name1 or Name2 and Name3 +Name1 or (Name2 and Name3) Name1 or Name2 and Name3 +(Name1 and Name2) or (Name3 and Name4) Name1 and Name2 or Name3 and Name4 -Name1 and Name2 or Name3 and Name4 -Name1 or Name2 and Name3 or Name4 +Name1 or (Name2 and Name3) or Name4 Name1 or Name2 and Name3 or Name4 v1 << 2 1 >> v2 1 % finished -1 + v2 - v3 * 4 ^ 5 ** v6 / 7 // 8 -1 + v2 - v3 * 4 ^ 5 ** v6 / 7 // 8 +1 + v2 - v3 * 4 ^ 5**v6 / 7 // 8 +((1 + v2) - (v3 * 4)) ^ (((5**v6) / 7) // 8) not great ~great +value -1 ~int and not v1 ^ 123 + v2 | True -~int and not v1 ^ 123 + v2 | True -+really ** -confusing ** ~operator ** -precedence +(~int) and (not ((v1 ^ (123 + v2)) | True)) ++(really ** -(confusing ** ~(operator**-precedence))) flags & ~select.EPOLLIN and waiters.write_task is not None lambda arg: None lambda a=True: a lambda a, b, c=True: a -lambda a, b, c=True, *, d=1 << v2, e='str': a -lambda a, b, c=True, *vararg, d=v1 << 2, e='str', **kwargs: a + b +lambda a, b, c=True, *, d=(1 << v2), e='str': a +lambda a, b, c=True, *vararg, d=(v1 << 2), e='str', **kwargs: a + b manylambdas = lambda x=lambda y=lambda z=1: z: y(): x() -foo = ( - lambda port_id, - ignore_missing,: {"port1": port1_resource, "port2": port2_resource}[port_id] -) +foo = lambda port_id, +ignore_missing,: { + "port1": port1_resource, + "port2": port2_resource, +}[port_id] 1 if True else 2 str or None if True else str or bytes or None -str or None if True else str or bytes or None -str or None if 1 if True else 2 else str or bytes or None -str or None if 1 if True else 2 else str or bytes or None +(str or None) if True else (str or bytes or None) +str or None if (1 if True else 2) else str or bytes or None +(str or None) if (1 if True else 2) else (str or bytes or None) ( - super_long_variable_name - or None - if 1 - if super_long_test_name - else 2 - else str - or bytes - or None + (super_long_variable_name or None) + if (1 if super_long_test_name else 2) + else (str or bytes or None) ) -{'2.7': dead, '3.7': long_live or die_hard} -{'2.7': dead, '3.7': long_live or die_hard, **{'3.6': verygood}} +{'2.7': dead, '3.7': (long_live or die_hard)} +{'2.7': dead, '3.7': (long_live or die_hard), **{'3.6': verygood}} {**a, **b, **c} -{'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'} -({'a': 'b'}, True or False, +value, 'string', b'bytes') or None +{'2.7', '3.6', '3.7', '3.8', '3.9', ('4.0' if gilectomy else '3.10')} +({'a': 'b'}, (True or False), (+value), 'string', b'bytes') or None () (1,) (1, 2) (1, 2, 3) [] -[1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C] +[1, 2, 3, 4, 5, 6, 7, 8, 9, (10 or A), (11 or B), (12 or C)] [ 1, 2, @@ -98,23 +94,21 @@ str or None if 1 if True else 2 else str or bytes or None *more, ] {i for i in (1, 2, 3)} -{i ** 2 for i in (1, 2, 3)} -{i ** 2 for (i, _) in ((1, 'a'), (2, 'b'), (3, 'c'))} -{i ** 2 + j for i in (1, 2, 3) for j in (1, 2, 3)} +{(i**2) for i in (1, 2, 3)} +{(i**2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))} +{((i**2) + j) for i in (1, 2, 3) for j in (1, 2, 3)} [i for i in (1, 2, 3)] -[i ** 2 for i in (1, 2, 3)] -[i ** 2 for (i, _) in ((1, 'a'), (2, 'b'), (3, 'c'))] -[i ** 2 + j for i in (1, 2, 3) for j in (1, 2, 3)] +[(i**2) for i in (1, 2, 3)] +[(i**2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))] +[((i**2) + j) for i in (1, 2, 3) for j in (1, 2, 3)] {i: 0 for i in (1, 2, 3)} -{i: j for (i, j) in ((1, 'a'), (2, 'b'), (3, 'c'))} -{a: b * 2 for (a, b) in dictionary.items()} -{a: b * -2 for (a, b) in dictionary.items()} +{i: j for i, j in ((1, 'a'), (2, 'b'), (3, 'c'))} +{a: b * 2 for a, b in dictionary.items()} +{a: b * -2 for a, b in dictionary.items()} { k: v - for ( - k, - v, - ) in this_is_a_very_long_variable_which_will_cause_a_trailing_comma_which_breaks_the_comprehension + for k, + v in this_is_a_very_long_variable_which_will_cause_a_trailing_comma_which_breaks_the_comprehension } Python3 > Python2 > COBOL Life is Life @@ -140,27 +134,25 @@ call.me(maybe) 1.0.real ....__class__ list[str] -dict[(str, int)] -tuple[(str, ...)] -tuple[(str, int, float, dict[(str, int)])] +dict[str, int] +tuple[str, ...] +tuple[str, int, float, dict[str, int]] tuple[ - ( - str, - int, - float, - dict[(str, int)], - ) + str, + int, + float, + dict[str, int], ] very_long_variable_name_filters: t.List[ - (t.Tuple[(str, t.Union[(str, t.List[t.Optional[str]])])],) + t.Tuple[str, t.Union[str, t.List[t.Optional[str]]]], ] -xxxx_xxxxx_xxxx_xxx: Callable[(..., List[SomeClass])] = ( +xxxx_xxxxx_xxxx_xxx: Callable[..., List[SomeClass]] = ( classmethod # type: ignore(sync(async_xxxx_xxx_xxxx_xxxxx_xxxx_xxx.__func__)) ) -xxxx_xxx_xxxx_xxxxx_xxxx_xxx: Callable[(..., List[SomeClass])] = ( +xxxx_xxx_xxxx_xxxxx_xxxx_xxx: Callable[..., List[SomeClass]] = ( classmethod # type: ignore(sync(async_xxxx_xxx_xxxx_xxxxx_xxxx_xxx.__func__)) ) -xxxx_xxx_xxxx_xxxxx_xxxx_xxx: Callable[(..., List[SomeClass])] = ( +xxxx_xxx_xxxx_xxxxx_xxxx_xxx: Callable[..., List[SomeClass]] = ( classmethod(sync(async_xxxx_xxx_xxxx_xxxxx_xxxx_xxx.__func__)) ) slice[0] @@ -171,35 +163,35 @@ slice[:-1] slice[1:] slice[::-1] slice[d :: d + 1] -slice[(:c, c - 1)] -numpy[(:, 0:1)] -numpy[(:, :-1)] -numpy[(0, :)] -numpy[(:, i)] -numpy[(0, :2)] -numpy[(:N, 0)] -numpy[(:2, :4)] -numpy[(2:4, 1:5)] -numpy[(4:, 2:)] -numpy[(:, (0, 1, 2, 5))] -numpy[(0, [0])] -numpy[(:, [i])] -numpy[(1 : c + 1 , c)] -numpy[(-c + 1 :, d)] -numpy[(:, l[-2])] -numpy[(:, ::-1)] -numpy[(np.newaxis, :)] -str or None if sys.version_info[0] > (3,) else str or bytes or None +slice[:c, c - 1] +numpy[:, 0:1] +numpy[:, :-1] +numpy[0, :] +numpy[:, i] +numpy[0, :2] +numpy[:N, 0] +numpy[:2, :4] +numpy[2:4, 1:5] +numpy[4:, 2:] +numpy[:, (0, 1, 2, 5)] +numpy[0, [0]] +numpy[:, [i]] +numpy[1 : c + 1, c] +numpy[-(c + 1) :, d] +numpy[:, l[-2]] +numpy[:, ::-1] +numpy[np.newaxis, :] +(str or None) if (sys.version_info[0] > (3,)) else (str or bytes or None) {'2.7': dead, '3.7': long_live or die_hard} {'2.7', '3.6', '3.7', '3.8', '3.9', '4.0' if gilectomy else '3.10'} [1, 2, 3, 4, 5, 6, 7, 8, 9, 10 or A, 11 or B, 12 or C] -SomeName +(SomeName) SomeName (Good, Bad, Ugly) (i for i in (1, 2, 3)) -(i ** 2 for i in (1, 2, 3)) -(i ** 2 for (i, _) in ((1, 'a'), (2, 'b'), (3, 'c'))) -(i ** 2 + j for i in (1, 2, 3) for j in (1, 2, 3)) +((i**2) for i in (1, 2, 3)) +((i**2) for i, _ in ((1, 'a'), (2, 'b'), (3, 'c'))) +(((i**2) + j) for i in (1, 2, 3) for j in (1, 2, 3)) (*starred,) { "id": "1", @@ -215,48 +207,45 @@ b = (1,) c = 1 d = (1,) + a + (2,) e = (1,).count(1) -f = (1, *range(10)) -g = (1, *"ten") +f = 1, *range(10) +g = 1, *"ten" what_is_up_with_those_new_coord_names = ( - coord_names - + set(vars_to_create) + (coord_names + set(vars_to_create)) + set(vars_to_remove) ) what_is_up_with_those_new_coord_names = ( - coord_names - | set(vars_to_create) + (coord_names | set(vars_to_create)) - set(vars_to_remove) ) -result = ( - session.query(models.Customer.id).filter( - models.Customer.account_id == account_id, - models.Customer.email == email_address, - ).order_by(models.Customer.id.asc()).all() -) -result = ( - session.query(models.Customer.id).filter( - models.Customer.account_id == account_id, - models.Customer.email == email_address, - ).order_by( - models.Customer.id.asc(), - ).all() -) +result = session.query(models.Customer.id).filter( + models.Customer.account_id == account_id, + models.Customer.email == email_address, +).order_by(models.Customer.id.asc()).all() +result = session.query(models.Customer.id).filter( + models.Customer.account_id == account_id, + models.Customer.email == email_address, +).order_by( + models.Customer.id.asc(), +).all() Ø = set() authors.łukasz.say_thanks() -mapping = ( - {A: 0.25 * 10.0 / 12, B: 0.1 * 10.0 / 12, C: 0.1 * 10.0 / 12, D: 0.1 * 10.0 / 12} -) +mapping = { + A: 0.25 * (10.0 / 12), + B: 0.1 * (10.0 / 12), + C: 0.1 * (10.0 / 12), + D: 0.1 * (10.0 / 12), +} def gen(): - (yield from outside_of_generator) - a = (yield) - b = (yield) - c = (yield) + yield from outside_of_generator + a = yield + b = yield + c = yield async def f(): - await some.complicated[0].call(with_args=True or 1 is not 1) + await some.complicated[0].call(with_args=(True or (1 is not 1))) print(*[] or [1]) @@ -268,7 +257,7 @@ assert ( and not requirements.fit_in_a_single_line(force=False) ), "Short message" assert parens is TooMany -for (x,) in ((1,), (2,), (3,)): +for (x,) in (1,), (2,), (3,): ... for y in (): ... @@ -276,7 +265,7 @@ for z in (i for i in (1, 2, 3)): ... for i in call(): ... -for j in 1 + 2 + 3: +for j in 1 + (2 + 3): ... while this and that: ... @@ -285,20 +274,24 @@ for ( addr_type, addr_proto, addr_canonname, - addr_sockaddr, + addr_sockaddr ) in socket.getaddrinfo('google.com', 'http'): pass a = ( - aaaa.bbbb.cccc.dddd.eeee.ffff.gggg.hhhh.iiii.jjjj.kkkk.llll.mmmm.nnnn.oooo.pppp in qqqq.rrrr.ssss.tttt.uuuu.vvvv.xxxx.yyyy.zzzz + aaaa.bbbb.cccc.dddd.eeee.ffff.gggg.hhhh.iiii.jjjj.kkkk.llll.mmmm.nnnn.oooo.pppp + in qqqq.rrrr.ssss.tttt.uuuu.vvvv.xxxx.yyyy.zzzz ) a = ( - aaaa.bbbb.cccc.dddd.eeee.ffff.gggg.hhhh.iiii.jjjj.kkkk.llll.mmmm.nnnn.oooo.pppp not in qqqq.rrrr.ssss.tttt.uuuu.vvvv.xxxx.yyyy.zzzz + aaaa.bbbb.cccc.dddd.eeee.ffff.gggg.hhhh.iiii.jjjj.kkkk.llll.mmmm.nnnn.oooo.pppp + not in qqqq.rrrr.ssss.tttt.uuuu.vvvv.xxxx.yyyy.zzzz ) a = ( - aaaa.bbbb.cccc.dddd.eeee.ffff.gggg.hhhh.iiii.jjjj.kkkk.llll.mmmm.nnnn.oooo.pppp is qqqq.rrrr.ssss.tttt.uuuu.vvvv.xxxx.yyyy.zzzz + aaaa.bbbb.cccc.dddd.eeee.ffff.gggg.hhhh.iiii.jjjj.kkkk.llll.mmmm.nnnn.oooo.pppp + is qqqq.rrrr.ssss.tttt.uuuu.vvvv.xxxx.yyyy.zzzz ) a = ( - aaaa.bbbb.cccc.dddd.eeee.ffff.gggg.hhhh.iiii.jjjj.kkkk.llll.mmmm.nnnn.oooo.pppp is not qqqq.rrrr.ssss.tttt.uuuu.vvvv.xxxx.yyyy.zzzz + aaaa.bbbb.cccc.dddd.eeee.ffff.gggg.hhhh.iiii.jjjj.kkkk.llll.mmmm.nnnn.oooo.pppp + is not qqqq.rrrr.ssss.tttt.uuuu.vvvv.xxxx.yyyy.zzzz ) if ( threading.current_thread() != threading.main_thread() @@ -337,62 +330,28 @@ if ( ): return True if ( - ~aaaa.a - + aaaa.b - - aaaa.c - * aaaa.d - / aaaa.e - | aaaa.f - & aaaa.g - % aaaa.h - ^ aaaa.i - << aaaa.k - >> aaaa.l - ** aaaa.m - // aaaa.n + ~aaaa.a + aaaa.b - aaaa.c * aaaa.d / aaaa.e + | aaaa.f & aaaa.g % aaaa.h ^ aaaa.i << aaaa.k >> aaaa.l**aaaa.m // aaaa.n ): return True if ( - ~aaaaaaaa.a - + aaaaaaaa.b - - aaaaaaaa.c - @ aaaaaaaa.d - / aaaaaaaa.e - | aaaaaaaa.f - & aaaaaaaa.g - % aaaaaaaa.h - ^ aaaaaaaa.i - << aaaaaaaa.k - >> aaaaaaaa.l - ** aaaaaaaa.m - // aaaaaaaa.n + ~aaaaaaaa.a + aaaaaaaa.b - aaaaaaaa.c @ aaaaaaaa.d / aaaaaaaa.e + | aaaaaaaa.f & aaaaaaaa.g % aaaaaaaa.h + ^ aaaaaaaa.i << aaaaaaaa.k >> aaaaaaaa.l**aaaaaaaa.m // aaaaaaaa.n ): return True if ( - ~aaaaaaaaaaaaaaaa.a - + aaaaaaaaaaaaaaaa.b - - aaaaaaaaaaaaaaaa.c - * aaaaaaaaaaaaaaaa.d - @ aaaaaaaaaaaaaaaa.e - | aaaaaaaaaaaaaaaa.f - & aaaaaaaaaaaaaaaa.g - % aaaaaaaaaaaaaaaa.h - ^ aaaaaaaaaaaaaaaa.i - << aaaaaaaaaaaaaaaa.k - >> aaaaaaaaaaaaaaaa.l - ** aaaaaaaaaaaaaaaa.m - // aaaaaaaaaaaaaaaa.n + ~aaaaaaaaaaaaaaaa.a + aaaaaaaaaaaaaaaa.b + - aaaaaaaaaaaaaaaa.c * aaaaaaaaaaaaaaaa.d @ aaaaaaaaaaaaaaaa.e + | aaaaaaaaaaaaaaaa.f & aaaaaaaaaaaaaaaa.g % aaaaaaaaaaaaaaaa.h + ^ aaaaaaaaaaaaaaaa.i << aaaaaaaaaaaaaaaa.k + >> aaaaaaaaaaaaaaaa.l**aaaaaaaaaaaaaaaa.m // aaaaaaaaaaaaaaaa.n ): return True ( - aaaaaaaaaaaaaaaa - + aaaaaaaaaaaaaaaa - - aaaaaaaaaaaaaaaa - * aaaaaaaaaaaaaaaa - + aaaaaaaaaaaaaaaa - / aaaaaaaaaaaaaaaa - + aaaaaaaaaaaaaaaa - + aaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaa + - aaaaaaaaaaaaaaaa * (aaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaa) + / (aaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaa) ) aaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaa ( @@ -403,8 +362,7 @@ aaaaaaaaaaaaaaaa + aaaaaaaaaaaaaaaa bbbb >> bbbb * bbbb ( aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa - ^ bbbb.a - & aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa + ^ bbbb.a & aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ^ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa ) last_call()# standalone comment at ENDMARKER diff --git a/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__function.py.snap.new b/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__function.py.snap.new index abefec0b00c650..80c56ab7187dcd 100644 --- a/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__function.py.snap.new +++ b/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__function.py.snap.new @@ -1,6 +1,6 @@ --- -source: ruff_fmt/src/lib.rs -assertion_line: 78 +source: crates/ruff_fmt/src/lib.rs +assertion_line: 89 expression: printed.as_code() --- #!/usr/bin/env python3 @@ -51,7 +51,7 @@ def function_signature_stress_test( def spaces(a=1, b=(), c=[], d={}, e=True, f=-1, g=1 if False else 2, h="", i=r''): offset = attr.ib(default=attr.Factory(lambda: _r.uniform(10000, 200000))) - assert task._cancel_stack[: len(old_stack)] == old_stack + assert task._cancel_stack[: len(old_stack) ] == old_stack def spaces_types( diff --git a/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__tupleassign.py.snap.new b/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__tupleassign.py.snap.new new file mode 100644 index 00000000000000..dd71f95b5144f8 --- /dev/null +++ b/crates/ruff_fmt/src/snapshots/ruff_fmt__tests__simple_cases__tupleassign.py.snap.new @@ -0,0 +1,17 @@ +--- +source: crates/ruff_fmt/src/lib.rs +assertion_line: 94 +expression: printed.as_code() +--- +# This is a standalone comment. +( + sdfjklsdfsjldkflkjsf, + sdfjsdfjlksdljkfsdlkf, + sdfsdjfklsdfjlksdljkf, + sdsfsdfjskdflsfsdf +) = 1, 2, 3 + +# This is as well. +(this_will_be_wrapped_in_parens,) = struct.unpack(b"12345678901234567890") + +(a,) = call() diff --git a/crates/ruff_fmt/src/trivia.rs b/crates/ruff_fmt/src/trivia.rs index 3e0236ef4966c5..284dcea9ad49b1 100644 --- a/crates/ruff_fmt/src/trivia.rs +++ b/crates/ruff_fmt/src/trivia.rs @@ -33,8 +33,7 @@ pub enum TriviaTokenKind { InlineComment, MagicTrailingComma, EmptyLine, - LeftParen, - RightParen, + Parentheses, } #[derive(Clone, Debug, PartialEq, Eq)] @@ -50,8 +49,7 @@ pub enum TriviaKind { InlineComment(Range), MagicTrailingComma, EmptyLine, - LeftParen, - RightParen, + Parentheses, } #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -61,6 +59,16 @@ pub enum Relationship { Dangling, } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Parenthesize { + /// Always parenthesize the statement or expression. + Always, + /// Never parenthesize the statement or expression. + Never, + /// Parenthesize the statement or expression if it expands. + IfExpanded, +} + #[derive(Clone, Debug, PartialEq, Eq)] pub struct Trivia { pub kind: TriviaKind, @@ -86,12 +94,8 @@ impl Trivia { kind: TriviaKind::InlineComment(Range::new(token.start, token.end)), relationship, }, - TriviaTokenKind::LeftParen => Self { - kind: TriviaKind::LeftParen, - relationship, - }, - TriviaTokenKind::RightParen => Self { - kind: TriviaKind::RightParen, + TriviaTokenKind::Parentheses => Self { + kind: TriviaKind::Parentheses, relationship, }, } @@ -101,11 +105,12 @@ impl Trivia { pub fn extract_trivia_tokens(lxr: &[LexResult]) -> Vec { let mut tokens = vec![]; let mut prev_tok: Option<(&Location, &Tok, &Location)> = None; + let mut prev_non_newline_tok: Option<(&Location, &Tok, &Location)> = None; let mut prev_semantic_tok: Option<(&Location, &Tok, &Location)> = None; let mut parens = vec![]; for (start, tok, end) in lxr.iter().flatten() { // Add empty lines. - if let Some((prev, ..)) = prev_tok { + if let Some((prev, ..)) = prev_non_newline_tok { for row in prev.row() + 1..start.row() { tokens.push(TriviaToken { start: Location::new(row, 0), @@ -120,7 +125,7 @@ pub fn extract_trivia_tokens(lxr: &[LexResult]) -> Vec { tokens.push(TriviaToken { start: *start, end: *end, - kind: if prev_tok.map_or(true, |(prev, ..)| prev.row() < start.row()) { + kind: if prev_non_newline_tok.map_or(true, |(prev, ..)| prev.row() < start.row()) { TriviaTokenKind::StandaloneComment } else { TriviaTokenKind::InlineComment @@ -129,7 +134,10 @@ pub fn extract_trivia_tokens(lxr: &[LexResult]) -> Vec { } // Add magic trailing commas. - if matches!(tok, Tok::Rpar | Tok::Rsqb | Tok::Rbrace) { + if matches!( + tok, + Tok::Rpar | Tok::Rsqb | Tok::Rbrace | Tok::Equal | Tok::Newline + ) { if let Some((prev_start, prev_tok, prev_end)) = prev_semantic_tok { if prev_tok == &Tok::Comma { tokens.push(TriviaToken { @@ -142,15 +150,10 @@ pub fn extract_trivia_tokens(lxr: &[LexResult]) -> Vec { } if matches!(tok, Tok::Lpar) { - if prev_semantic_tok.map_or(true, |(_, prev_tok, _)| { + if prev_tok.map_or(true, |(_, prev_tok, _)| { !matches!(prev_tok, Tok::Name { .. }) }) { parens.push((start, true)); - // tokens.push(TriviaToken { - // start: *start, - // end: *end, - // kind: TriviaTokenKind::LeftParen, - // }); } else { parens.push((start, false)); } @@ -160,14 +163,16 @@ pub fn extract_trivia_tokens(lxr: &[LexResult]) -> Vec { tokens.push(TriviaToken { start: *start, end: *end, - kind: TriviaTokenKind::LeftParen, + kind: TriviaTokenKind::Parentheses, }); } } + prev_tok = Some((start, tok, end)); + // Track the most recent non-whitespace token. if !matches!(tok, Tok::Newline | Tok::NonLogicalNewline,) { - prev_tok = Some((start, tok, end)); + prev_non_newline_tok = Some((start, tok, end)); } // Track the most recent semantic token. @@ -426,13 +431,19 @@ fn sorted_child_nodes_inner<'a>(node: &Node<'a>, result: &mut Vec>) { ExprKind::UnaryOp { operand, .. } => { result.push(Node::Expr(operand)); } - ExprKind::Lambda { body, .. } => { + ExprKind::Lambda { body, args, .. } => { // TODO(charlie): Arguments. + for expr in &args.defaults { + result.push(Node::Expr(expr)); + } + for expr in &args.kw_defaults { + result.push(Node::Expr(expr)); + } result.push(Node::Expr(body)); } ExprKind::IfExp { test, body, orelse } => { - result.push(Node::Expr(test)); result.push(Node::Expr(body)); + result.push(Node::Expr(test)); result.push(Node::Expr(orelse)); } ExprKind::Dict { keys, values } => { @@ -597,6 +608,7 @@ pub fn decorate_token<'a>( token: &TriviaToken, node: &Node<'a>, enclosing_node: Option>, + enclosed_node: Option>, cache: &mut FxHashMap>>, ) -> ( Option>, @@ -610,6 +622,7 @@ pub fn decorate_token<'a>( let mut preceding_node = None; let mut following_node = None; + let mut enclosed_node = enclosed_node; let mut left = 0; let mut right = child_nodes.len(); @@ -632,9 +645,41 @@ pub fn decorate_token<'a>( Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"), }; + if let Some(existing) = &enclosed_node { + // Special-case: if we're dealing with a statement that's a single expression, + // we want to treat the expression as the enclosed node. + let existing_start = match &existing { + Node::Stmt(node) => node.location, + Node::Expr(node) => node.location, + Node::Alias(node) => node.location, + Node::Excepthandler(node) => node.location, + Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"), + }; + let existing_end = match &existing { + Node::Stmt(node) => node.end_location.unwrap(), + Node::Expr(node) => node.end_location.unwrap(), + Node::Alias(node) => node.end_location.unwrap(), + Node::Excepthandler(node) => node.end_location.unwrap(), + Node::Mod(..) => unreachable!("Node::Mod cannot be a child node"), + }; + if start == existing_start && end == existing_end { + enclosed_node = Some(child.clone()); + } + } else { + if token.start <= start && token.end >= end { + enclosed_node = Some(child.clone()); + } + } + // The comment is completely contained by this child node. if token.start >= start && token.end <= end { - return decorate_token(token, &child.clone(), Some(child.clone()), cache); + return decorate_token( + token, + &child.clone(), + Some(child.clone()), + enclosed_node, + cache, + ); } if end <= token.start { @@ -657,11 +702,15 @@ pub fn decorate_token<'a>( continue; } - // Return the enclosed node. - return (None, None, None, Some(child.clone())); + return (None, None, None, enclosed_node); } - (preceding_node, following_node, enclosing_node, None) + ( + preceding_node, + following_node, + enclosing_node, + enclosed_node, + ) } #[derive(Debug, Default)] @@ -712,7 +761,8 @@ pub fn decorate_trivia(tokens: Vec, python_ast: &[Stmt]) -> TriviaI let mut cache = FxHashMap::default(); for token in &tokens { let (preceding_node, following_node, enclosing_node, enclosed_node) = - decorate_token(token, &Node::Mod(python_ast), None, &mut cache); + decorate_token(token, &Node::Mod(python_ast), None, None, &mut cache); + stack.push(( preceding_node, following_node, @@ -724,7 +774,7 @@ pub fn decorate_trivia(tokens: Vec, python_ast: &[Stmt]) -> TriviaI let mut trivia_index = TriviaIndex::default(); for (index, token) in tokens.into_iter().enumerate() { - let (preceding_node, following_node, enclosing_node, ..) = &stack[index]; + let (preceding_node, following_node, enclosing_node, enclosed_node) = &stack[index]; match token.kind { TriviaTokenKind::EmptyLine | TriviaTokenKind::StandaloneComment => { if let Some(following_node) = following_node { @@ -788,27 +838,16 @@ pub fn decorate_trivia(tokens: Vec, python_ast: &[Stmt]) -> TriviaI unreachable!("Attach token to the ast: {:?}", token); } } - TriviaTokenKind::LeftParen => { - // if let Some(enclosed_node) = enclosed_node { - // add_comment( - // Trivia::from_token(&token, Relationship::Leading), - // enclosed_node, - // &mut trivia_index, - // ); - // } else { - // unreachable!("Attach token to the ast: {:?}", token); - // } - } - TriviaTokenKind::RightParen => { - // if let Some(preceding) = preceding { - // add_comment( - // Trivia::from_token(&token, Relationship::Trailing), - // preceding, - // &mut trivia_index, - // ); - // } else { - // unreachable!("Attach token to the ast: {:?}", token); - // } + TriviaTokenKind::Parentheses => { + if let Some(enclosed_node) = enclosed_node { + add_comment( + Trivia::from_token(&token, Relationship::Leading), + enclosed_node, + &mut trivia_index, + ); + } else { + unreachable!("Attach token to the ast: {:?}", token); + } } } }