Skip to content

Commit

Permalink
implement removing excessive trivia using swift-syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
MahdiBM committed Jul 17, 2024
1 parent a6143c5 commit 0a9b591
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 27 deletions.
17 changes: 11 additions & 6 deletions Sources/EnumeratorMacroImpl/EnumeratorMacroType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ extension EnumeratorMacroType: MemberMacro {
guard let rendered else {
return nil
}
let noEmptyLines = rendered.split(separator: "\n").joined(separator: "\n")
return (noEmptyLines, syntax)
return (rendered, syntax)
} catch {
let message: MacroError
let errorSyntax: SyntaxProtocol
Expand Down Expand Up @@ -168,11 +167,17 @@ extension EnumeratorMacroType: MemberMacro {
($0, result.codeSyntax)
}
}
let postProcessedSyntaxes = syntaxes.compactMap {
let postProcessedSyntaxes = syntaxes.compactMap {
(syntax, codeSyntax) -> DeclSyntax? in
let postProcessor = PostProcessor()
let newSyntax = postProcessor.rewrite(syntax)
guard let declSyntax = DeclSyntax(newSyntax) else {
var processedSyntax = Syntax(syntax)

let excessiveTriviaRemover = ExcessiveTriviaRemover()
processedSyntax = excessiveTriviaRemover.rewrite(processedSyntax)

let switchRewriter = SwitchRewriter()
processedSyntax = switchRewriter.rewrite(processedSyntax)

guard let declSyntax = DeclSyntax(processedSyntax) else {
context.diagnose(
Diagnostic(
node: codeSyntax,
Expand Down
45 changes: 45 additions & 0 deletions Sources/EnumeratorMacroImpl/ExcessiveTriviaRemover.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import SwiftSyntax

final class ExcessiveTriviaRemover: SyntaxRewriter {
/// Remove empty lines if there are more than 1 lines stacked together.
override func visitAny(_ node: Syntax) -> Syntax? {
var node = node

var modifiedLeadingTrivia = false
let newLeadingTrivia = node.leadingTrivia.pieces.map { piece in
if case let .newlines(count) = piece,
count > 1 {
modifiedLeadingTrivia = true
return TriviaPiece.newlines(1)
} else {
return piece
}
}
if modifiedLeadingTrivia {
node = node.with(
\.leadingTrivia,
Trivia(pieces: newLeadingTrivia)
)
}

var modifiedTrailingTrivia = false
let newTrailingTrivia = node.trailingTrivia.pieces.map { piece in
if case let .newlines(count) = piece,
count > 1 {
modifiedTrailingTrivia = true
return TriviaPiece.newlines(1)
} else {
return piece
}
}
if modifiedTrailingTrivia {
node = node.with(
\.trailingTrivia,
Trivia(pieces: newTrailingTrivia)
)
}

let modified = modifiedLeadingTrivia || modifiedTrailingTrivia
return modified ? self.rewrite(node) : nil
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import SwiftSyntax

final class PostProcessor: SyntaxRewriter {
final class SwitchRewriter: SyntaxRewriter {
override func visit(_ node: SwitchCaseSyntax) -> SwitchCaseSyntax {
self.removeUnusedLet(
self.removeUnusedArguments(
Expand Down
80 changes: 60 additions & 20 deletions Tests/EnumeratorMacroTests/EnumeratorMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import SwiftSyntaxMacrosTestSupport
import XCTest

final class EnumeratorMacroTests: XCTestCase {
func testCreatesCaseName() throws {
func testCreatesCaseName() {
assertMacroExpansion(
#"""
@Enumerator(
Expand All @@ -15,9 +15,6 @@ final class EnumeratorMacroTests: XCTestCase {
{{#cases}}
case .{{name}}:
"{{name}}"
{{/cases}}
}
}
Expand Down Expand Up @@ -47,7 +44,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testCreatesACopyOfSelf() throws {
func testCreatesACopyOfSelf() {
assertMacroExpansion(
#"""
@Enumerator("""
Expand Down Expand Up @@ -80,7 +77,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testCreatesDeclarationsForCaseChecking() throws {
func testCreatesDeclarationsForCaseChecking() {
assertMacroExpansion(
#"""
@Enumerator("""
Expand Down Expand Up @@ -139,7 +136,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testCreatesSubtypeWithMultiMacroArguments() throws {
func testCreatesSubtypeWithMultiMacroArguments() {
assertMacroExpansion(
#"""
@Enumerator("""
Expand Down Expand Up @@ -193,7 +190,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testCreatesGetCaseValueFunctions() throws {
func testCreatesGetCaseValueFunctions() {
assertMacroExpansion(
#"""
@Enumerator("""
Expand Down Expand Up @@ -245,7 +242,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testProperlyReadsComments() throws {
func testProperlyReadsComments() {
assertMacroExpansion(
#"""
@Enumerator("""
Expand Down Expand Up @@ -298,7 +295,50 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testDiagnosesNotAnEnum() throws {
func removesExcessiveTrivia() {
assertMacroExpansion(
#"""
@Enumerator(
"""
var caseName: String {
switch self {
{{#cases}}
case .{{name}}:
"{{name}}"
{{/cases}}
}
}
"""
)
enum TestEnum {
case a
case b
}
"""#,
expandedSource: #"""
enum TestEnum {
case a
case b
var caseName: String {
switch self {
case .a:
"a"
case .b:
"b"
}
}
}
"""#,
macros: EnumeratorMacroEntryPoint.macros
)
}


func testDiagnosesNotAnEnum() {
assertMacroExpansion(
#"""
@Enumerator("""
Expand Down Expand Up @@ -333,7 +373,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testDiagnosesNoArguments() throws {
func testDiagnosesNoArguments() {
assertMacroExpansion(
#"""
@Enumerator
Expand Down Expand Up @@ -362,7 +402,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testDiagnosesEmptyArguments() throws {
func testDiagnosesEmptyArguments() {
assertMacroExpansion(
#"""
@Enumerator
Expand Down Expand Up @@ -391,7 +431,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testDiagnosesUnacceptableArguments() throws {
func testDiagnosesUnacceptableArguments() {
assertMacroExpansion(
#"""
@Enumerator(myVariable)
Expand Down Expand Up @@ -420,7 +460,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testDiagnosesStringInterpolationInMustacheTemplate() throws {
func testDiagnosesStringInterpolationInMustacheTemplate() {
assertMacroExpansion(
#"""
@Enumerator("""
Expand Down Expand Up @@ -474,7 +514,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testDiagnosesBadMustacheTemplate() throws {
func testDiagnosesBadMustacheTemplate() {
assertMacroExpansion(
#"""
@Enumerator("""
Expand Down Expand Up @@ -513,7 +553,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testDiagnosesErroneousSwiftCode() throws {
func testDiagnosesErroneousSwiftCode() {
assertMacroExpansion(
#"""
@Enumerator("""
Expand Down Expand Up @@ -640,7 +680,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testRemovesUnusedLetInSwitchStatements() throws {
func testRemovesUnusedLetInSwitchStatements() {
assertMacroExpansion(
#"""
@Enumerator("""
Expand Down Expand Up @@ -679,7 +719,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testRemovesArgumentInSwitchStatements() throws {
func testRemovesArgumentInSwitchStatements() {
assertMacroExpansion(
#"""
@Enumerator("""
Expand Down Expand Up @@ -718,7 +758,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testRemovesArgumentInSwitchStatementsWithMultipleArgumentsWhereOneArgIsUsed() throws {
func testRemovesArgumentInSwitchStatementsWithMultipleArgumentsWhereOneArgIsUsed() {
assertMacroExpansion(
#"""
@Enumerator("""
Expand Down Expand Up @@ -757,7 +797,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

// func testAppliesFixIts() throws {
// func testAppliesFixIts() {
// assertMacroExpansion(
// #"""
// @Enumerator("""
Expand Down

0 comments on commit 0a9b591

Please sign in to comment.