Skip to content

Commit 9fc99f3

Browse files
committed
feat(formatter): introduce AstNode<ExpressionStatement>::is_arrow_function_body
1 parent 6095569 commit 9fc99f3

File tree

6 files changed

+32
-70
lines changed

6 files changed

+32
-70
lines changed

crates/oxc_formatter/src/ast_nodes/node.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use core::fmt;
22
use std::ops::Deref;
33

44
use oxc_allocator::Allocator;
5-
use oxc_ast::ast::Program;
5+
use oxc_ast::ast::{ExpressionStatement, Program};
66
use oxc_span::{GetSpan, Span};
77

88
use super::AstNodes;
@@ -110,3 +110,17 @@ impl<'a> AstNode<'a, Program<'a>> {
110110
AstNode { inner, parent, allocator, following_span: None }
111111
}
112112
}
113+
114+
impl<'a> AstNode<'a, ExpressionStatement<'a>> {
115+
/// Check if this ExpressionStatement is the body of an arrow function expression
116+
///
117+
/// Example:
118+
/// `() => expression;`
119+
/// ^^^^^^^^^^ This ExpressionStatement is the body of an arrow function
120+
///
121+
/// `() => { return expression; }`
122+
/// ^^^^^^^^^^^^^^^^^^^^ This ExpressionStatement is NOT the body of an arrow function
123+
pub fn is_arrow_function_body(&self) -> bool {
124+
matches!(self.parent.parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression)
125+
}
126+
}

crates/oxc_formatter/src/parentheses/expression.rs

Lines changed: 12 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,7 @@ impl NeedsParentheses<'_> for AstNode<'_, IdentifierReference<'_>> {
131131

132132
matches!(
133133
parent, AstNodes::ExpressionStatement(stmt) if
134-
!matches!(
135-
stmt.grand_parent(), AstNodes::ArrowFunctionExpression(arrow)
136-
if arrow.expression()
137-
)
134+
!stmt.is_arrow_function_body()
138135
)
139136
}
140137
}
@@ -211,15 +208,7 @@ impl NeedsParentheses<'_> for AstNode<'_, StringLiteral<'_>> {
211208

