diff --git a/Sources/TestingMacros/Support/ConditionArgumentParsing.swift b/Sources/TestingMacros/Support/ConditionArgumentParsing.swift index 18ee9f445..d59d7600d 100644 --- a/Sources/TestingMacros/Support/ConditionArgumentParsing.swift +++ b/Sources/TestingMacros/Support/ConditionArgumentParsing.swift @@ -242,6 +242,18 @@ private func _parseCondition(from expr: ClosureExprSyntax, for macro: some Frees return Condition(expression: expr) } +/// A class that walks a syntax tree looking for optional chaining expressions +/// such as `a?.b.c`. +private final class _OptionalChainFinder: SyntaxVisitor { + /// Whether or not any optional chaining was found. + var optionalChainFound = false + + override func visit(_ node: OptionalChainingExprSyntax) -> SyntaxVisitorContinueKind { + optionalChainFound = true + return .skipChildren + } +} + /// Extract the underlying expression from an optional-chained expression as /// well as the number of question marks required to reach it. /// @@ -279,15 +291,9 @@ private func _exprFromOptionalChainedExpr(_ expr: some ExprSyntaxProtocol) -> (E // the member accesses in the expression use optional chaining and, if one // does, ensure we preserve optional chaining in the macro expansion. if questionMarkCount == 0 { - func isOptionalChained(_ expr: some ExprSyntaxProtocol) -> Bool { - if expr.is(OptionalChainingExprSyntax.self) { - return true - } else if let memberAccessBaseExpr = expr.as(MemberAccessExprSyntax.self)?.base { - return isOptionalChained(memberAccessBaseExpr) - } - return false - } - if isOptionalChained(originalExpr) { + let optionalChainFinder = _OptionalChainFinder(viewMode: .sourceAccurate) + optionalChainFinder.walk(originalExpr) + if optionalChainFinder.optionalChainFound { questionMarkCount = 1 } } diff --git a/Tests/TestingMacrosTests/ConditionMacroTests.swift b/Tests/TestingMacrosTests/ConditionMacroTests.swift index 16d178e0d..070483a78 100644 --- a/Tests/TestingMacrosTests/ConditionMacroTests.swift +++ b/Tests/TestingMacrosTests/ConditionMacroTests.swift @@ -88,6 +88,8 @@ struct ConditionMacroTests { ##"Testing.__checkPropertyAccess(a.self, getting: { $0???.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a"), .__fromSyntaxNode("isB")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##, ##"#expect(a?.b.isB)"##: ##"Testing.__checkPropertyAccess(a?.b.self, getting: { $0?.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a?.b"), .__fromSyntaxNode("isB")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##, + ##"#expect(a?.b().isB)"##: + ##"Testing.__checkPropertyAccess(a?.b().self, getting: { $0?.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a?.b()"), .__fromSyntaxNode("isB")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation.__here()).__expected()"##, ##"#expect(isolation: somewhere) {}"##: ##"Testing.__checkClosureCall(performing: {}, expression: .__fromSyntaxNode("{}"), comments: [], isRequired: false, isolation: somewhere, sourceLocation: Testing.SourceLocation.__here()).__expected()"##, ] @@ -166,6 +168,8 @@ struct ConditionMacroTests { ##"Testing.__checkPropertyAccess(a.self, getting: { $0???.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a"), .__fromSyntaxNode("isB")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##, ##"#require(a?.b.isB)"##: ##"Testing.__checkPropertyAccess(a?.b.self, getting: { $0?.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a?.b"), .__fromSyntaxNode("isB")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##, + ##"#require(a?.b().isB)"##: + ##"Testing.__checkPropertyAccess(a?.b().self, getting: { $0?.isB }, expression: .__fromPropertyAccess(.__fromSyntaxNode("a?.b()"), .__fromSyntaxNode("isB")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation.__here()).__required()"##, ##"#require(isolation: somewhere) {}"##: ##"Testing.__checkClosureCall(performing: {}, expression: .__fromSyntaxNode("{}"), comments: [], isRequired: true, isolation: somewhere, sourceLocation: Testing.SourceLocation.__here()).__required()"##, ] diff --git a/Tests/TestingTests/MiscellaneousTests.swift b/Tests/TestingTests/MiscellaneousTests.swift index 5d412f996..b9f274c8e 100644 --- a/Tests/TestingTests/MiscellaneousTests.swift +++ b/Tests/TestingTests/MiscellaneousTests.swift @@ -222,6 +222,11 @@ struct MultiLineSuite { staticMultiLineTestDecl() async {} } +@Test(.hidden) func complexOptionalChainingWithRequire() throws { + let x: String? = nil + _ = try #require(x?[...].last) +} + @Suite("Miscellaneous tests") struct MiscellaneousTests { @Test("Free function's name")