Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions crates/oxc_formatter/src/ast_nodes/impls/ast_nodes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//! Implementations of methods for [`AstNodes`].

use crate::ast_nodes::AstNodes;

impl<'a> AstNodes<'a> {
/// Returns an iterator over all ancestor nodes in the AST, starting from self.
///
/// The iteration includes the current node and proceeds upward through the tree,
/// terminating after yielding the root `Program` node.
///
/// # Example hierarchy
/// ```text
/// Program
/// └─ BlockStatement
/// └─ ExpressionStatement <- self
/// ```
/// For `self` as ExpressionStatement, this yields: [ExpressionStatement, BlockStatement, Program]
pub fn ancestors(&self) -> impl Iterator<Item = &AstNodes<'a>> {
// Start with the current node and walk up the tree, including Program
std::iter::successors(Some(self), |node| {
// Continue iteration until we've yielded Program (root node)
// After Program, parent() would still return Program, so stop there
if matches!(node, AstNodes::Program(_)) { None } else { Some(node.parent()) }
})
}
}
1 change: 1 addition & 0 deletions crates/oxc_formatter/src/ast_nodes/impls/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod ast_nodes;
1 change: 1 addition & 0 deletions crates/oxc_formatter/src/ast_nodes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod generated;
pub mod impls;
mod iterator;
mod node;

Expand Down
42 changes: 42 additions & 0 deletions crates/oxc_formatter/src/ast_nodes/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,48 @@ impl<T: GetSpan> GetSpan for &AstNode<'_, T> {
}
}

impl<T> AstNode<'_, T> {
/// Returns an iterator over all ancestor nodes in the AST, starting from self.
///
/// The iteration includes the current node and proceeds upward through the tree,
/// terminating after yielding the root `Program` node.
///
/// This is a convenience method that delegates to `self.parent.ancestors()`.
///
/// # Example
/// ```text
/// Program
/// └─ BlockStatement
/// └─ ExpressionStatement <- self
/// ```
/// For `self` as ExpressionStatement, this yields: [ExpressionStatement, BlockStatement, Program]
///
/// # Usage
/// ```ignore
/// // Find the first ancestor that matches a condition
/// let parent = self.ancestors()
/// .find(|p| matches!(p, AstNodes::ForStatement(_)))
/// .unwrap();
///
/// // Get the nth ancestor
/// let great_grandparent = self.ancestors().nth(3);
///
/// // Check if any ancestor is a specific type
/// let in_arrow_fn = self.ancestors()
/// .any(|p| matches!(p, AstNodes::ArrowFunctionExpression(_)));
/// ```
pub fn ancestors(&self) -> impl Iterator<Item = &AstNodes<'_>> {
self.parent.ancestors()
}

/// Returns the grandparent node (parent's parent).
///
/// This is a convenience method equivalent to `self.parent.parent()`.
pub fn grand_parent(&self) -> &AstNodes<'_> {
self.parent.parent()
}
}

