Skip to content

Commit

Permalink
detect circular macro expansion
Browse files Browse the repository at this point in the history
- `MacroApplication` now detects any freestanding macro that appears on an expansion path more than once and throws `MacroExpansionError.circularExpansion`
- added a test case in `ExpressionMacroTests`
  • Loading branch information
AppAppWorks committed Aug 7, 2024
1 parent 24a2501 commit 394075d
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 11 deletions.
4 changes: 4 additions & 0 deletions Sources/SwiftSyntaxMacroExpansion/MacroExpansion.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ enum MacroExpansionError: Error, CustomStringConvertible {
case noFreestandingMacroRoles(Macro.Type)
case moreThanOneBodyMacro
case preambleWithoutBody
case circularExpansion(Macro.Type, any FreestandingMacroExpansionSyntax)

var description: String {
switch self {
Expand Down Expand Up @@ -92,6 +93,9 @@ enum MacroExpansionError: Error, CustomStringConvertible {

case .preambleWithoutBody:
return "preamble macro cannot be applied to a function with no body"

case .circularExpansion(let type, let syntax):
return "circular expansion detected: '\(syntax)' with macro implementation type '\(type)'"
}
}
}
Expand Down
64 changes: 53 additions & 11 deletions Sources/SwiftSyntaxMacroExpansion/MacroSystem.swift
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,9 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
/// added to top-level 'CodeBlockItemList'.
var extensions: [CodeBlockItemSyntax] = []

/// Stores the types of the freestanding macros that are currently expanding.
var expandingFreestandingMacros: [any Macro.Type] = []

init(
macroSystem: MacroSystem,
contextGenerator: @escaping (Syntax) -> Context,
Expand All @@ -687,13 +690,20 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
return nil
}

let macroCount = expandingFreestandingMacros.count
defer {
expandingFreestandingMacros.removeLast(expandingFreestandingMacros.count - macroCount)
}

// Expand 'MacroExpansionExpr'.
// Note that 'MacroExpansionExpr'/'MacroExpansionExprDecl' at code item
// position are handled by 'visit(_:CodeBlockItemListSyntax)'.
// Only expression expansions inside other syntax nodes is handled here.
switch expandExpr(node: node) {
case .success(let expanded):
return Syntax(visit(expanded))
case .success(let expansion):
return expansion.map { expanded in
Syntax(visit(expanded))
}
case .failure:
return Syntax(node)
case .notAMacro:
Expand Down Expand Up @@ -794,9 +804,11 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
func addResult(_ node: CodeBlockItemSyntax) {
// Expand freestanding macro.
switch expandCodeBlockItem(node: node) {
case .success(let expanded):
for item in expanded {
addResult(item)
case .success(let expansion):
expansion.map { expanded in
for item in expanded {
addResult(item)
}
}
return
case .failure:
Expand Down Expand Up @@ -839,9 +851,11 @@ private class MacroApplication<Context: MacroExpansionContext>: SyntaxRewriter {
func addResult(_ node: MemberBlockItemSyntax) {
// Expand freestanding macro.
switch expandMemberDecl(node: node) {
case .success(let expanded):
for item in expanded {
addResult(item)
case .success(let expansion):
expansion.map { expanded in
for item in expanded {
addResult(item)
}
}
return
case .failure:
Expand Down Expand Up @@ -1204,10 +1218,33 @@ extension MacroApplication {

// MARK: Freestanding macro expansion

protocol MacroExpansion<ResultType> {
associatedtype ResultType
func map<T>(_ transform: (ResultType) throws -> T) rethrows -> T
}

extension MacroApplication {
struct StackMacroExpansion<ResultType>: MacroExpansion {
let expandedNode: ResultType
unowned let macroApplication: MacroApplication

init(expandedNode: ResultType, macro: any Macro.Type, macroApplication: MacroApplication) {
self.expandedNode = expandedNode
self.macroApplication = macroApplication
macroApplication.expandingFreestandingMacros.append(macro)
}

func map<T>(_ transform: (ResultType) throws -> T) rethrows -> T {
defer {
macroApplication.expandingFreestandingMacros.removeLast()
}
return try transform(expandedNode)
}
}

enum MacroExpansionResult<ResultType> {
/// Expansion of the macro succeeded.
case success(ResultType)
case success(any MacroExpansion<ResultType>)

/// Macro system found the macro to expand but running the expansion threw
/// an error and thus no expansion result exists.
Expand All @@ -1219,16 +1256,21 @@ extension MacroApplication {

private func expandFreestandingMacro<ExpandedMacroType: SyntaxProtocol>(
_ node: (any FreestandingMacroExpansionSyntax)?,
expandMacro: (_ macro: Macro.Type, _ node: any FreestandingMacroExpansionSyntax) throws -> ExpandedMacroType?
expandMacro: (_ macro: any Macro.Type, _ node: any FreestandingMacroExpansionSyntax) throws -> ExpandedMacroType?
) -> MacroExpansionResult<ExpandedMacroType> {
guard let node,
let macro = macroSystem.lookup(node.macroName.text)?.type
else {
return .notAMacro
}

do {
guard expandingFreestandingMacros.allSatisfy({ $0 != macro }) else {
throw MacroExpansionError.circularExpansion(macro, node)
}

if let expanded = try expandMacro(macro, node) {
return .success(expanded)
return .success(StackMacroExpansion(expandedNode: expanded, macro: macro, macroApplication: self))
} else {
return .failure
}
Expand Down
67 changes: 67 additions & 0 deletions Tests/SwiftSyntaxMacroExpansionTest/ExpressionMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,46 @@ fileprivate struct StringifyMacro: ExpressionMacro {
}
}

private struct InfiniteRecursionMacro: ExpressionMacro {
static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
if let i = node.arguments.first?.expression.as(IntegerLiteralExprSyntax.self)?.representedLiteralValue {
return "\(raw: i) + #infiniteRecursion(i: \(raw: i + 1))"
} else {
return "#nested1"
}
}
}

private struct Nested1RecursionMacro: ExpressionMacro {
static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
"(#nested2, #nested3, #infiniteRecursion(i: 1), #infiniteRecursion)"
}
}

private struct Nested2RecursionMacro: ExpressionMacro {
static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
"(#nested3, #nested3)"
}
}

private struct Nested3RecursionMacro: ExpressionMacro {
static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
"0"
}
}

final class ExpressionMacroTests: XCTestCase {
private let indentationWidth: Trivia = .spaces(2)

Expand Down Expand Up @@ -292,4 +332,31 @@ final class ExpressionMacroTests: XCTestCase {
macros: ["test": DiagnoseFirstArgument.self]
)
}

func testDetectCircularExpansion() {
assertMacroExpansion(
"#nested1",
expandedSource: "((0, 0), 0, 1 + #infiniteRecursion(i: 2), #nested1)",
diagnostics: [
DiagnosticSpec(
message:
"circular expansion detected: '#infiniteRecursion(i: 2)' with macro implementation type 'InfiniteRecursionMacro'",
line: 1,
column: 5
),
DiagnosticSpec(
message:
"circular expansion detected: '#nested1' with macro implementation type 'Nested1RecursionMacro'",
line: 1,
column: 1
),
],
macros: [
"nested1": Nested1RecursionMacro.self,
"nested2": Nested2RecursionMacro.self,
"nested3": Nested3RecursionMacro.self,
"infiniteRecursion": InfiniteRecursionMacro.self,
]
)
}
}

0 comments on commit 394075d

Please sign in to comment.