diff --git a/crates/ruff_linter/src/rules/flake8_simplify/rules/needless_bool.rs b/crates/ruff_linter/src/rules/flake8_simplify/rules/needless_bool.rs index 569e53bf2b3d2..8f7a4745516e1 100644 --- a/crates/ruff_linter/src/rules/flake8_simplify/rules/needless_bool.rs +++ b/crates/ruff_linter/src/rules/flake8_simplify/rules/needless_bool.rs @@ -144,7 +144,7 @@ pub(crate) fn needless_bool(checker: &mut Checker, stmt: &Stmt) { .semantic() .current_statement_parent() .and_then(|parent| traversal::suite(stmt, parent)) - .and_then(|suite| traversal::next_sibling(stmt, suite)) + .and_then(|suite| suite.next_sibling()) else { return; }; diff --git a/crates/ruff_linter/src/rules/flake8_simplify/rules/reimplemented_builtin.rs b/crates/ruff_linter/src/rules/flake8_simplify/rules/reimplemented_builtin.rs index 6bcdb9c1fc232..5b81ed48df7f2 100644 --- a/crates/ruff_linter/src/rules/flake8_simplify/rules/reimplemented_builtin.rs +++ b/crates/ruff_linter/src/rules/flake8_simplify/rules/reimplemented_builtin.rs @@ -72,8 +72,7 @@ pub(crate) fn convert_for_loop_to_any_all(checker: &mut Checker, stmt: &Stmt) { // - `for` loop followed by `return True` or `return False`. let Some(terminal) = match_else_return(stmt).or_else(|| { let parent = checker.semantic().current_statement_parent()?; - let suite = traversal::suite(stmt, parent)?; - let sibling = traversal::next_sibling(stmt, suite)?; + let sibling = traversal::suite(stmt, parent)?.next_sibling()?; match_sibling_return(stmt, sibling) }) else { return; diff --git a/crates/ruff_linter/src/rules/refurb/rules/repeated_append.rs b/crates/ruff_linter/src/rules/refurb/rules/repeated_append.rs index 57e5867172f85..1489932e9f4ed 100644 --- a/crates/ruff_linter/src/rules/refurb/rules/repeated_append.rs +++ b/crates/ruff_linter/src/rules/refurb/rules/repeated_append.rs @@ -3,6 +3,7 @@ use rustc_hash::FxHashMap; use ast::traversal; use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::traversal::EnclosingSuite; use ruff_python_ast::{self as ast, Expr, Stmt}; use ruff_python_codegen::Generator; use ruff_python_semantic::analyze::typing::is_list; @@ -179,10 +180,10 @@ fn match_consecutive_appends<'a>( // In order to match consecutive statements, we need to go to the tree ancestor of the // given statement, find its position there, and match all 'appends' from there. - let siblings: &[Stmt] = if semantic.at_top_level() { + let suite = if semantic.at_top_level() { // If the statement is at the top level, we should go to the parent module. // Module is available in the definitions list. - semantic.definitions.python_ast()? + EnclosingSuite::new(semantic.definitions.python_ast()?, stmt)? } else { // Otherwise, go to the parent, and take its body as a sequence of siblings. semantic @@ -190,11 +191,12 @@ fn match_consecutive_appends<'a>( .and_then(|parent| traversal::suite(stmt, parent))? }; - let stmt_index = siblings.iter().position(|sibling| sibling == stmt)?; - // We shouldn't repeat the same work for many 'appends' that go in a row. Let's check // that this statement is at the beginning of such a group. - if stmt_index != 0 && match_append(semantic, &siblings[stmt_index - 1]).is_some() { + if suite + .previous_sibling() + .is_some_and(|previous_stmt| match_append(semantic, previous_stmt).is_some()) + { return None; } @@ -202,9 +204,9 @@ fn match_consecutive_appends<'a>( Some( std::iter::once(append) .chain( - siblings + suite + .next_siblings() .iter() - .skip(stmt_index + 1) .map_while(|sibling| match_append(semantic, sibling)), ) .collect(), diff --git a/crates/ruff_python_ast/src/traversal.rs b/crates/ruff_python_ast/src/traversal.rs index 1e050cfa94b1e..1803042e7775d 100644 --- a/crates/ruff_python_ast/src/traversal.rs +++ b/crates/ruff_python_ast/src/traversal.rs @@ -1,81 +1,81 @@ //! Utilities for manually traversing a Python AST. -use crate::{self as ast, ExceptHandler, Stmt, Suite}; +use crate::{self as ast, AnyNodeRef, ExceptHandler, Stmt}; -/// Given a [`Stmt`] and its parent, return the [`Suite`] that contains the [`Stmt`]. -pub fn suite<'a>(stmt: &'a Stmt, parent: &'a Stmt) -> Option<&'a Suite> { +/// Given a [`Stmt`] and its parent, return the [`ast::Suite`] that contains the [`Stmt`]. +pub fn suite<'a>(stmt: &'a Stmt, parent: &'a Stmt) -> Option> { // TODO: refactor this to work without a parent, ie when `stmt` is at the top level match parent { - Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) => Some(body), - Stmt::ClassDef(ast::StmtClassDef { body, .. }) => Some(body), - Stmt::For(ast::StmtFor { body, orelse, .. }) => { - if body.contains(stmt) { - Some(body) - } else if orelse.contains(stmt) { - Some(orelse) - } else { - None - } - } - Stmt::While(ast::StmtWhile { body, orelse, .. }) => { - if body.contains(stmt) { - Some(body) - } else if orelse.contains(stmt) { - Some(orelse) - } else { - None - } - } + Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) => EnclosingSuite::new(body, stmt), + Stmt::ClassDef(ast::StmtClassDef { body, .. }) => EnclosingSuite::new(body, stmt), + Stmt::For(ast::StmtFor { body, orelse, .. }) => [body, orelse] + .iter() + .find_map(|suite| EnclosingSuite::new(suite, stmt)), + Stmt::While(ast::StmtWhile { body, orelse, .. }) => [body, orelse] + .iter() + .find_map(|suite| EnclosingSuite::new(suite, stmt)), Stmt::If(ast::StmtIf { body, elif_else_clauses, .. - }) => { - if body.contains(stmt) { - Some(body) - } else { - elif_else_clauses - .iter() - .map(|elif_else_clause| &elif_else_clause.body) - .find(|body| body.contains(stmt)) - } - } - Stmt::With(ast::StmtWith { body, .. }) => Some(body), + }) => [body] + .into_iter() + .chain(elif_else_clauses.iter().map(|clause| &clause.body)) + .find_map(|suite| EnclosingSuite::new(suite, stmt)), + Stmt::With(ast::StmtWith { body, .. }) => EnclosingSuite::new(body, stmt), Stmt::Match(ast::StmtMatch { cases, .. }) => cases .iter() .map(|case| &case.body) - .find(|body| body.contains(stmt)), + .find_map(|body| EnclosingSuite::new(body, stmt)), Stmt::Try(ast::StmtTry { body, handlers, orelse, finalbody, .. - }) => { - if body.contains(stmt) { - Some(body) - } else if orelse.contains(stmt) { - Some(orelse) - } else if finalbody.contains(stmt) { - Some(finalbody) - } else { + }) => [body, orelse, finalbody] + .into_iter() + .chain( handlers .iter() .filter_map(ExceptHandler::as_except_handler) - .map(|handler| &handler.body) - .find(|body| body.contains(stmt)) - } - } + .map(|handler| &handler.body), + ) + .find_map(|suite| EnclosingSuite::new(suite, stmt)), _ => None, } } -/// Given a [`Stmt`] and its containing [`Suite`], return the next [`Stmt`] in the [`Suite`]. -pub fn next_sibling<'a>(stmt: &'a Stmt, suite: &'a Suite) -> Option<&'a Stmt> { - let mut iter = suite.iter(); - while let Some(sibling) = iter.next() { - if sibling == stmt { - return iter.next(); - } +pub struct EnclosingSuite<'a> { + suite: &'a [Stmt], + position: usize, +} + +impl<'a> EnclosingSuite<'a> { + pub fn new(suite: &'a [Stmt], stmt: &'a Stmt) -> Option { + let position = suite + .iter() + .position(|sibling| AnyNodeRef::ptr_eq(sibling.into(), stmt.into()))?; + + Some(EnclosingSuite { suite, position }) + } + + pub fn next_sibling(&self) -> Option<&'a Stmt> { + self.suite.get(self.position + 1) + } + + pub fn next_siblings(&self) -> &'a [Stmt] { + self.suite.get(self.position + 1..).unwrap_or_default() + } + + pub fn previous_sibling(&self) -> Option<&'a Stmt> { + self.suite.get(self.position.checked_sub(1)?) + } +} + +impl std::ops::Deref for EnclosingSuite<'_> { + type Target = [Stmt]; + + fn deref(&self) -> &Self::Target { + self.suite } - None }