Skip to content

Commit

Permalink
[NFC] AST, Sema: Make TypeChecker::findReturnStatements a member of…
Browse files Browse the repository at this point in the history
… `BraceStmt`

Also rename it to `getExplicitReturnStmts` for clarity and have it
take a `SmallVector` out parameter instead as a small optimization and
to discourage use of this new method as an alternative to
`BraceStmt::hasExplicitReturnStatement`.
  • Loading branch information
AnthonyLatsis committed Sep 28, 2024
1 parent fdb9f22 commit 02665c5
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 78 deletions.
7 changes: 7 additions & 0 deletions include/swift/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class FuncDecl;
class AbstractFunctionDecl;
class Pattern;
class PatternBindingDecl;
class ReturnStmt;
class VarDecl;
class CaseStmt;
class DoCatchStmt;
Expand Down Expand Up @@ -244,6 +245,12 @@ class BraceStmt final : public Stmt,
/// statement, `false` otherwise.
bool hasExplicitReturnStmt(ASTContext &ctx) const;

/// Finds occurrences of explicit `return` statements within the brace
/// statement.
/// \param results An out container to which the results are added.
void getExplicitReturnStmts(ASTContext &ctx,
SmallVectorImpl<ReturnStmt *> &results) const;

static bool classof(const Stmt *S) { return S->getKind() == StmtKind::Brace; }
};

Expand Down
69 changes: 69 additions & 0 deletions lib/AST/Stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,80 @@ Stmt *BraceStmt::getSingleActiveStatement() const {
return getSingleActiveElement().dyn_cast<Stmt *>();
}

/// Walks the given brace statement and calls the given function reference on
/// every occurrence of an explicit `return` statement.
///
/// \param callback A function reference that takes a `return` statement and
/// returns a boolean value indicating whether to abort the walk.
///
/// \returns `true` if the walk was aborted, `false` otherwise.
static bool walkExplicitReturnStmts(const BraceStmt *BS,
function_ref<bool(ReturnStmt *)> callback) {
class Walker : public ASTWalker {
function_ref<bool(ReturnStmt *)> callback;

public:
Walker(decltype(Walker::callback) callback) : callback(callback) {}

MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Arguments;
}

PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
return Action::SkipNode(E);
}

PreWalkResult<Stmt *> walkToStmtPre(Stmt *S) override {
if (S->isImplicit()) {
return Action::SkipNode(S);
}

auto *returnStmt = dyn_cast<ReturnStmt>(S);
if (!returnStmt) {
return Action::Continue(S);
}

if (callback(returnStmt)) {
return Action::Stop();
}

// Skip children & post walk and continue.
return Action::SkipNode(S);
}

/// Ignore patterns.
PreWalkResult<Pattern *> walkToPatternPre(Pattern *pat) override {
return Action::SkipNode(pat);
}
};

Walker walker(callback);

return const_cast<BraceStmt *>(BS)->walk(walker) == nullptr;
}

bool BraceHasExplicitReturnStmtRequest::evaluate(Evaluator &evaluator,
const BraceStmt *BS) const {
return walkExplicitReturnStmts(BS, [](ReturnStmt *) { return true; });
}

bool BraceStmt::hasExplicitReturnStmt(ASTContext &ctx) const {
return evaluateOrDefault(ctx.evaluator,
BraceHasExplicitReturnStmtRequest{this}, false);
}

void BraceStmt::getExplicitReturnStmts(
ASTContext &ctx, SmallVectorImpl<ReturnStmt *> &results) const {
if (!hasExplicitReturnStmt(ctx)) {
return;
}

walkExplicitReturnStmts(this, [&results](ReturnStmt *RS) {
results.push_back(RS);
return false;
});
}

