Skip to content

Commit

Permalink
detect circular macro expansion
Browse files Browse the repository at this point in the history
- `MacroApplication` now detects any freestanding macro that appears on an expansion path more than once and throws `MacroExpansionError.circularExpansion`
- added a test case in `ExpressionMacroTests`
  • Loading branch information
AppAppWorks committed Aug 6, 2024
1 parent 24a2501 commit 2e694a1
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
4 changes: 4 additions & 0 deletions Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ enum MacroExpansionError: Error, CustomStringConvertible {
case noFreestandingMacroRoles(Macro.Type)
case moreThanOneBodyMacro
case preambleWithoutBody
case circularExpansion(Macro.Type, any FreestandingMacroExpansionSyntax)

var description: String {
switch self {
Expand Down Expand Up @@ -92,6 +93,9 @@ enum MacroExpansionError: Error, CustomStringConvertible {

case .preambleWithoutBody:
return "preamble macro cannot be applied to a function with no body"

case .circularExpansion(let type, let syntax):
return "circular expansion detected: '\(syntax)' with macro implementation type '\(type)'"
}
}
}
Expand Down
24 changes: 24 additions & 0 deletions Sources/SwiftSyntaxMacroExpansion/MacroSystem.swift
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,9 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
/// added to top-level 'CodeBlockItemList'.
var extensions: [CodeBlockItemSyntax] = []

/// Stores the types of the freestanding macros that are currently expanding.
var expandingFreestandingMacros: [any Macro.Type] = []

init(
macroSystem: MacroSystem,
contextGenerator: @escaping (Syntax) -> Context,
Expand All @@ -687,6 +690,11 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
return nil
}

let macroCount = expandingFreestandingMacros.count
defer {
expandingFreestandingMacros.removeLast(expandingFreestandingMacros.count - macroCount)
}

// Expand 'MacroExpansionExpr'.
// Note that 'MacroExpansionExpr'/'MacroExpansionExprDecl' at code item
// position are handled by 'visit(_:CodeBlockItemListSyntax)'.
Expand Down Expand Up @@ -792,6 +800,11 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
override func visit(_ node: CodeBlockItemListSyntax) -> CodeBlockItemListSyntax {
var newItems: [CodeBlockItemSyntax] = []
func addResult(_ node: CodeBlockItemSyntax) {
let macroCount = expandingFreestandingMacros.count
defer {
expandingFreestandingMacros.removeLast(expandingFreestandingMacros.count - macroCount)
}

// Expand freestanding macro.
switch expandCodeBlockItem(node: node) {
case .success(let expanded):
Expand Down Expand Up @@ -837,6 +850,11 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
var newItems: [MemberBlockItemSyntax] = []

func addResult(_ node: MemberBlockItemSyntax) {
let macroCount = expandingFreestandingMacros.count
defer {
expandingFreestandingMacros.removeLast(expandingFreestandingMacros.count - macroCount)
}

// Expand freestanding macro.
switch expandMemberDecl(node: node) {
case .success(let expanded):
Expand Down Expand Up @@ -1226,7 +1244,13 @@ extension MacroApplication {
else {
return .notAMacro
}

do {
guard expandingFreestandingMacros.allSatisfy({ $0 != macro }) else {
throw MacroExpansionError.circularExpansion(macro, node)
}
expandingFreestandingMacros.append(macro)

if let expanded = try expandMacro(macro, node) {
return .success(expanded)
} else {
Expand Down
67 changes: 67 additions & 0 deletions Tests/SwiftSyntaxMacroExpansionTest/ExpressionMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,46 @@ fileprivate struct StringifyMacro: ExpressionMacro {
}
}

private struct InfiniteRecursionMacro: ExpressionMacro {
static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
if let i = node.arguments.first?.expression.as(IntegerLiteralExprSyntax.self)?.representedLiteralValue {
return "\(raw: i) + #infiniteRecursion(i: \(raw: i + 1))"
} else {
return "#nested1"
}
}
}

private struct Nested1RecursionMacro: ExpressionMacro {
static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
"(#nested2, #nested3, #infiniteRecursion(i: 1), #infiniteRecursion)"
}
}

private struct Nested2RecursionMacro: ExpressionMacro {
static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
"(#nested3, #nested3)"
}
}

private struct Nested3RecursionMacro: ExpressionMacro {
static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
"0"
}
}

final class ExpressionMacroTests: XCTestCase {
private let indentationWidth: Trivia = .spaces(2)

Expand Down Expand Up @@ -292,4 +332,31 @@ final class ExpressionMacroTests: XCTestCase {
macros: ["test": DiagnoseFirstArgument.self]
)
}

func testDetectCircularExpansion() {
assertMacroExpansion(
"#nested1",
expandedSource: "((0, 0), 0, 1 + #infiniteRecursion(i: 2), #nested1)",
diagnostics: [
DiagnosticSpec(
message:
"circular expansion detected: '#infiniteRecursion(i: 2)' with macro implementation type 'InfiniteRecursionMacro'",
line: 1,
column: 5
),
DiagnosticSpec(
message:
"circular expansion detected: '#nested1' with macro implementation type 'Nested1RecursionMacro'",
line: 1,
column: 1
),
],
macros: [
"nested1": Nested1RecursionMacro.self,
"nested2": Nested2RecursionMacro.self,
"nested3": Nested3RecursionMacro.self,
"infiniteRecursion": InfiniteRecursionMacro.self,
]
)
}
}

0 comments on commit 2e694a1

Please sign in to comment.