diff --git a/crates/ruff_python_formatter/resources/test/fixtures/ruff/range_formatting/clause_header.py b/crates/ruff_python_formatter/resources/test/fixtures/ruff/range_formatting/clause_header.py index d30c32f11df71..7217b91ea7f32 100644 --- a/crates/ruff_python_formatter/resources/test/fixtures/ruff/range_formatting/clause_header.py +++ b/crates/ruff_python_formatter/resources/test/fixtures/ruff/range_formatting/clause_header.py @@ -42,3 +42,11 @@ def test4( a): if b + c : # trailing clause header comment print("Not formatted" ) + +def test5(): + x = 1 + try: + a; + finally: + b + diff --git a/crates/ruff_python_formatter/src/statement/clause.rs b/crates/ruff_python_formatter/src/statement/clause.rs index 7cc82ca923b08..a5c172f4f8743 100644 --- a/crates/ruff_python_formatter/src/statement/clause.rs +++ b/crates/ruff_python_formatter/src/statement/clause.rs @@ -216,7 +216,11 @@ impl ClauseHeader<'_> { .decorator_list .last() .map_or_else(|| header.start(), Ranged::end); - find_keyword(start_position, SimpleTokenKind::Class, source) + find_keyword( + StartPosition::ClauseStart(start_position), + SimpleTokenKind::Class, + source, + ) } ClauseHeader::Function(header) => { let start_position = header @@ -228,21 +232,39 @@ impl ClauseHeader<'_> { } else { SimpleTokenKind::Def }; - find_keyword(start_position, keyword, source) + find_keyword(StartPosition::ClauseStart(start_position), keyword, source) } - ClauseHeader::If(header) => find_keyword(header.start(), SimpleTokenKind::If, source), + ClauseHeader::If(header) => find_keyword( + StartPosition::clause_start(header), + SimpleTokenKind::If, + source, + ), ClauseHeader::ElifElse(ElifElseClause { test: None, range, .. - }) => find_keyword(range.start(), SimpleTokenKind::Else, source), + }) => find_keyword( + StartPosition::clause_start(range), + SimpleTokenKind::Else, + source, + ), ClauseHeader::ElifElse(ElifElseClause { test: Some(_), range, .. - }) => find_keyword(range.start(), SimpleTokenKind::Elif, source), - ClauseHeader::Try(header) => find_keyword(header.start(), SimpleTokenKind::Try, source), - ClauseHeader::ExceptHandler(header) => { - find_keyword(header.start(), SimpleTokenKind::Except, source) - } + }) => find_keyword( + StartPosition::clause_start(range), + SimpleTokenKind::Elif, + source, + ), + ClauseHeader::Try(header) => find_keyword( + StartPosition::clause_start(header), + SimpleTokenKind::Try, + source, + ), + ClauseHeader::ExceptHandler(header) => find_keyword( + StartPosition::clause_start(header), + SimpleTokenKind::Except, + source, + ), ClauseHeader::TryFinally(header) => { let last_statement = header .orelse @@ -252,25 +274,35 @@ impl ClauseHeader<'_> { .or_else(|| header.body.last().map(AnyNodeRef::from)) .unwrap(); - find_keyword(last_statement.end(), SimpleTokenKind::Finally, source) - } - ClauseHeader::Match(header) => { - find_keyword(header.start(), SimpleTokenKind::Match, source) - } - ClauseHeader::MatchCase(header) => { - find_keyword(header.start(), SimpleTokenKind::Case, source) + find_keyword( + StartPosition::LastStatement(last_statement.end()), + SimpleTokenKind::Finally, + source, + ) } + ClauseHeader::Match(header) => find_keyword( + StartPosition::clause_start(header), + SimpleTokenKind::Match, + source, + ), + ClauseHeader::MatchCase(header) => find_keyword( + StartPosition::clause_start(header), + SimpleTokenKind::Case, + source, + ), ClauseHeader::For(header) => { let keyword = if header.is_async { SimpleTokenKind::Async } else { SimpleTokenKind::For }; - find_keyword(header.start(), keyword, source) - } - ClauseHeader::While(header) => { - find_keyword(header.start(), SimpleTokenKind::While, source) + find_keyword(StartPosition::clause_start(header), keyword, source) } + ClauseHeader::While(header) => find_keyword( + StartPosition::clause_start(header), + SimpleTokenKind::While, + source, + ), ClauseHeader::With(header) => { let keyword = if header.is_async { SimpleTokenKind::Async @@ -278,7 +310,7 @@ impl ClauseHeader<'_> { SimpleTokenKind::With }; - find_keyword(header.start(), keyword, source) + find_keyword(StartPosition::clause_start(header), keyword, source) } ClauseHeader::OrElse(header) => match header { ElseClause::Try(try_stmt) => { @@ -289,12 +321,18 @@ impl ClauseHeader<'_> { .or_else(|| try_stmt.body.last().map(AnyNodeRef::from)) .unwrap(); - find_keyword(last_statement.end(), SimpleTokenKind::Else, source) + find_keyword( + StartPosition::LastStatement(last_statement.end()), + SimpleTokenKind::Else, + source, + ) } ElseClause::For(StmtFor { body, .. }) - | ElseClause::While(StmtWhile { body, .. }) => { - find_keyword(body.last().unwrap().end(), SimpleTokenKind::Else, source) - } + | ElseClause::While(StmtWhile { body, .. }) => find_keyword( + StartPosition::LastStatement(body.last().unwrap().end()), + SimpleTokenKind::Else, + source, + ), }, } } @@ -434,16 +472,41 @@ impl Format> for FormatClauseBody<'_> { } } -/// Finds the range of `keyword` starting the search at `start_position`. Expects only comments and `(` between -/// the `start_position` and the `keyword` token. +/// Finds the range of `keyword` starting the search at `start_position`. +/// +/// If the start position is at the end of the previous statement, the +/// search will skip the optional semi-colon at the end of that statement. +/// Other than this, we expect only trivia between the `start_position` +/// and the keyword. fn find_keyword( - start_position: TextSize, + start_position: StartPosition, keyword: SimpleTokenKind, source: &str, ) -> FormatResult { - let mut tokenizer = SimpleTokenizer::starts_at(start_position, source).skip_trivia(); + let next_token = match start_position { + StartPosition::ClauseStart(text_size) => SimpleTokenizer::starts_at(text_size, source) + .skip_trivia() + .next(), + StartPosition::LastStatement(text_size) => { + let mut tokenizer = SimpleTokenizer::starts_at(text_size, source).skip_trivia(); + + let mut token = tokenizer.next(); + + // If the last statement ends with a semi-colon, skip it. + if matches!( + token, + Some(SimpleToken { + kind: SimpleTokenKind::Semi, + .. + }) + ) { + token = tokenizer.next(); + } + token + } + }; - match tokenizer.next() { + match next_token { Some(token) if token.kind() == keyword => Ok(token.range()), Some(other) => { debug_assert!( @@ -466,6 +529,35 @@ fn find_keyword( } } +/// Offset directly before clause header. +/// +/// Can either be the beginning of the clause header +/// or the end of the last statement preceding the clause. +#[derive(Clone, Copy)] +enum StartPosition { + /// The beginning of a clause header + ClauseStart(TextSize), + /// The end of the last statement in the suite preceding a clause. + /// + /// For example: + /// ```python + /// if cond: + /// a + /// b + /// c; + /// # ...^here + /// else: + /// d + /// ``` + LastStatement(TextSize), +} + +impl StartPosition { + fn clause_start(ranged: impl Ranged) -> Self { + Self::ClauseStart(ranged.start()) + } +} + /// Returns the range of the `:` ending the clause header or `Err` if the colon can't be found. fn colon_range(after_keyword_or_condition: TextSize, source: &str) -> FormatResult { let mut tokenizer = SimpleTokenizer::starts_at(after_keyword_or_condition, source) diff --git a/crates/ruff_python_formatter/tests/snapshots/format@range_formatting__clause_header.py.snap b/crates/ruff_python_formatter/tests/snapshots/format@range_formatting__clause_header.py.snap index 8f16a68d08d6a..07e89d0649165 100644 --- a/crates/ruff_python_formatter/tests/snapshots/format@range_formatting__clause_header.py.snap +++ b/crates/ruff_python_formatter/tests/snapshots/format@range_formatting__clause_header.py.snap @@ -1,7 +1,6 @@ --- source: crates/ruff_python_formatter/tests/fixtures.rs input_file: crates/ruff_python_formatter/resources/test/fixtures/ruff/range_formatting/clause_header.py -snapshot_kind: text --- ## Input ```python @@ -49,6 +48,14 @@ def test4( a): if b + c : # trailing clause header comment print("Not formatted" ) + +def test5(): + x = 1 + try: + a; + finally: + b + ``` ## Output @@ -96,4 +103,11 @@ if a + b: # trailing clause header comment if b + c: # trailing clause header comment print("Not formatted" ) + +def test5(): + x = 1 + try: + a + finally: + b ```