impl<'a> AstNode<'a, Program<'a>> {
pub fn new(inner: &'a Program<'a>, parent: &'a AstNodes<'a>, allocator: &'a Allocator) -> Self {
AstNode { inner, parent, allocator, following_span: None }
Expand Down
81 changes: 41 additions & 40 deletions crates/oxc_formatter/src/parentheses/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,20 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, IdentifierReference<'a>> {
matches!(self.parent, AstNodes::ForOfStatement(stmt) if !stmt.r#await && stmt.left.span().contains_inclusive(self.span))
}
"let" => {
let mut parent = self.parent;
loop {
// Walk up ancestors to find the relevant context for `let` keyword
for parent in self.ancestors() {
match parent {
AstNodes::Program(_) | AstNodes::ExpressionStatement(_) => return false,
AstNodes::ExpressionStatement(_) => return false,
AstNodes::ForOfStatement(stmt) => {
return stmt.left.span().contains_inclusive(self.span);
}
AstNodes::TSSatisfiesExpression(expr) => {
return expr.expression.span() == self.span();
}
_ => parent = parent.parent(),
_ => {}
}
}
unreachable!()
false
}
name => {
// <https://github.com/prettier/prettier/blob/7584432401a47a26943dd7a9ca9a8e032ead7285/src/language-js/needs-parens.js#L123-L133>
Expand Down Expand Up @@ -131,7 +131,7 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, IdentifierReference<'a>> {
matches!(
parent, AstNodes::ExpressionStatement(stmt) if
!matches!(
stmt.parent.parent(), AstNodes::ArrowFunctionExpression(arrow)
stmt.grand_parent(), AstNodes::ArrowFunctionExpression(arrow)
if arrow.expression()
)
)
Expand Down Expand Up @@ -392,8 +392,9 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, BinaryExpression<'a>> {

/// Add parentheses if the `in` is inside of a `for` initializer (see tests).
fn is_in_for_initializer(expr: &AstNode<'_, BinaryExpression<'_>>) -> bool {
let mut parent = expr.parent;
loop {
let mut ancestors = expr.ancestors();

while let Some(parent) = ancestors.next() {
match parent {
AstNodes::ExpressionStatement(stmt) => {
let grand_parent = parent.parent();
Expand All @@ -404,7 +405,13 @@ fn is_in_for_initializer(expr: &AstNode<'_, BinaryExpression<'_>>) -> bool {
grand_grand_parent,
AstNodes::ArrowFunctionExpression(arrow) if arrow.expression()
) {
parent = grand_grand_parent;
// Skip ahead to grand_grand_parent by consuming ancestors
// until we reach it
for ancestor in ancestors.by_ref() {
if core::ptr::eq(ancestor, grand_grand_parent) {
break;
}
}
continue;
}
}
Expand All @@ -423,11 +430,11 @@ fn is_in_for_initializer(expr: &AstNode<'_, BinaryExpression<'_>>) -> bool {
AstNodes::Program(_) => {
return false;
}
_ => {
parent = parent.parent();
}
_ => {}
}
}

false
}

impl<'a> NeedsParentheses<'a> for AstNode<'a, PrivateInExpression<'a>> {
Expand Down Expand Up @@ -546,25 +553,20 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, AssignmentExpression<'a>> {
// - `a = 1, b = 2` in for loops don't need parens
// - `(a = 1, b = 2)` elsewhere usually need parens
AstNodes::SequenceExpression(sequence) => {
let mut current_parent = self.parent;
loop {
match current_parent {
AstNodes::SequenceExpression(_) | AstNodes::ParenthesizedExpression(_) => {
current_parent = current_parent.parent();
}
AstNodes::ForStatement(for_stmt) => {
let is_initializer = for_stmt
.init
.as_ref()
.is_some_and(|init| init.span().contains_inclusive(self.span()));
let is_update = for_stmt.update.as_ref().is_some_and(|update| {
update.span().contains_inclusive(self.span())
});
return !(is_initializer || is_update);
}
_ => break,
// Skip through SequenceExpression and ParenthesizedExpression ancestors
if let Some(ancestor) = self.ancestors().find(|p| {
!matches!(p, AstNodes::SequenceExpression(_) | AstNodes::ParenthesizedExpression(_))
}) && let AstNodes::ForStatement(for_stmt) = ancestor {
let is_initializer = for_stmt
.init
.as_ref()
.is_some_and(|init| init.span().contains_inclusive(self.span()));
let is_update = for_stmt.update.as_ref().is_some_and(|update| {
update.span().contains_inclusive(self.span())
});
return !(is_initializer || is_update);
}
}

true
}
// `interface { [a = 1]; }` and `class { [a = 1]; }` not need parens
Expand Down Expand Up @@ -620,8 +622,8 @@ impl<'a> NeedsParentheses<'a> for AstNode<'a, SequenceExpression<'a>> {
}
}

impl<'a> NeedsParentheses<'a> for AstNode<'a, AwaitExpression<'a>> {
fn needs_parentheses(&self, f: &Formatter<'_, 'a>) -> bool {
impl NeedsParentheses<'_> for AstNode<'_, AwaitExpression<'_>> {
fn needs_parentheses(&self, f: &Formatter<'_, '_>) -> bool {
if f.comments().is_type_cast_node(self) {
return false;
}
Expand Down Expand Up @@ -977,14 +979,15 @@ pub enum FirstInStatementMode {
/// the left most node or reached a statement.
fn is_first_in_statement(
mut current_span: Span,
mut parent: &AstNodes<'_>,
parent: &AstNodes<'_>,
mode: FirstInStatementMode,
) -> bool {
let mut is_not_first_iteration = false;
loop {
match parent {
for (index, ancestor) in parent.ancestors().enumerate() {
let is_not_first_iteration = index > 0;

match ancestor {
AstNodes::ExpressionStatement(stmt) => {
if matches!(stmt.parent.parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression)
if matches!(stmt.grand_parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression)
{
if mode == FirstInStatementMode::ExpressionStatementOrArrow {
if is_not_first_iteration
Expand Down Expand Up @@ -1051,9 +1054,7 @@ fn is_first_in_statement(
}
_ => break,
}
current_span = parent.span();
parent = parent.parent();
is_not_first_iteration = true;
current_span = ancestor.span();
}

false
Expand Down
2 changes: 1 addition & 1 deletion crates/oxc_formatter/src/utils/call_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub fn is_test_call_expression(call: &AstNode<CallExpression<'_>>) -> bool {
match (args.next(), args.next(), args.next()) {
(Some(argument), None, None) if arguments.len() == 1 => {
if is_angular_test_wrapper(call) && {
if let AstNodes::CallExpression(call) = call.parent.parent() {
if let AstNodes::CallExpression(call) = call.grand_parent() {
is_test_call_expression(call)
} else {
false
Expand Down
2 changes: 1 addition & 1 deletion crates/oxc_formatter/src/utils/member_chain/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl<'a, 'b> MemberChain<'a, 'b> {
is_factory(&identifier.name) ||
// If an identifier has a name that is shorter than the tab with, then we join it with the "head"
(matches!(parent, AstNodes::ExpressionStatement(stmt) if {
if let AstNodes::ArrowFunctionExpression(arrow) = stmt.parent.parent() {
if let AstNodes::ArrowFunctionExpression(arrow) = stmt.grand_parent() {
!arrow.expression
} else {
true
Expand Down
2 changes: 1 addition & 1 deletion crates/oxc_formatter/src/write/call_arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl<'a> Format<'a> for AstNode<'a, ArenaVec<'a, Argument<'a>>> {
});

if has_empty_line
|| (!matches!(self.parent.parent(), AstNodes::Decorator(_))
|| (!matches!(self.grand_parent(), AstNodes::Decorator(_))
&& is_function_composition_args(self))
{
return format_all_args_broken_out(self, true, f);
Expand Down
2 changes: 1 addition & 1 deletion crates/oxc_formatter/src/write/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ impl<'a> Format<'a> for FormatClass<'a, '_> {
}
});

if matches!(extends.parent.parent(), AstNodes::AssignmentExpression(_)) {
if matches!(extends.grand_parent(), AstNodes::AssignmentExpression(_)) {
if has_trailing_comments {
write!(f, [text("("), &content, text(")")])
} else {
Expand Down
6 changes: 3 additions & 3 deletions crates/oxc_formatter/src/write/jsx/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,18 @@ impl<'a> Format<'a> for AnyJsxTagWithChildren<'a, '_> {
/// </div>;
/// ```
pub fn should_expand(mut parent: &AstNodes<'_>) -> bool {
if matches!(parent, AstNodes::ExpressionStatement(_)) {
if let AstNodes::ExpressionStatement(stmt) = parent {
// If the parent is a JSXExpressionContainer, we need to check its parent
// to determine if it should expand.
parent = parent.parent().parent();
parent = stmt.grand_parent();
}
let maybe_jsx_expression_child = match parent {
AstNodes::ArrowFunctionExpression(arrow) if arrow.expression => match arrow.parent {
// Argument
AstNodes::Argument(argument)
if matches!(argument.parent, AstNodes::CallExpression(_)) =>
{
argument.parent.parent()
argument.grand_parent()
}
// Callee
AstNodes::CallExpression(call) => call.parent,
Expand Down
14 changes: 7 additions & 7 deletions crates/oxc_formatter/src/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,22 +429,22 @@ impl<'a> FormatWrite<'a> for AstNode<'a, AwaitExpression<'a>> {
};

if is_callee_or_object {
let mut parent = self.parent.parent();
loop {
match parent {
AstNodes::AwaitExpression(_)
| AstNodes::BlockStatement(_)
let mut await_expression = None;
for ancestor in self.ancestors().skip(1) {
match ancestor {
AstNodes::BlockStatement(_)
| AstNodes::FunctionBody(_)
| AstNodes::SwitchCase(_)
| AstNodes::Program(_)
| AstNodes::TSModuleBlock(_) => break,
_ => parent = parent.parent(),
AstNodes::AwaitExpression(expr) => await_expression = Some(expr),
_ => {}
}
}

let indented = format_with(|f| write!(f, [soft_block_indent(&format_inner)]));

return if let AstNodes::AwaitExpression(expr) = parent {
return if let Some(expr) = await_expression.take() {
if !expr.needs_parentheses(f)
&& ExpressionLeftSide::leftmost(expr.argument()).span() != self.span()
{
Expand Down
8 changes: 4 additions & 4 deletions crates/oxc_formatter/src/write/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ impl<'a> FormatWrite<'a> for AstNode<'a, FormalParameters<'a>> {
let layout = if !self.has_parameter() && this_param.is_none() {
ParameterLayout::NoParameters
} else if can_hug || {
// `self.parent`: Function
// `self.parent.parent()`: Argument
// `self.parent.parent().parent()` CallExpression
if let AstNodes::CallExpression(call) = self.parent.parent().parent() {
// `self`: Function
// `self.ancestors().nth(1)`: Argument
// `self.ancestors().nth(2)`: CallExpression
if let Some(AstNodes::CallExpression(call)) = self.ancestors().nth(2) {
is_test_call_expression(call)
} else {
false
Expand Down
2 changes: 1 addition & 1 deletion crates/oxc_formatter/src/write/sequence_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl<'a> FormatWrite<'a> for AstNode<'a, SequenceExpression<'a>> {

if matches!(self.parent, AstNodes::ForStatement(_))
|| (matches!(self.parent, AstNodes::ExpressionStatement(statement) if
!matches!(statement.parent.parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression)))
!matches!(statement.grand_parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression)))
{
write!(f, [indent(&rest)])
} else {
Expand Down
6 changes: 3 additions & 3 deletions crates/oxc_formatter/src/write/type_parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl<'a> Format<'a> for AstNode<'a, Vec<'a, TSTypeParameter<'a>>> {
let trailing_separator = if self.len() == 1
// This only concern sources that allow JSX or a restricted standard variant.
&& f.context().source_type().is_jsx()
&& matches!(self.parent.parent(), AstNodes::ArrowFunctionExpression(_))
&& matches!(self.grand_parent(), AstNodes::ArrowFunctionExpression(_))
// Ignore Type parameter with an `extends` clause or a default type.
&& !self.first().is_some_and(|t| t.constraint().is_some() || t.default().is_some())
{
Expand Down Expand Up @@ -113,7 +113,7 @@ impl<'a> Format<'a> for FormatTSTypeParameters<'a, '_> {
write!(
f,
[group(&format_args!("<", format_once(|f| {
if matches!( self.decl.parent.parent().parent(), AstNodes::CallExpression(call) if is_test_call_expression(call))
if matches!(self.decl.ancestors().nth(2), Some(AstNodes::CallExpression(call)) if is_test_call_expression(call))
{
f.join_nodes_with_space().entries_with_trailing_separator(params, ",", TrailingSeparator::Omit).finish()
} else {
Expand Down Expand Up @@ -204,7 +204,7 @@ fn is_arrow_function_variable_type_argument<'a>(

// `node.parent` is `TSTypeReference`
matches!(
&node.parent.parent(),
&node.grand_parent(),
AstNodes::TSTypeAnnotation(type_annotation)
if matches!(
&type_annotation.parent,
Expand Down
10 changes: 4 additions & 6 deletions crates/oxc_formatter/src/write/union_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,10 @@ impl<'a> FormatWrite<'a> for AstNode<'a, TSUnionType<'a>> {
let leading_comments = f.context().comments().comments_before(self.span().start);
let has_leading_comments = !leading_comments.is_empty();
let mut union_type_at_top = self;
while let AstNodes::TSUnionType(parent) = union_type_at_top.parent {
if parent.types().len() == 1 {
union_type_at_top = parent;
} else {
break;
}
while let AstNodes::TSUnionType(parent) = union_type_at_top.parent
&& parent.types().len() == 1
{
union_type_at_top = parent;
}

let should_indent = {
Expand Down
Loading
Loading