diff --git a/Sources/SwiftFormat/Pipelines+Generated.swift b/Sources/SwiftFormat/Pipelines+Generated.swift index 6701a0a96..80c2e1ad1 100644 --- a/Sources/SwiftFormat/Pipelines+Generated.swift +++ b/Sources/SwiftFormat/Pipelines+Generated.swift @@ -161,7 +161,7 @@ class LintPipeline: SyntaxVisitor { return .visitChildren } - override func visit(_ node: IfStmtSyntax) -> SyntaxVisitorContinueKind { + override func visit(_ node: IfExprSyntax) -> SyntaxVisitorContinueKind { visitIfEnabled(NoParensAroundConditions.visit, for: node) return .visitChildren } @@ -271,7 +271,7 @@ class LintPipeline: SyntaxVisitor { return .visitChildren } - override func visit(_ node: SwitchStmtSyntax) -> SyntaxVisitorContinueKind { + override func visit(_ node: SwitchExprSyntax) -> SyntaxVisitorContinueKind { visitIfEnabled(NoParensAroundConditions.visit, for: node) return .visitChildren } diff --git a/Sources/SwiftFormatPrettyPrint/TokenStreamCreator.swift b/Sources/SwiftFormatPrettyPrint/TokenStreamCreator.swift index fa592e4cf..a02dc66c1 100644 --- a/Sources/SwiftFormatPrettyPrint/TokenStreamCreator.swift +++ b/Sources/SwiftFormatPrettyPrint/TokenStreamCreator.swift @@ -480,7 +480,7 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor { return .visitChildren } - override func visit(_ node: IfStmtSyntax) -> SyntaxVisitorContinueKind { + override func visit(_ node: IfExprSyntax) -> SyntaxVisitorContinueKind { // There may be a consistent breaking group around this node, see `CodeBlockItemSyntax`. This // group is necessary so that breaks around and inside of the conditions aren't forced to break // when the if-stmt spans multiple lines. @@ -515,7 +515,7 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor { // any newlines between `else` and the open brace or a following `if`. if let tokenAfterElse = elseKeyword.nextToken(viewMode: .all), tokenAfterElse.leadingTrivia.hasLineComment { after(node.elseKeyword, tokens: .break(.same, size: 1)) - } else if let elseBody = node.elseBody, elseBody.is(IfStmtSyntax.self) { + } else if let elseBody = node.elseBody, elseBody.is(IfExprSyntax.self) { after(node.elseKeyword, tokens: .space) } } @@ -673,7 +673,7 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor { return .visitChildren } - override func visit(_ node: SwitchStmtSyntax) -> SyntaxVisitorContinueKind { + override func visit(_ node: SwitchExprSyntax) -> SyntaxVisitorContinueKind { before(node.switchKeyword, tokens: .open) after(node.switchKeyword, tokens: .space) before(node.leftBrace, tokens: .break(.reset)) @@ -1457,7 +1457,8 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor { // This group applies to a top-level if-stmt so that all of the bodies will have the same // breaking behavior. - if let ifStmt = node.item.as(IfStmtSyntax.self) { + if let exprStmt = node.item.as(ExpressionStmtSyntax.self), + let ifStmt = exprStmt.expression.as(IfExprSyntax.self) { before(ifStmt.conditions.firstToken, tokens: .open(.consistent)) after(ifStmt.lastToken, tokens: .close) } diff --git a/Sources/SwiftFormatRules/NoParensAroundConditions.swift b/Sources/SwiftFormatRules/NoParensAroundConditions.swift index a2aa8df1e..92b43d7ad 100644 --- a/Sources/SwiftFormatRules/NoParensAroundConditions.swift +++ b/Sources/SwiftFormatRules/NoParensAroundConditions.swift @@ -52,7 +52,7 @@ public final class NoParensAroundConditions: SyntaxFormatRule { ) } - public override func visit(_ node: IfStmtSyntax) -> StmtSyntax { + public override func visit(_ node: IfExprSyntax) -> ExprSyntax { let conditions = visit(node.conditions) var result = node.withIfKeyword(node.ifKeyword.withOneTrailingSpace()) .withConditions(conditions) @@ -60,7 +60,7 @@ public final class NoParensAroundConditions: SyntaxFormatRule { if let elseBody = node.elseBody { result = result.withElseBody(visit(elseBody)) } - return StmtSyntax(result) + return ExprSyntax(result) } public override func visit(_ node: ConditionElementSyntax) -> ConditionElementSyntax { @@ -72,14 +72,14 @@ public final class NoParensAroundConditions: SyntaxFormatRule { return node.withCondition(.expression(extractExpr(tup))) } - /// FIXME(hbh): Parsing for SwitchStmtSyntax is not implemented. - public override func visit(_ node: SwitchStmtSyntax) -> StmtSyntax { + /// FIXME(hbh): Parsing for SwitchExprSyntax is not implemented. + public override func visit(_ node: SwitchExprSyntax) -> ExprSyntax { guard let tup = node.expression.as(TupleExprSyntax.self), tup.elementList.firstAndOnly != nil else { return super.visit(node) } - return StmtSyntax( + return ExprSyntax( node.withExpression(extractExpr(tup)).withCases(visit(node.cases))) } diff --git a/Sources/SwiftFormatRules/UseEarlyExits.swift b/Sources/SwiftFormatRules/UseEarlyExits.swift index c2c0f62ac..c748734f4 100644 --- a/Sources/SwiftFormatRules/UseEarlyExits.swift +++ b/Sources/SwiftFormatRules/UseEarlyExits.swift @@ -54,11 +54,12 @@ public final class UseEarlyExits: SyntaxFormatRule { let result = CodeBlockItemListSyntax( codeBlockItems.flatMap { (codeBlockItem: CodeBlockItemSyntax) -> [CodeBlockItemSyntax] in - // The `elseBody` of an `IfStmtSyntax` will be a `CodeBlockSyntax` if it's an `else` block, - // or another `IfStmtSyntax` if it's an `else if` block. We only want to handle the former. - guard let ifStatement = codeBlockItem.item.as(IfStmtSyntax.self), - let elseBody = ifStatement.elseBody?.as(CodeBlockSyntax.self), - codeBlockEndsWithEarlyExit(elseBody) + // The `elseBody` of an `IfExprSyntax` will be a `CodeBlockSyntax` if it's an `else` block, + // or another `IfExprSyntax` if it's an `else if` block. We only want to handle the former. + guard let exprStmt = codeBlockItem.item.as(ExpressionStmtSyntax.self), + let ifStatement = exprStmt.expression.as(IfExprSyntax.self), + let elseBody = ifStatement.elseBody?.as(CodeBlockSyntax.self), + codeBlockEndsWithEarlyExit(elseBody) else { return [codeBlockItem] } diff --git a/Sources/SwiftFormatRules/UseWhereClausesInForLoops.swift b/Sources/SwiftFormatRules/UseWhereClausesInForLoops.swift index bc58af4b8..5916ff86f 100644 --- a/Sources/SwiftFormatRules/UseWhereClausesInForLoops.swift +++ b/Sources/SwiftFormatRules/UseWhereClausesInForLoops.swift @@ -52,22 +52,26 @@ public final class UseWhereClausesInForLoops: SyntaxFormatRule { forInStmt: ForInStmtSyntax ) -> ForInStmtSyntax { switch Syntax(firstStmt).as(SyntaxEnum.self) { - case .ifStmt(let ifStmt) - where ifStmt.conditions.count == 1 - && ifStmt.elseKeyword == nil - && forInStmt.body.statements.count == 1: - // Extract the condition of the IfStmt. - let conditionElement = ifStmt.conditions.first! - guard let condition = conditionElement.condition.as(ExprSyntax.self) else { + case .expressionStmt(let exprStmt): + switch Syntax(exprStmt.expression).as(SyntaxEnum.self) { + case .ifExpr(let ifExpr) + where ifExpr.conditions.count == 1 + && ifExpr.elseKeyword == nil + && forInStmt.body.statements.count == 1: + // Extract the condition of the IfExpr. + let conditionElement = ifExpr.conditions.first! + guard let condition = conditionElement.condition.as(ExprSyntax.self) else { + return forInStmt + } + diagnose(.useWhereInsteadOfIf, on: ifExpr) + return updateWithWhereCondition( + node: forInStmt, + condition: condition, + statements: ifExpr.body.statements + ) + default: return forInStmt } - diagnose(.useWhereInsteadOfIf, on: ifStmt) - return updateWithWhereCondition( - node: forInStmt, - condition: condition, - statements: ifStmt.body.statements - ) - case .guardStmt(let guardStmt) where guardStmt.conditions.count == 1 && guardStmt.body.statements.count == 1 diff --git a/Tests/SwiftFormatPrettyPrintTests/IfStmtTests.swift b/Tests/SwiftFormatPrettyPrintTests/IfStmtTests.swift index 51fad551e..a67dfc754 100644 --- a/Tests/SwiftFormatPrettyPrintTests/IfStmtTests.swift +++ b/Tests/SwiftFormatPrettyPrintTests/IfStmtTests.swift @@ -153,6 +153,74 @@ final class IfStmtTests: PrettyPrintTestCase { assertPrettyPrintEqual(input: input, expected: expected, linelength: 20, configuration: config) } + func testIfExpression1() { + let input = + """ + func foo() -> Int { + if var1 < var2 { + 23 + } + else if d < e { + 24 + } + else { + 0 + } + } + """ + + let expected = + """ + func foo() -> Int { + if var1 < var2 { + 23 + } else if d < e { + 24 + } else { + 0 + } + } + + """ + + assertPrettyPrintEqual(input: input, expected: expected, linelength: 23) + } + + func testIfExpression2() { + let input = + """ + func foo() -> Int { + let x = if var1 < var2 { + 23 + } + else if d < e { + 24 + } + else { + 0 + } + return x + } + """ + + let expected = + """ + func foo() -> Int { + let x = if var1 < var2 { + 23 + } else if d < e { + 24 + } else { + 0 + } + return x + } + + """ + + assertPrettyPrintEqual(input: input, expected: expected, linelength: 26) + } + func testMatchingPatternConditions() { let input = """ diff --git a/Tests/SwiftFormatPrettyPrintTests/SwitchStmtTests.swift b/Tests/SwiftFormatPrettyPrintTests/SwitchStmtTests.swift index 365e8ce36..35726d224 100644 --- a/Tests/SwiftFormatPrettyPrintTests/SwitchStmtTests.swift +++ b/Tests/SwiftFormatPrettyPrintTests/SwitchStmtTests.swift @@ -205,6 +205,77 @@ final class SwitchStmtTests: PrettyPrintTestCase { assertPrettyPrintEqual(input: input, expected: expected, linelength: 45) } + func testSwitchExpression1() { + let input = + """ + func foo() -> Int { + switch value1 + value2 + value3 + value4 { + case "a": + 0 + case "b": + 1 + default: + 2 + } + } + """ + + let expected = + """ + func foo() -> Int { + switch value1 + value2 + value3 + + value4 + { + case "a": + 0 + case "b": + 1 + default: + 2 + } + } + + """ + + assertPrettyPrintEqual(input: input, expected: expected, linelength: 35) + } + + + func testSwitchExpression2() { + let input = + """ + func foo() -> Int { + let x = switch value1 + value2 + value3 + value4 { + case "a": + 0 + case "b": + 1 + default: + 2 + } + return x + } + """ + + let expected = + """ + func foo() -> Int { + let x = switch value1 + value2 + value3 + value4 { + case "a": + 0 + case "b": + 1 + default: + 2 + } + return x + } + + """ + + assertPrettyPrintEqual(input: input, expected: expected, linelength: 52) + } + func testUnknownDefault() { let input = """ diff --git a/Tests/SwiftFormatRulesTests/NoParensAroundConditionsTests.swift b/Tests/SwiftFormatRulesTests/NoParensAroundConditionsTests.swift index 8b0b53493..865082ed3 100644 --- a/Tests/SwiftFormatRulesTests/NoParensAroundConditionsTests.swift +++ b/Tests/SwiftFormatRulesTests/NoParensAroundConditionsTests.swift @@ -135,4 +135,29 @@ final class NoParensAroundConditionsTests: LintOrFormatRuleTestCase { if foo.someCall({ if x {} }) {} """) } + + func testParensAroundIfAndSwitchExprs() { + XCTAssertFormatting( + NoParensAroundConditions.self, + input: """ + let x = if (x) {} + let y = switch (4) { default: break } + func foo() { + return if (x) {} + } + func bar() { + return switch (4) { default: break } + } + """, + expected: """ + let x = if x {} + let y = switch 4 { default: break } + func foo() { + return if x {} + } + func bar() { + return switch 4 { default: break } + } + """) + } }