212209
if let AstNodes::ExpressionStatement(stmt) = self.parent {
213210
// `() => "foo"`
214-
if let AstNodes::FunctionBody(arrow) = stmt.parent {
215-
if let AstNodes::ArrowFunctionExpression(arrow) = arrow.parent {
216-
!arrow.expression()
217-
} else {
218-
true
219-
}
220-
} else {
221-
true
222-
}
211+
!stmt.is_arrow_function_body()
223212
} else {
224213
false
225214
}
@@ -400,21 +389,11 @@ fn is_in_for_initializer(expr: &AstNode<'_, BinaryExpression<'_>>) -> bool {
400389
AstNodes::ExpressionStatement(stmt) => {
401390
let grand_parent = parent.parent();
402391

403-
if matches!(grand_parent, AstNodes::FunctionBody(_)) {
404-
let grand_grand_parent = grand_parent.parent();
405-
if matches!(
406-
grand_grand_parent,
407-
AstNodes::ArrowFunctionExpression(arrow) if arrow.expression()
408-
) {
409-
// Skip ahead to grand_grand_parent by consuming ancestors
410-
// until we reach it
411-
for ancestor in ancestors.by_ref() {
412-
if core::ptr::eq(ancestor, grand_grand_parent) {
413-
break;
414-
}
415-
}
416-
continue;
417-
}
392+
if stmt.is_arrow_function_body() {
393+
// Skip `FunctionBody` and `ArrowFunctionExpression`
394+
let skipped = ancestors.by_ref().nth(1);
395+
debug_assert!(matches!(skipped, Some(AstNodes::ArrowFunctionExpression(_))));
396+
continue;
418397
}
419398

420399
return false;
@@ -534,15 +513,11 @@ impl NeedsParentheses<'_> for AstNode<'_, AssignmentExpression<'_>> {
534513
// - `{ x } = obj` -> `({ x } = obj)` = needed to prevent parsing as block statement
535514
// - `() => { x } = obj` -> `() => ({ x } = obj)` = needed in arrow function body
536515
// - `() => a = b` -> `() => (a = b)` = also parens needed
537-
AstNodes::ExpressionStatement(parent) => {
538-
let parent_parent = parent.parent;
539-
if let AstNodes::FunctionBody(body) = parent_parent {
540-
let parent_parent_parent = body.parent;
541-
if matches!(parent_parent_parent, AstNodes::ArrowFunctionExpression(arrow) if arrow.expression())
542-
{
543-
return true;
544-
}
516+
AstNodes::ExpressionStatement(stmt) => {
517+
if stmt.is_arrow_function_body() {
518+
return true;
545519
}
520+
546521
matches!(self.left, AssignmentTarget::ObjectAssignmentTarget(_))
547522
&& is_first_in_statement(
548523
self.span,
@@ -588,12 +563,6 @@ impl NeedsParentheses<'_> for AstNode<'_, AssignmentExpression<'_>> {
588563
stmt.update.as_ref().is_some_and(|update| update.span() == self.span());
589564
!(is_initializer || is_update)
590565
}
591-
// Arrow functions, only need parens if assignment is the direct body:
592-
// - `() => a = b` -> `() => (a = b)` = needed
593-
// - `() => someFunc(a = b)` = no extra parens needed
594-
AstNodes::ArrowFunctionExpression(arrow) => {
595-
arrow.expression() && arrow.body.span() == self.span()
596-
}
597566
// Default: need parentheses in most other contexts
598567
// - `new (a = b)`
599568
// - `(a = b).prop`
@@ -617,8 +586,6 @@ impl NeedsParentheses<'_> for AstNode<'_, SequenceExpression<'_>> {
617586
| AstNodes::ForStatement(_)
618587
| AstNodes::ExpressionStatement(_)
619588
| AstNodes::SequenceExpression(_)
620-
// Handled as part of the arrow function formatting
621-
| AstNodes::ArrowFunctionExpression(_)
622589
)
623590
}
624591
}
@@ -988,8 +955,7 @@ fn is_first_in_statement(
988955

989956
match ancestor {
990957
AstNodes::ExpressionStatement(stmt) => {
991-
if matches!(stmt.grand_parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression)
992-
{
958+
if stmt.is_arrow_function_body() {
993959
if mode == FirstInStatementMode::ExpressionStatementOrArrow {
994960
if is_not_first_iteration
995961
&& matches!(

crates/oxc_formatter/src/utils/jsx.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,7 @@ pub fn get_wrap_state(parent: &AstNodes<'_>) -> WrapState {
9090
AstNodes::ExpressionStatement(stmt) => {
9191
// `() => <div></div>`
9292
// ^^^^^^^^^^^
93-
if let AstNodes::FunctionBody(body) = stmt.parent
94-
&& matches!(body.parent, AstNodes::ArrowFunctionExpression(arrow) if arrow.expression)
95-
{
96-
WrapState::WrapOnBreak
97-
} else {
98-
WrapState::NoWrap
99-
}
93+
if stmt.is_arrow_function_body() { WrapState::WrapOnBreak } else { WrapState::NoWrap }
10094
}
10195
AstNodes::ComputedMemberExpression(member) => {
10296
if member.optional {

crates/oxc_formatter/src/utils/member_chain/mod.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,7 @@ impl<'a, 'b> MemberChain<'a, 'b> {
9595
has_computed_property ||
9696
is_factory(&identifier.name) ||
9797
// If an identifier has a name that is shorter than the tab with, then we join it with the "head"
98-
(matches!(parent, AstNodes::ExpressionStatement(stmt) if {
99-
if let AstNodes::ArrowFunctionExpression(arrow) = stmt.grand_parent() {
100-
!arrow.expression
101-
} else {
102-
true
103-
}
104-
})
98+
(matches!(parent, AstNodes::ExpressionStatement(stmt) if !stmt.is_arrow_function_body())
10599
&& has_short_name(&identifier.name, f.options().indent_width.value()))
106100
} else {
107101
matches!(node.as_ref(), Expression::ThisExpression(_))

crates/oxc_formatter/src/write/binary_like_expression.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,7 @@ impl<'a, 'b> BinaryLikeExpression<'a, 'b> {
158158
AstNodes::JSXExpressionContainer(container) => {
159159
matches!(container.parent, AstNodes::JSXAttribute(_))
160160
}
161-
AstNodes::ExpressionStatement(statement) => {
162-
if let AstNodes::FunctionBody(arrow) = statement.parent {
163-
arrow.span == self.span()
164-
} else {
165-
false
166-
}
167-
}
161+
AstNodes::ExpressionStatement(statement) => statement.is_arrow_function_body(),
168162
AstNodes::ConditionalExpression(conditional) => {
169163
!matches!(
170164
parent.parent(),

crates/oxc_formatter/src/write/sequence_expression.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ impl<'a> FormatWrite<'a> for AstNode<'a, SequenceExpression<'a>> {
3232
});
3333

3434
if matches!(self.parent, AstNodes::ForStatement(_))
35-
|| (matches!(self.parent, AstNodes::ExpressionStatement(statement) if
36-
!matches!(statement.grand_parent(), AstNodes::ArrowFunctionExpression(arrow) if arrow.expression)))
35+
|| (matches!(self.parent, AstNodes::ExpressionStatement(statement)
36+
if !statement.is_arrow_function_body()))
3737
{
3838
write!(f, [indent(&rest)])
3939
} else {

0 commit comments

Comments
 (0)