diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index 63c63e9342674..d759acdd28d48 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -44,6 +44,7 @@ class FuncDecl; class AbstractFunctionDecl; class Pattern; class PatternBindingDecl; +class ReturnStmt; class VarDecl; class CaseStmt; class DoCatchStmt; @@ -244,6 +245,12 @@ class BraceStmt final : public Stmt, /// statement, `false` otherwise. bool hasExplicitReturnStmt(ASTContext &ctx) const; + /// Finds occurrences of explicit `return` statements within the brace + /// statement. + /// \param results An out container to which the results are added. + void getExplicitReturnStmts(ASTContext &ctx, + SmallVectorImpl &results) const; + static bool classof(const Stmt *S) { return S->getKind() == StmtKind::Brace; } }; diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp index ea8e949e6a081..858533b09289a 100644 --- a/lib/AST/Stmt.cpp +++ b/lib/AST/Stmt.cpp @@ -317,11 +317,80 @@ Stmt *BraceStmt::getSingleActiveStatement() const { return getSingleActiveElement().dyn_cast(); } +/// Walks the given brace statement and calls the given function reference on +/// every occurrence of an explicit `return` statement. +/// +/// \param callback A function reference that takes a `return` statement and +/// returns a boolean value indicating whether to abort the walk. +/// +/// \returns `true` if the walk was aborted, `false` otherwise. +static bool walkExplicitReturnStmts(const BraceStmt *BS, + function_ref callback) { + class Walker : public ASTWalker { + function_ref callback; + + public: + Walker(decltype(Walker::callback) callback) : callback(callback) {} + + MacroWalking getMacroWalkingBehavior() const override { + return MacroWalking::Arguments; + } + + PreWalkResult walkToExprPre(Expr *E) override { + return Action::SkipNode(E); + } + + PreWalkResult walkToStmtPre(Stmt *S) override { + if (S->isImplicit()) { + return Action::SkipNode(S); + } + + auto *returnStmt = dyn_cast(S); + if (!returnStmt) { + return Action::Continue(S); + } + + if (callback(returnStmt)) { + return Action::Stop(); + } + + // Skip children & post walk and continue. + return Action::SkipNode(S); + } + + /// Ignore patterns. + PreWalkResult walkToPatternPre(Pattern *pat) override { + return Action::SkipNode(pat); + } + }; + + Walker walker(callback); + + return const_cast(BS)->walk(walker) == nullptr; +} + +bool BraceHasExplicitReturnStmtRequest::evaluate(Evaluator &evaluator, + const BraceStmt *BS) const { + return walkExplicitReturnStmts(BS, [](ReturnStmt *) { return true; }); +} + bool BraceStmt::hasExplicitReturnStmt(ASTContext &ctx) const { return evaluateOrDefault(ctx.evaluator, BraceHasExplicitReturnStmtRequest{this}, false); } +void BraceStmt::getExplicitReturnStmts( + ASTContext &ctx, SmallVectorImpl &results) const { + if (!hasExplicitReturnStmt(ctx)) { + return; + } + + walkExplicitReturnStmts(this, [&results](ReturnStmt *RS) { + results.push_back(RS); + return false; + }); +} + IsSingleValueStmtResult Stmt::mayProduceSingleValue(ASTContext &ctx) const { return evaluateOrDefault(ctx.evaluator, IsSingleValueStmtRequest{this, &ctx}, IsSingleValueStmtResult::circularReference()); diff --git a/lib/Sema/BuilderTransform.cpp b/lib/Sema/BuilderTransform.cpp index 65e2b3353b55d..266e620b30072 100644 --- a/lib/Sema/BuilderTransform.cpp +++ b/lib/Sema/BuilderTransform.cpp @@ -918,7 +918,9 @@ std::optional TypeChecker::applyResultBuilderBodyTransform(FuncDecl *func, Type builderType) { // First look for any return statements, and bail if we have any. auto &ctx = func->getASTContext(); - if (auto returnStmts = findReturnStatements(func); !returnStmts.empty()) { + if (SmallVector returnStmts; + func->getBody()->getExplicitReturnStmts(ctx, returnStmts), + !returnStmts.empty()) { // One or more explicit 'return' statements were encountered, which // disables the result builder transform. Warn when we do this. ctx.Diags.diagnose( @@ -1222,79 +1224,6 @@ ConstraintSystem::matchResultBuilder(AnyFunctionRef fn, Type builderType, return getTypeMatchSuccess(); } -/// Walks the given brace statement and calls the given function reference on -/// every occurrence of an explicit `return` statement. -/// -/// \param callback A function reference that takes a `return` statement and -/// returns a boolean value indicating whether to abort the walk. -/// -/// \returns `true` if the walk was aborted, `false` otherwise. -static bool walkExplicitReturnStmts(const BraceStmt *BS, - function_ref callback) { - class Walker : public ASTWalker { - function_ref callback; - - public: - Walker(decltype(Walker::callback) callback) : callback(callback) {} - - MacroWalking getMacroWalkingBehavior() const override { - return MacroWalking::Arguments; - } - - PreWalkResult walkToExprPre(Expr *E) override { - return Action::SkipNode(E); - } - - PreWalkResult walkToStmtPre(Stmt *S) override { - if (S->isImplicit()) { - return Action::SkipNode(S); - } - - auto *returnStmt = dyn_cast(S); - if (!returnStmt) { - return Action::Continue(S); - } - - if (callback(returnStmt)) { - return Action::Stop(); - } - - // Skip children & post walk and continue. - return Action::SkipNode(S); - } - - /// Ignore patterns. - PreWalkResult walkToPatternPre(Pattern *pat) override { - return Action::SkipNode(pat); - } - }; - - Walker walker(callback); - - return const_cast(BS)->walk(walker) == nullptr; -} - -bool BraceHasExplicitReturnStmtRequest::evaluate(Evaluator &evaluator, - const BraceStmt *BS) const { - return walkExplicitReturnStmts(BS, [](ReturnStmt *) { return true; }); -} - -std::vector TypeChecker::findReturnStatements(AnyFunctionRef fn) { - if (!fn.getBody()->hasExplicitReturnStmt( - fn.getAsDeclContext()->getASTContext())) { - return std::vector(); - } - - std::vector results; - - walkExplicitReturnStmts(fn.getBody(), [&results](ReturnStmt *RS) { - results.push_back(RS); - return false; - }); - - return results; -} - ResultBuilderOpSupport TypeChecker::checkBuilderOpSupport( Type builderType, DeclContext *dc, Identifier fnName, ArrayRef argLabels, SmallVectorImpl *allResults) { diff --git a/lib/Sema/CSDiagnostics.cpp b/lib/Sema/CSDiagnostics.cpp index 28d8b59f08d67..b0d6c93825cdd 100644 --- a/lib/Sema/CSDiagnostics.cpp +++ b/lib/Sema/CSDiagnostics.cpp @@ -8860,7 +8860,8 @@ bool ReferenceToInvalidDeclaration::diagnoseAsError() { bool InvalidReturnInResultBuilderBody::diagnoseAsError() { auto *closure = castToExpr(getAnchor()); - auto returnStmts = TypeChecker::findReturnStatements(closure); + SmallVector returnStmts; + closure->getBody()->getExplicitReturnStmts(getASTContext(), returnStmts); assert(!returnStmts.empty()); auto loc = returnStmts.front()->getReturnLoc(); diff --git a/lib/Sema/TypeCheckRequestFunctions.cpp b/lib/Sema/TypeCheckRequestFunctions.cpp index 9edbf83e080ef..9f75a9d1587cb 100644 --- a/lib/Sema/TypeCheckRequestFunctions.cpp +++ b/lib/Sema/TypeCheckRequestFunctions.cpp @@ -236,7 +236,7 @@ static Type inferResultBuilderType(ValueDecl *decl) { // Check whether there are any return statements in the function's body. // If there are, the result builder transform will be disabled, // so don't infer a result builder. - if (!TypeChecker::findReturnStatements(funcDecl).empty()) + if (funcDecl->getBody()->hasExplicitReturnStmt(dc->getASTContext())) return Type(); // Find all of the potentially inferred result builder types. diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 6436b19b399a4..8fc457227bfcc 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -468,8 +468,7 @@ void typeCheckASTNode(ASTNode &node, DeclContext *DC, std::optional applyResultBuilderBodyTransform(FuncDecl *func, Type builderType); -/// Find the return statements within the body of the given function. -std::vector findReturnStatements(AnyFunctionRef fn); +bool typeCheckClosureBody(ClosureExpr *closure); bool typeCheckTapBody(TapExpr *expr, DeclContext *DC);