@@ -1010,20 +1010,20 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
10101010 Stmt *visitBraceStmt (BraceStmt *BS);
10111011
10121012 Stmt *visitReturnStmt (ReturnStmt *RS) {
1013- auto TheFunc = AnyFunctionRef::fromDeclContext (DC);
1013+ // First, let's do a pre-check, and bail if the return is completely
1014+ // invalid.
1015+ auto &eval = getASTContext ().evaluator ;
1016+ auto *S =
1017+ evaluateOrDefault (eval, PreCheckReturnStmtRequest{RS, DC}, nullptr );
1018+
1019+ // We do a cast here as it may have been turned into a FailStmt. We should
1020+ // return that without doing anything else.
1021+ RS = dyn_cast_or_null<ReturnStmt>(S);
1022+ if (!RS)
1023+ return S;
10141024
1015- if (!TheFunc.has_value ()) {
1016- getASTContext ().Diags .diagnose (RS->getReturnLoc (),
1017- diag::return_invalid_outside_func);
1018- return nullptr ;
1019- }
1020-
1021- // If the return is in a defer, then it isn't valid either.
1022- if (isInDefer ()) {
1023- getASTContext ().Diags .diagnose (RS->getReturnLoc (),
1024- diag::jump_out_of_defer, " return" );
1025- return nullptr ;
1026- }
1025+ auto TheFunc = AnyFunctionRef::fromDeclContext (DC);
1026+ assert (TheFunc && " Should have bailed from pre-check if this is None" );
10271027
10281028 Type ResultTy = TheFunc->getBodyResultType ();
10291029 if (!ResultTy || ResultTy->hasError ())
@@ -1061,40 +1061,6 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
10611061 }
10621062
10631063 Expr *E = RS->getResult ();
1064-
1065- // In an initializer, the only expression allowed is "nil", which indicates
1066- // failure from a failable initializer.
1067- if (auto ctor = dyn_cast_or_null<ConstructorDecl>(
1068- TheFunc->getAbstractFunctionDecl ())) {
1069- // The only valid return expression in an initializer is the literal
1070- // 'nil'.
1071- auto nilExpr = dyn_cast<NilLiteralExpr>(E->getSemanticsProvidingExpr ());
1072- if (!nilExpr) {
1073- getASTContext ().Diags .diagnose (RS->getReturnLoc (),
1074- diag::return_init_non_nil)
1075- .highlight (E->getSourceRange ());
1076- RS->setResult (nullptr );
1077- return RS;
1078- }
1079-
1080- // "return nil" is only permitted in a failable initializer.
1081- if (!ctor->isFailable ()) {
1082- getASTContext ().Diags .diagnose (RS->getReturnLoc (),
1083- diag::return_non_failable_init)
1084- .highlight (E->getSourceRange ());
1085- getASTContext ().Diags .diagnose (ctor->getLoc (), diag::make_init_failable,
1086- ctor->getName ())
1087- .fixItInsertAfter (ctor->getLoc (), " ?" );
1088- RS->setResult (nullptr );
1089- return RS;
1090- }
1091-
1092- // Replace the "return nil" with a new 'fail' statement.
1093- return new (getASTContext ()) FailStmt (RS->getReturnLoc (),
1094- nilExpr->getLoc (),
1095- RS->isImplicit ());
1096- }
1097-
10981064 TypeCheckExprOptions options = {};
10991065
11001066 if (LeaveBraceStmtBodyUnchecked) {
@@ -1547,6 +1513,62 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
15471513};
15481514} // end anonymous namespace
15491515
1516+ Stmt *PreCheckReturnStmtRequest::evaluate (Evaluator &evaluator, ReturnStmt *RS,
1517+ DeclContext *DC) const {
1518+ auto &ctx = DC->getASTContext ();
1519+ auto fn = AnyFunctionRef::fromDeclContext (DC);
1520+
1521+ // Not valid outside of a function.
1522+ if (!fn) {
1523+ ctx.Diags .diagnose (RS->getReturnLoc (), diag::return_invalid_outside_func);
1524+ return nullptr ;
1525+ }
1526+
1527+ // If the return is in a defer, then it isn't valid either.
1528+ if (isDefer (DC)) {
1529+ ctx.Diags .diagnose (RS->getReturnLoc (), diag::jump_out_of_defer, " return" );
1530+ return nullptr ;
1531+ }
1532+
1533+ // The rest of the checks only concern return statements with results.
1534+ if (!RS->hasResult ())
1535+ return RS;
1536+
1537+ auto *E = RS->getResult ();
1538+
1539+ // In an initializer, the only expression allowed is "nil", which indicates
1540+ // failure from a failable initializer.
1541+ if (auto *ctor =
1542+ dyn_cast_or_null<ConstructorDecl>(fn->getAbstractFunctionDecl ())) {
1543+
1544+ // The only valid return expression in an initializer is the literal
1545+ // 'nil'.
1546+ auto *nilExpr = dyn_cast<NilLiteralExpr>(E->getSemanticsProvidingExpr ());
1547+ if (!nilExpr) {
1548+ ctx.Diags .diagnose (RS->getReturnLoc (), diag::return_init_non_nil)
1549+ .highlight (E->getSourceRange ());
1550+ RS->setResult (nullptr );
1551+ return RS;
1552+ }
1553+
1554+ // "return nil" is only permitted in a failable initializer.
1555+ if (!ctor->isFailable ()) {
1556+ ctx.Diags .diagnose (RS->getReturnLoc (), diag::return_non_failable_init)
1557+ .highlight (E->getSourceRange ());
1558+ ctx.Diags
1559+ .diagnose (ctor->getLoc (), diag::make_init_failable, ctor->getName ())
1560+ .fixItInsertAfter (ctor->getLoc (), " ?" );
1561+ RS->setResult (nullptr );
1562+ return RS;
1563+ }
1564+
1565+ // Replace the "return nil" with a new 'fail' statement.
1566+ return new (ctx)
1567+ FailStmt (RS->getReturnLoc (), nilExpr->getLoc (), RS->isImplicit ());
1568+ }
1569+ return RS;
1570+ }
1571+
15501572static bool isDiscardableType (Type type) {
15511573 return (type->hasError () ||
15521574 type->isUninhabited () ||
0 commit comments