IsSingleValueStmtResult Stmt::mayProduceSingleValue(ASTContext &ctx) const {
return evaluateOrDefault(ctx.evaluator, IsSingleValueStmtRequest{this, &ctx},
IsSingleValueStmtResult::circularReference());
Expand Down
77 changes: 3 additions & 74 deletions lib/Sema/BuilderTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,9 @@ std::optional<BraceStmt *>
TypeChecker::applyResultBuilderBodyTransform(FuncDecl *func, Type builderType) {
// First look for any return statements, and bail if we have any.
auto &ctx = func->getASTContext();
if (auto returnStmts = findReturnStatements(func); !returnStmts.empty()) {
if (SmallVector<ReturnStmt *> returnStmts;
func->getBody()->getExplicitReturnStmts(ctx, returnStmts),
!returnStmts.empty()) {
// One or more explicit 'return' statements were encountered, which
// disables the result builder transform. Warn when we do this.
ctx.Diags.diagnose(
Expand Down Expand Up @@ -1222,79 +1224,6 @@ ConstraintSystem::matchResultBuilder(AnyFunctionRef fn, Type builderType,
return getTypeMatchSuccess();
}

/// Walks the given brace statement and calls the given function reference on
/// every occurrence of an explicit `return` statement.
///
/// \param callback A function reference that takes a `return` statement and
/// returns a boolean value indicating whether to abort the walk.
///
/// \returns `true` if the walk was aborted, `false` otherwise.
static bool walkExplicitReturnStmts(const BraceStmt *BS,
function_ref<bool(ReturnStmt *)> callback) {
class Walker : public ASTWalker {
function_ref<bool(ReturnStmt *)> callback;

public:
Walker(decltype(Walker::callback) callback) : callback(callback) {}

MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Arguments;
}

PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
return Action::SkipNode(E);
}

PreWalkResult<Stmt *> walkToStmtPre(Stmt *S) override {
if (S->isImplicit()) {
return Action::SkipNode(S);
}

auto *returnStmt = dyn_cast<ReturnStmt>(S);
if (!returnStmt) {
return Action::Continue(S);
}

if (callback(returnStmt)) {
return Action::Stop();
}

// Skip children & post walk and continue.
return Action::SkipNode(S);
}

/// Ignore patterns.
PreWalkResult<Pattern *> walkToPatternPre(Pattern *pat) override {
return Action::SkipNode(pat);
}
};

Walker walker(callback);

return const_cast<BraceStmt *>(BS)->walk(walker) == nullptr;
}

bool BraceHasExplicitReturnStmtRequest::evaluate(Evaluator &evaluator,
const BraceStmt *BS) const {
return walkExplicitReturnStmts(BS, [](ReturnStmt *) { return true; });
}

std::vector<ReturnStmt *> TypeChecker::findReturnStatements(AnyFunctionRef fn) {
if (!fn.getBody()->hasExplicitReturnStmt(
fn.getAsDeclContext()->getASTContext())) {
return std::vector<ReturnStmt *>();
}

std::vector<ReturnStmt *> results;

walkExplicitReturnStmts(fn.getBody(), [&results](ReturnStmt *RS) {
results.push_back(RS);
return false;
});

return results;
}

ResultBuilderOpSupport TypeChecker::checkBuilderOpSupport(
Type builderType, DeclContext *dc, Identifier fnName,
ArrayRef<Identifier> argLabels, SmallVectorImpl<ValueDecl *> *allResults) {
Expand Down
3 changes: 2 additions & 1 deletion lib/Sema/CSDiagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8860,7 +8860,8 @@ bool ReferenceToInvalidDeclaration::diagnoseAsError() {
bool InvalidReturnInResultBuilderBody::diagnoseAsError() {
auto *closure = castToExpr<ClosureExpr>(getAnchor());

auto returnStmts = TypeChecker::findReturnStatements(closure);
SmallVector<ReturnStmt *> returnStmts;
closure->getBody()->getExplicitReturnStmts(getASTContext(), returnStmts);
assert(!returnStmts.empty());

auto loc = returnStmts.front()->getReturnLoc();
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/TypeCheckRequestFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ static Type inferResultBuilderType(ValueDecl *decl) {
// Check whether there are any return statements in the function's body.
// If there are, the result builder transform will be disabled,
// so don't infer a result builder.
if (!TypeChecker::findReturnStatements(funcDecl).empty())
if (funcDecl->getBody()->hasExplicitReturnStmt(dc->getASTContext()))
return Type();

// Find all of the potentially inferred result builder types.
Expand Down
3 changes: 1 addition & 2 deletions lib/Sema/TypeChecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,7 @@ void typeCheckASTNode(ASTNode &node, DeclContext *DC,
std::optional<BraceStmt *> applyResultBuilderBodyTransform(FuncDecl *func,
Type builderType);

/// Find the return statements within the body of the given function.
std::vector<ReturnStmt *> findReturnStatements(AnyFunctionRef fn);
bool typeCheckClosureBody(ClosureExpr *closure);

bool typeCheckTapBody(TapExpr *expr, DeclContext *DC);

Expand Down

0 comments on commit 02665c5

Please sign in to comment.