@@ -216,7 +216,11 @@ impl ClauseHeader<'_> {
216216 . decorator_list
217217 . last ( )
218218 . map_or_else ( || header. start ( ) , Ranged :: end) ;
219- find_keyword ( start_position, SimpleTokenKind :: Class , source)
219+ find_keyword (
220+ StartPosition :: ClauseStart ( start_position) ,
221+ SimpleTokenKind :: Class ,
222+ source,
223+ )
220224 }
221225 ClauseHeader :: Function ( header) => {
222226 let start_position = header
@@ -228,21 +232,39 @@ impl ClauseHeader<'_> {
228232 } else {
229233 SimpleTokenKind :: Def
230234 } ;
231- find_keyword ( start_position, keyword, source)
235+ find_keyword ( StartPosition :: ClauseStart ( start_position) , keyword, source)
232236 }
233- ClauseHeader :: If ( header) => find_keyword ( header. start ( ) , SimpleTokenKind :: If , source) ,
237+ ClauseHeader :: If ( header) => find_keyword (
238+ StartPosition :: clause_start ( header) ,
239+ SimpleTokenKind :: If ,
240+ source,
241+ ) ,
234242 ClauseHeader :: ElifElse ( ElifElseClause {
235243 test : None , range, ..
236- } ) => find_keyword ( range. start ( ) , SimpleTokenKind :: Else , source) ,
244+ } ) => find_keyword (
245+ StartPosition :: clause_start ( range) ,
246+ SimpleTokenKind :: Else ,
247+ source,
248+ ) ,
237249 ClauseHeader :: ElifElse ( ElifElseClause {
238250 test : Some ( _) ,
239251 range,
240252 ..
241- } ) => find_keyword ( range. start ( ) , SimpleTokenKind :: Elif , source) ,
242- ClauseHeader :: Try ( header) => find_keyword ( header. start ( ) , SimpleTokenKind :: Try , source) ,
243- ClauseHeader :: ExceptHandler ( header) => {
244- find_keyword ( header. start ( ) , SimpleTokenKind :: Except , source)
245- }
253+ } ) => find_keyword (
254+ StartPosition :: clause_start ( range) ,
255+ SimpleTokenKind :: Elif ,
256+ source,
257+ ) ,
258+ ClauseHeader :: Try ( header) => find_keyword (
259+ StartPosition :: clause_start ( header) ,
260+ SimpleTokenKind :: Try ,
261+ source,
262+ ) ,
263+ ClauseHeader :: ExceptHandler ( header) => find_keyword (
264+ StartPosition :: clause_start ( header) ,
265+ SimpleTokenKind :: Except ,
266+ source,
267+ ) ,
246268 ClauseHeader :: TryFinally ( header) => {
247269 let last_statement = header
248270 . orelse
@@ -252,33 +274,43 @@ impl ClauseHeader<'_> {
252274 . or_else ( || header. body . last ( ) . map ( AnyNodeRef :: from) )
253275 . unwrap ( ) ;
254276
255- find_keyword ( last_statement. end ( ) , SimpleTokenKind :: Finally , source)
256- }
257- ClauseHeader :: Match ( header) => {
258- find_keyword ( header. start ( ) , SimpleTokenKind :: Match , source)
259- }
260- ClauseHeader :: MatchCase ( header) => {
261- find_keyword ( header. start ( ) , SimpleTokenKind :: Case , source)
277+ find_keyword (
278+ StartPosition :: LastStatement ( last_statement. end ( ) ) ,
279+ SimpleTokenKind :: Finally ,
280+ source,
281+ )
262282 }
283+ ClauseHeader :: Match ( header) => find_keyword (
284+ StartPosition :: clause_start ( header) ,
285+ SimpleTokenKind :: Match ,
286+ source,
287+ ) ,
288+ ClauseHeader :: MatchCase ( header) => find_keyword (
289+ StartPosition :: clause_start ( header) ,
290+ SimpleTokenKind :: Case ,
291+ source,
292+ ) ,
263293 ClauseHeader :: For ( header) => {
264294 let keyword = if header. is_async {
265295 SimpleTokenKind :: Async
266296 } else {
267297 SimpleTokenKind :: For
268298 } ;
269- find_keyword ( header. start ( ) , keyword, source)
270- }
271- ClauseHeader :: While ( header) => {
272- find_keyword ( header. start ( ) , SimpleTokenKind :: While , source)
299+ find_keyword ( StartPosition :: clause_start ( header) , keyword, source)
273300 }
301+ ClauseHeader :: While ( header) => find_keyword (
302+ StartPosition :: clause_start ( header) ,
303+ SimpleTokenKind :: While ,
304+ source,
305+ ) ,
274306 ClauseHeader :: With ( header) => {
275307 let keyword = if header. is_async {
276308 SimpleTokenKind :: Async
277309 } else {
278310 SimpleTokenKind :: With
279311 } ;
280312
281- find_keyword ( header . start ( ) , keyword, source)
313+ find_keyword ( StartPosition :: clause_start ( header ) , keyword, source)
282314 }
283315 ClauseHeader :: OrElse ( header) => match header {
284316 ElseClause :: Try ( try_stmt) => {
@@ -289,12 +321,18 @@ impl ClauseHeader<'_> {
289321 . or_else ( || try_stmt. body . last ( ) . map ( AnyNodeRef :: from) )
290322 . unwrap ( ) ;
291323
292- find_keyword ( last_statement. end ( ) , SimpleTokenKind :: Else , source)
324+ find_keyword (
325+ StartPosition :: LastStatement ( last_statement. end ( ) ) ,
326+ SimpleTokenKind :: Else ,
327+ source,
328+ )
293329 }
294330 ElseClause :: For ( StmtFor { body, .. } )
295- | ElseClause :: While ( StmtWhile { body, .. } ) => {
296- find_keyword ( body. last ( ) . unwrap ( ) . end ( ) , SimpleTokenKind :: Else , source)
297- }
331+ | ElseClause :: While ( StmtWhile { body, .. } ) => find_keyword (
332+ StartPosition :: LastStatement ( body. last ( ) . unwrap ( ) . end ( ) ) ,
333+ SimpleTokenKind :: Else ,
334+ source,
335+ ) ,
298336 } ,
299337 }
300338 }
@@ -434,16 +472,41 @@ impl Format<PyFormatContext<'_>> for FormatClauseBody<'_> {
434472 }
435473}
436474
437- /// Finds the range of `keyword` starting the search at `start_position`. Expects only comments and `(` between
438- /// the `start_position` and the `keyword` token.
475+ /// Finds the range of `keyword` starting the search at `start_position`.
476+ ///
477+ /// If the start position is at the end of the previous statement, the
478+ /// search will skip the optional semi-colon at the end of that statement.
479+ /// Other than this, we expect only trivia between the `start_position`
480+ /// and the keyword.
439481fn find_keyword (
440- start_position : TextSize ,
482+ start_position : StartPosition ,
441483 keyword : SimpleTokenKind ,
442484 source : & str ,
443485) -> FormatResult < TextRange > {
444- let mut tokenizer = SimpleTokenizer :: starts_at ( start_position, source) . skip_trivia ( ) ;
486+ let next_token = match start_position {
487+ StartPosition :: ClauseStart ( text_size) => SimpleTokenizer :: starts_at ( text_size, source)
488+ . skip_trivia ( )
489+ . next ( ) ,
490+ StartPosition :: LastStatement ( text_size) => {
491+ let mut tokenizer = SimpleTokenizer :: starts_at ( text_size, source) . skip_trivia ( ) ;
492+
493+ let mut token = tokenizer. next ( ) ;
494+
495+ // If the last statement ends with a semi-colon, skip it.
496+ if matches ! (
497+ token,
498+ Some ( SimpleToken {
499+ kind: SimpleTokenKind :: Semi ,
500+ ..
501+ } )
502+ ) {
503+ token = tokenizer. next ( ) ;
504+ }
505+ token
506+ }
507+ } ;
445508
446- match tokenizer . next ( ) {
509+ match next_token {
447510 Some ( token) if token. kind ( ) == keyword => Ok ( token. range ( ) ) ,
448511 Some ( other) => {
449512 debug_assert ! (
@@ -466,6 +529,35 @@ fn find_keyword(
466529 }
467530}
468531
532+ /// Offset directly before clause header.
533+ ///
534+ /// Can either be the beginning of the clause header
535+ /// or the end of the last statement preceding the clause.
536+ #[ derive( Clone , Copy ) ]
537+ enum StartPosition {
538+ /// The beginning of a clause header
539+ ClauseStart ( TextSize ) ,
540+ /// The end of the last statement in the suite preceding a clause.
541+ ///
542+ /// For example:
543+ /// ```python
544+ /// if cond:
545+ /// a
546+ /// b
547+ /// c;
548+ /// # ...^here
549+ /// else:
550+ /// d
551+ /// ```
552+ LastStatement ( TextSize ) ,
553+ }
554+
555+ impl StartPosition {
556+ fn clause_start ( ranged : impl Ranged ) -> Self {
557+ Self :: ClauseStart ( ranged. start ( ) )
558+ }
559+ }
560+
469561/// Returns the range of the `:` ending the clause header or `Err` if the colon can't be found.
470562fn colon_range ( after_keyword_or_condition : TextSize , source : & str ) -> FormatResult < TextRange > {
471563 let mut tokenizer = SimpleTokenizer :: starts_at ( after_keyword_or_condition, source)
0 commit comments