diff --git a/crates/oxc_formatter/src/ast_nodes/impls/ast_nodes.rs b/crates/oxc_formatter/src/ast_nodes/impls/ast_nodes.rs new file mode 100644 index 0000000000000..adeacf8da2229 --- /dev/null +++ b/crates/oxc_formatter/src/ast_nodes/impls/ast_nodes.rs @@ -0,0 +1,26 @@ +//! Implementations of methods for [`AstNodes`]. + +use crate::ast_nodes::AstNodes; + +impl<'a> AstNodes<'a> { + /// Returns an iterator over all ancestor nodes in the AST, starting from self. + /// + /// The iteration includes the current node and proceeds upward through the tree, + /// terminating after yielding the root `Program` node. + /// + /// # Example hierarchy + /// ```text + /// Program + /// └─ BlockStatement + /// └─ ExpressionStatement <- self + /// ``` + /// For `self` as ExpressionStatement, this yields: [ExpressionStatement, BlockStatement, Program] + pub fn ancestors(&self) -> impl Iterator> { + // Start with the current node and walk up the tree, including Program + std::iter::successors(Some(self), |node| { + // Continue iteration until we've yielded Program (root node) + // After Program, parent() would still return Program, so stop there + if matches!(node, AstNodes::Program(_)) { None } else { Some(node.parent()) } + }) + } +} diff --git a/crates/oxc_formatter/src/ast_nodes/impls/mod.rs b/crates/oxc_formatter/src/ast_nodes/impls/mod.rs new file mode 100644 index 0000000000000..fd0d37df85b8b --- /dev/null +++ b/crates/oxc_formatter/src/ast_nodes/impls/mod.rs @@ -0,0 +1 @@ +pub mod ast_nodes; diff --git a/crates/oxc_formatter/src/ast_nodes/mod.rs b/crates/oxc_formatter/src/ast_nodes/mod.rs index 59b9c868003af..07d155811fb17 100644 --- a/crates/oxc_formatter/src/ast_nodes/mod.rs +++ b/crates/oxc_formatter/src/ast_nodes/mod.rs @@ -1,4 +1,5 @@ pub mod generated; +pub mod impls; mod iterator; mod node; diff --git a/crates/oxc_formatter/src/ast_nodes/node.rs b/crates/oxc_formatter/src/ast_nodes/node.rs index 492e95e3186b2..04f04395b3210 100644 --- a/crates/oxc_formatter/src/ast_nodes/node.rs +++ b/crates/oxc_formatter/src/ast_nodes/node.rs @@ -63,6 +63,48 @@ impl GetSpan for &AstNode<'_, T> { } } +impl AstNode<'_, T> { + /// Returns an iterator over all ancestor nodes in the AST, starting from self. + /// + /// The iteration includes the current node and proceeds upward through the tree, + /// terminating after yielding the root `Program` node. + /// + /// This is a convenience method that delegates to `self.parent.ancestors()`. + /// + /// # Example + /// ```text + /// Program + /// └─ BlockStatement + /// └─ ExpressionStatement <- self + /// ``` + /// For `self` as ExpressionStatement, this yields: [ExpressionStatement, BlockStatement, Program] + /// + /// # Usage + /// ```ignore + /// // Find the first ancestor that matches a condition + /// let parent = self.ancestors() + /// .find(|p| matches!(p, AstNodes::ForStatement(_))) + /// .unwrap(); + /// + /// // Get the nth ancestor + /// let great_grandparent = self.ancestors().nth(3); + /// + /// // Check if any ancestor is a specific type + /// let in_arrow_fn = self.ancestors() + /// .any(|p| matches!(p, AstNodes::ArrowFunctionExpression(_))); + /// ``` + pub fn ancestors(&self) -> impl Iterator> { + self.parent.ancestors() + } + + /// Returns the grandparent node (parent's parent). + /// + /// This is a convenience method equivalent to `self.parent.parent()`. + pub fn grand_parent(&self) -> &AstNodes<'_> { + self.parent.parent() + } +} + impl<'a> AstNode<'a, Program<'a>> { pub fn new(inner: &'a Program<'a>, parent: &'a AstNodes<'a>, allocator: &'a Allocator) -> Self { AstNode { inner, parent, allocator, following_span: None } diff --git a/crates/oxc_formatter/src/parentheses/expression.rs b/crates/oxc_formatter/src/parentheses/expression.rs index e4eb88c4f3d61..73339ae14841c 100644 --- a/crates/oxc_formatter/src/parentheses/expression.rs +++ b/crates/oxc_formatter/src/parentheses/expression.rs @@ -84,20 +84,20 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, IdentifierReference<'a>> { matches!(self.parent, AstNodes::ForOfStatement(stmt) if !stmt.r#await && stmt.left.span().contains_inclusive(self.span)) } "let" => { - let mut parent = self.parent; - loop { + // Walk up ancestors to find the relevant context for `let` keyword + for parent in self.ancestors() { match parent { - AstNodes::Program(_) | AstNodes::ExpressionStatement(_) => return false, + AstNodes::ExpressionStatement(_) => return false, AstNodes::ForOfStatement(stmt) => { return stmt.left.span().contains_inclusive(self.span); } AstNodes::TSSatisfiesExpression(expr) => { return expr.expression.span() == self.span(); } - _ => parent = parent.parent(), + _ => {} } } - unreachable!() + false } name => { // @@ -131,7 +131,7 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, IdentifierReference<'a>> { matches!( parent, AstNodes::ExpressionStatement(stmt) if !matches!( - stmt.parent.parent(), AstNodes::ArrowFunctionExpression(arrow) + stmt.grand_parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression() ) ) @@ -392,8 +392,9 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, BinaryExpression<'a>> { /// Add parentheses if the `in` is inside of a `for` initializer (see tests). fn is_in_for_initializer(expr: &AstNode<'_, BinaryExpression<'_>>) -> bool { - let mut parent = expr.parent; - loop { + let mut ancestors = expr.ancestors(); + + while let Some(parent) = ancestors.next() { match parent { AstNodes::ExpressionStatement(stmt) => { let grand_parent = parent.parent(); @@ -404,7 +405,13 @@ fn is_in_for_initializer(expr: &AstNode<'_, BinaryExpression<'_>>) -> bool { grand_grand_parent, AstNodes::ArrowFunctionExpression(arrow) if arrow.expression() ) { - parent = grand_grand_parent; + // Skip ahead to grand_grand_parent by consuming ancestors + // until we reach it + for ancestor in ancestors.by_ref() { + if core::ptr::eq(ancestor, grand_grand_parent) { + break; + } + } continue; } } @@ -423,11 +430,11 @@ fn is_in_for_initializer(expr: &AstNode<'_, BinaryExpression<'_>>) -> bool { AstNodes::Program(_) => { return false; } - _ => { - parent = parent.parent(); - } + _ => {} } } + + false } impl<'a> NeedsParentheses<'a> for AstNode<'a, PrivateInExpression<'a>> { @@ -546,25 +553,20 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, AssignmentExpression<'a>> { // - `a = 1, b = 2` in for loops don't need parens // - `(a = 1, b = 2)` elsewhere usually need parens AstNodes::SequenceExpression(sequence) => { - let mut current_parent = self.parent; - loop { - match current_parent { - AstNodes::SequenceExpression(_) | AstNodes::ParenthesizedExpression(_) => { - current_parent = current_parent.parent(); - } - AstNodes::ForStatement(for_stmt) => { - let is_initializer = for_stmt - .init - .as_ref() - .is_some_and(|init| init.span().contains_inclusive(self.span())); - let is_update = for_stmt.update.as_ref().is_some_and(|update| { - update.span().contains_inclusive(self.span()) - }); - return !(is_initializer || is_update); - } - _ => break, + // Skip through SequenceExpression and ParenthesizedExpression ancestors + if let Some(ancestor) = self.ancestors().find(|p| { + !matches!(p, AstNodes::SequenceExpression(_) | AstNodes::ParenthesizedExpression(_)) + }) && let AstNodes::ForStatement(for_stmt) = ancestor { + let is_initializer = for_stmt + .init + .as_ref() + .is_some_and(|init| init.span().contains_inclusive(self.span())); + let is_update = for_stmt.update.as_ref().is_some_and(|update| { + update.span().contains_inclusive(self.span()) + }); + return !(is_initializer || is_update); } - } + true } // `interface { [a = 1]; }` and `class { [a = 1]; }` not need parens @@ -620,8 +622,8 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, SequenceExpression<'a>> { } } -impl<'a> NeedsParentheses<'a> for AstNode<'a, AwaitExpression<'a>> { - fn needs_parentheses(&self, f: &Formatter<'_, 'a>) -> bool { +impl NeedsParentheses<'_> for AstNode<'_, AwaitExpression<'_>> { + fn needs_parentheses(&self, f: &Formatter<'_, '_>) -> bool { if f.comments().is_type_cast_node(self) { return false; } @@ -977,14 +979,15 @@ pub enum FirstInStatementMode { /// the left most node or reached a statement. fn is_first_in_statement( mut current_span: Span, - mut parent: &AstNodes<'_>, + parent: &AstNodes<'_>, mode: FirstInStatementMode, ) -> bool { - let mut is_not_first_iteration = false; - loop { - match parent { + for (index, ancestor) in parent.ancestors().enumerate() { + let is_not_first_iteration = index > 0; + + match ancestor { AstNodes::ExpressionStatement(stmt) => { - if matches!(stmt.parent.parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression) + if matches!(stmt.grand_parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression) { if mode == FirstInStatementMode::ExpressionStatementOrArrow { if is_not_first_iteration @@ -1051,9 +1054,7 @@ fn is_first_in_statement( } _ => break, } - current_span = parent.span(); - parent = parent.parent(); - is_not_first_iteration = true; + current_span = ancestor.span(); } false diff --git a/crates/oxc_formatter/src/utils/call_expression.rs b/crates/oxc_formatter/src/utils/call_expression.rs index 3125ee0006b66..d52009824d619 100644 --- a/crates/oxc_formatter/src/utils/call_expression.rs +++ b/crates/oxc_formatter/src/utils/call_expression.rs @@ -37,7 +37,7 @@ pub fn is_test_call_expression(call: &AstNode>) -> bool { match (args.next(), args.next(), args.next()) { (Some(argument), None, None) if arguments.len() == 1 => { if is_angular_test_wrapper(call) && { - if let AstNodes::CallExpression(call) = call.parent.parent() { + if let AstNodes::CallExpression(call) = call.grand_parent() { is_test_call_expression(call) } else { false diff --git a/crates/oxc_formatter/src/utils/member_chain/mod.rs b/crates/oxc_formatter/src/utils/member_chain/mod.rs index 6b25b249f56f6..c366aecbc80b7 100644 --- a/crates/oxc_formatter/src/utils/member_chain/mod.rs +++ b/crates/oxc_formatter/src/utils/member_chain/mod.rs @@ -96,7 +96,7 @@ impl<'a, 'b> MemberChain<'a, 'b> { is_factory(&identifier.name) || // If an identifier has a name that is shorter than the tab with, then we join it with the "head" (matches!(parent, AstNodes::ExpressionStatement(stmt) if { - if let AstNodes::ArrowFunctionExpression(arrow) = stmt.parent.parent() { + if let AstNodes::ArrowFunctionExpression(arrow) = stmt.grand_parent() { !arrow.expression } else { true diff --git a/crates/oxc_formatter/src/write/call_arguments.rs b/crates/oxc_formatter/src/write/call_arguments.rs index 8b801bf7ac5f4..294d1bfeab720 100644 --- a/crates/oxc_formatter/src/write/call_arguments.rs +++ b/crates/oxc_formatter/src/write/call_arguments.rs @@ -115,7 +115,7 @@ impl<'a> Format<'a> for AstNode<'a, ArenaVec<'a, Argument<'a>>> { }); if has_empty_line - || (!matches!(self.parent.parent(), AstNodes::Decorator(_)) + || (!matches!(self.grand_parent(), AstNodes::Decorator(_)) && is_function_composition_args(self)) { return format_all_args_broken_out(self, true, f); diff --git a/crates/oxc_formatter/src/write/class.rs b/crates/oxc_formatter/src/write/class.rs index 717a1265b524c..76d331da0a965 100644 --- a/crates/oxc_formatter/src/write/class.rs +++ b/crates/oxc_formatter/src/write/class.rs @@ -425,7 +425,7 @@ impl<'a> Format<'a> for FormatClass<'a, '_> { } }); - if matches!(extends.parent.parent(), AstNodes::AssignmentExpression(_)) { + if matches!(extends.grand_parent(), AstNodes::AssignmentExpression(_)) { if has_trailing_comments { write!(f, [text("("), &content, text(")")]) } else { diff --git a/crates/oxc_formatter/src/write/jsx/element.rs b/crates/oxc_formatter/src/write/jsx/element.rs index e1870d46f7a0f..3d08c98c69f71 100644 --- a/crates/oxc_formatter/src/write/jsx/element.rs +++ b/crates/oxc_formatter/src/write/jsx/element.rs @@ -181,10 +181,10 @@ impl<'a> Format<'a> for AnyJsxTagWithChildren<'a, '_> { /// ; /// ``` pub fn should_expand(mut parent: &AstNodes<'_>) -> bool { - if matches!(parent, AstNodes::ExpressionStatement(_)) { + if let AstNodes::ExpressionStatement(stmt) = parent { // If the parent is a JSXExpressionContainer, we need to check its parent // to determine if it should expand. - parent = parent.parent().parent(); + parent = stmt.grand_parent(); } let maybe_jsx_expression_child = match parent { AstNodes::ArrowFunctionExpression(arrow) if arrow.expression => match arrow.parent { @@ -192,7 +192,7 @@ pub fn should_expand(mut parent: &AstNodes<'_>) -> bool { AstNodes::Argument(argument) if matches!(argument.parent, AstNodes::CallExpression(_)) => { - argument.parent.parent() + argument.grand_parent() } // Callee AstNodes::CallExpression(call) => call.parent, diff --git a/crates/oxc_formatter/src/write/mod.rs b/crates/oxc_formatter/src/write/mod.rs index 25d290120a964..762d89f7238a9 100644 --- a/crates/oxc_formatter/src/write/mod.rs +++ b/crates/oxc_formatter/src/write/mod.rs @@ -429,22 +429,22 @@ impl<'a> FormatWrite<'a> for AstNode<'a, AwaitExpression<'a>> { }; if is_callee_or_object { - let mut parent = self.parent.parent(); - loop { - match parent { - AstNodes::AwaitExpression(_) - | AstNodes::BlockStatement(_) + let mut await_expression = None; + for ancestor in self.ancestors().skip(1) { + match ancestor { + AstNodes::BlockStatement(_) | AstNodes::FunctionBody(_) | AstNodes::SwitchCase(_) | AstNodes::Program(_) | AstNodes::TSModuleBlock(_) => break, - _ => parent = parent.parent(), + AstNodes::AwaitExpression(expr) => await_expression = Some(expr), + _ => {} } } let indented = format_with(|f| write!(f, [soft_block_indent(&format_inner)])); - return if let AstNodes::AwaitExpression(expr) = parent { + return if let Some(expr) = await_expression.take() { if !expr.needs_parentheses(f) && ExpressionLeftSide::leftmost(expr.argument()).span() != self.span() { diff --git a/crates/oxc_formatter/src/write/parameters.rs b/crates/oxc_formatter/src/write/parameters.rs index 94603ba863b00..4c9ae8212a8a2 100644 --- a/crates/oxc_formatter/src/write/parameters.rs +++ b/crates/oxc_formatter/src/write/parameters.rs @@ -49,10 +49,10 @@ impl<'a> FormatWrite<'a> for AstNode<'a, FormalParameters<'a>> { let layout = if !self.has_parameter() && this_param.is_none() { ParameterLayout::NoParameters } else if can_hug || { - // `self.parent`: Function - // `self.parent.parent()`: Argument - // `self.parent.parent().parent()` CallExpression - if let AstNodes::CallExpression(call) = self.parent.parent().parent() { + // `self`: Function + // `self.ancestors().nth(1)`: Argument + // `self.ancestors().nth(2)`: CallExpression + if let Some(AstNodes::CallExpression(call)) = self.ancestors().nth(2) { is_test_call_expression(call) } else { false diff --git a/crates/oxc_formatter/src/write/sequence_expression.rs b/crates/oxc_formatter/src/write/sequence_expression.rs index 8ceb7342d00f3..174701f56588e 100644 --- a/crates/oxc_formatter/src/write/sequence_expression.rs +++ b/crates/oxc_formatter/src/write/sequence_expression.rs @@ -33,7 +33,7 @@ impl<'a> FormatWrite<'a> for AstNode<'a, SequenceExpression<'a>> { if matches!(self.parent, AstNodes::ForStatement(_)) || (matches!(self.parent, AstNodes::ExpressionStatement(statement) if - !matches!(statement.parent.parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression))) + !matches!(statement.grand_parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression))) { write!(f, [indent(&rest)]) } else { diff --git a/crates/oxc_formatter/src/write/type_parameters.rs b/crates/oxc_formatter/src/write/type_parameters.rs index f49d4cd410728..a3fb02d073e0e 100644 --- a/crates/oxc_formatter/src/write/type_parameters.rs +++ b/crates/oxc_formatter/src/write/type_parameters.rs @@ -69,7 +69,7 @@ impl<'a> Format<'a> for AstNode<'a, Vec<'a, TSTypeParameter<'a>>> { let trailing_separator = if self.len() == 1 // This only concern sources that allow JSX or a restricted standard variant. && f.context().source_type().is_jsx() - && matches!(self.parent.parent(), AstNodes::ArrowFunctionExpression(_)) + && matches!(self.grand_parent(), AstNodes::ArrowFunctionExpression(_)) // Ignore Type parameter with an `extends` clause or a default type. && !self.first().is_some_and(|t| t.constraint().is_some() || t.default().is_some()) { @@ -113,7 +113,7 @@ impl<'a> Format<'a> for FormatTSTypeParameters<'a, '_> { write!( f, [group(&format_args!("<", format_once(|f| { - if matches!( self.decl.parent.parent().parent(), AstNodes::CallExpression(call) if is_test_call_expression(call)) + if matches!(self.decl.ancestors().nth(2), Some(AstNodes::CallExpression(call)) if is_test_call_expression(call)) { f.join_nodes_with_space().entries_with_trailing_separator(params, ",", TrailingSeparator::Omit).finish() } else { @@ -204,7 +204,7 @@ fn is_arrow_function_variable_type_argument<'a>( // `node.parent` is `TSTypeReference` matches!( - &node.parent.parent(), + &node.grand_parent(), AstNodes::TSTypeAnnotation(type_annotation) if matches!( &type_annotation.parent, diff --git a/crates/oxc_formatter/src/write/union_type.rs b/crates/oxc_formatter/src/write/union_type.rs index c193a3a4ff23d..3b92008fb1f04 100644 --- a/crates/oxc_formatter/src/write/union_type.rs +++ b/crates/oxc_formatter/src/write/union_type.rs @@ -46,12 +46,10 @@ impl<'a> FormatWrite<'a> for AstNode<'a, TSUnionType<'a>> { let leading_comments = f.context().comments().comments_before(self.span().start); let has_leading_comments = !leading_comments.is_empty(); let mut union_type_at_top = self; - while let AstNodes::TSUnionType(parent) = union_type_at_top.parent { - if parent.types().len() == 1 { - union_type_at_top = parent; - } else { - break; - } + while let AstNodes::TSUnionType(parent) = union_type_at_top.parent + && parent.types().len() == 1 + { + union_type_at_top = parent; } let should_indent = { diff --git a/crates/oxc_formatter/src/write/variable_declaration.rs b/crates/oxc_formatter/src/write/variable_declaration.rs index 4dd14d65a2372..a9d3d0c6743ba 100644 --- a/crates/oxc_formatter/src/write/variable_declaration.rs +++ b/crates/oxc_formatter/src/write/variable_declaration.rs @@ -52,7 +52,7 @@ impl<'a> Format<'a> for AstNode<'a, Vec<'a, VariableDeclarator<'a>>> { let length = self.len(); let is_parent_for_loop = matches!( - self.parent.parent(), + self.grand_parent(), AstNodes::ForStatement(_) | AstNodes::ForInStatement(_) | AstNodes::ForOfStatement(_) );