diff --git a/include/swift/AST/Pattern.h b/include/swift/AST/Pattern.h index f85ff44d8ae06..c49241b86af7c 100644 --- a/include/swift/AST/Pattern.h +++ b/include/swift/AST/Pattern.h @@ -508,25 +508,29 @@ class EnumElementPattern : public Pattern { DeclNameRef Name; PointerUnion ElementDeclOrUnresolvedOriginalExpr; Pattern /*nullable*/ *SubPattern; + DeclContext *DC; public: EnumElementPattern(TypeExpr *ParentType, SourceLoc DotLoc, DeclNameLoc NameLoc, DeclNameRef Name, - EnumElementDecl *Element, Pattern *SubPattern) + EnumElementDecl *Element, Pattern *SubPattern, + DeclContext *DC) : Pattern(PatternKind::EnumElement), ParentType(ParentType), DotLoc(DotLoc), NameLoc(NameLoc), Name(Name), - ElementDeclOrUnresolvedOriginalExpr(Element), SubPattern(SubPattern) { + ElementDeclOrUnresolvedOriginalExpr(Element), SubPattern(SubPattern), + DC(DC) { assert(ParentType && "Missing parent type?"); } /// Create an unresolved EnumElementPattern for a `.foo` pattern relying on /// contextual type. EnumElementPattern(SourceLoc DotLoc, DeclNameLoc NameLoc, DeclNameRef Name, - Pattern *SubPattern, Expr *UnresolvedOriginalExpr) + Pattern *SubPattern, Expr *UnresolvedOriginalExpr, + DeclContext *DC) : Pattern(PatternKind::EnumElement), ParentType(nullptr), DotLoc(DotLoc), NameLoc(NameLoc), Name(Name), ElementDeclOrUnresolvedOriginalExpr(UnresolvedOriginalExpr), - SubPattern(SubPattern) {} + SubPattern(SubPattern), DC(DC) {} bool hasSubPattern() const { return SubPattern; } @@ -540,6 +544,8 @@ class EnumElementPattern : public Pattern { void setSubPattern(Pattern *p) { SubPattern = p; } + DeclContext *getDeclContext() const { return DC; } + DeclNameRef getName() const { return Name; } EnumElementDecl *getElementDecl() const { @@ -647,43 +653,59 @@ class OptionalSomePattern : public Pattern { class ExprPattern : public Pattern { llvm::PointerIntPair SubExprAndIsResolved; - /// An expression constructed during type-checking that produces a call to the - /// '~=' operator comparing the match expression on the left to the matched - /// value on the right. - Expr *MatchExpr = nullptr; + DeclContext *DC; + + /// A synthesized call to the '~=' operator comparing the match expression + /// on the left to the matched value on the right. + mutable Expr *MatchExpr = nullptr; + + /// An implicit variable used to represent the RHS value of the synthesized + /// match expression. + mutable VarDecl *MatchVar = nullptr; - /// An implicit variable used to represent the RHS value of the match. - VarDecl *MatchVar = nullptr; + ExprPattern(Expr *E, DeclContext *DC, bool isResolved) + : Pattern(PatternKind::Expr), SubExprAndIsResolved(E, isResolved), + DC(DC) {} - ExprPattern(Expr *E, bool isResolved) - : Pattern(PatternKind::Expr), SubExprAndIsResolved(E, isResolved) {} + friend class ExprPatternMatchRequest; public: /// Create a new parsed unresolved ExprPattern. - static ExprPattern *createParsed(ASTContext &ctx, Expr *E); + static ExprPattern *createParsed(ASTContext &ctx, Expr *E, DeclContext *DC); /// Create a new resolved ExprPattern. This should be used in cases /// where a user-written expression should be treated as an ExprPattern. - static ExprPattern *createResolved(ASTContext &ctx, Expr *E); + static ExprPattern *createResolved(ASTContext &ctx, Expr *E, DeclContext *DC); /// Create a new implicit resolved ExprPattern. - static ExprPattern *createImplicit(ASTContext &ctx, Expr *E); + static ExprPattern *createImplicit(ASTContext &ctx, Expr *E, DeclContext *DC); Expr *getSubExpr() const { return SubExprAndIsResolved.getPointer(); } void setSubExpr(Expr *e) { SubExprAndIsResolved.setPointer(e); } - Expr *getMatchExpr() const { return MatchExpr; } + DeclContext *getDeclContext() const { return DC; } + + /// The match expression if it has been computed, \c nullptr otherwise. + /// Should only be used by the ASTDumper and ASTWalker. + Expr *getCachedMatchExpr() const { return MatchExpr; } + + /// The match variable if it has been computed, \c nullptr otherwise. + /// Should only be used by the ASTDumper and ASTWalker. + VarDecl *getCachedMatchVar() const { return MatchVar; } + + /// A synthesized call to the '~=' operator comparing the match expression + /// on the left to the matched value on the right. + Expr *getMatchExpr() const; + + /// An implicit variable used to represent the RHS value of the synthesized + /// match expression. + VarDecl *getMatchVar() const; + void setMatchExpr(Expr *e) { - assert(isResolved() && "cannot set match fn for unresolved expr patter"); + assert(MatchExpr && "Should only update an existing MatchExpr"); MatchExpr = e; } - VarDecl *getMatchVar() const { return MatchVar; } - void setMatchVar(VarDecl *v) { - assert(isResolved() && "cannot set match var for unresolved expr patter"); - MatchVar = v; - } - SourceLoc getLoc() const; SourceRange getSourceRange() const; @@ -851,6 +873,9 @@ class ContextualPattern { }; void simple_display(llvm::raw_ostream &out, const ContextualPattern &pattern); +void simple_display(llvm::raw_ostream &out, const Pattern *pattern); + +SourceLoc extractNearestSourceLoc(const Pattern *pattern); } // end namespace swift diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index 3d786d92c4dfd..33027a7556ac1 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -2219,6 +2219,74 @@ class NamingPatternRequest void cacheResult(NamedPattern *P) const; }; +class ExprPatternMatchResult { + VarDecl *MatchVar; + Expr *MatchExpr; + + friend class ExprPattern; + + // Should only be used as the default value for the request, as the caching + // logic assumes the request always produces a non-null result. + ExprPatternMatchResult(llvm::NoneType) + : MatchVar(nullptr), MatchExpr(nullptr) {} + +public: + ExprPatternMatchResult(VarDecl *matchVar, Expr *matchExpr) + : MatchVar(matchVar), MatchExpr(matchExpr) { + // Note the caching logic currently requires the request to produce a + // non-null result. + assert(matchVar && matchExpr); + } + + VarDecl *getMatchVar() const { return MatchVar; } + Expr *getMatchExpr() const { return MatchExpr; } +}; + +/// Compute the match VarDecl and expression for an ExprPattern, which applies +/// the \c ~= operator. +class ExprPatternMatchRequest + : public SimpleRequest { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + ExprPatternMatchResult evaluate(Evaluator &evaluator, + const ExprPattern *EP) const; + +public: + // Separate caching. + bool isCached() const { return true; } + Optional getCachedResult() const; + void cacheResult(ExprPatternMatchResult result) const; +}; + +/// Creates a corresponding ExprPattern from the original Expr of an +/// EnumElementPattern. This needs to be a cached request to ensure we don't +/// generate multiple ExprPatterns along different constraint solver paths. +class EnumElementExprPatternRequest + : public SimpleRequest { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + ExprPattern *evaluate(Evaluator &evaluator, + const EnumElementPattern *EEP) const; + +public: + // Cached. + bool isCached() const { return true; } +}; + class InterfaceTypeRequest : public SimpleRequest + getTargetFor(SyntacticElementTargetKey key) const { + auto known = targets.find(key); + if (known == targets.end()) + return None; + return known->second; + } + ConstraintLocator *getCalleeLocator(ConstraintLocator *locator, bool lookThroughApply = true) const; diff --git a/lib/AST/ASTDumper.cpp b/lib/AST/ASTDumper.cpp index 63c6c286436c8..14d70ba5c24cc 100644 --- a/lib/AST/ASTDumper.cpp +++ b/lib/AST/ASTDumper.cpp @@ -472,7 +472,7 @@ namespace { void visitExprPattern(ExprPattern *P) { printCommon(P, "pattern_expr"); OS << '\n'; - if (auto m = P->getMatchExpr()) + if (auto m = P->getCachedMatchExpr()) printRec(m); else printRec(P->getSubExpr()); diff --git a/lib/AST/ASTWalker.cpp b/lib/AST/ASTWalker.cpp index 52cc1e3c335f2..52f7016bc37d6 100644 --- a/lib/AST/ASTWalker.cpp +++ b/lib/AST/ASTWalker.cpp @@ -1980,8 +1980,8 @@ Pattern *Traversal::visitEnumElementPattern(EnumElementPattern *P) { Pattern *Traversal::visitExprPattern(ExprPattern *P) { // If the pattern has been type-checked, walk the match expression, which // includes the explicit subexpression. - if (P->getMatchExpr()) { - if (Expr *newMatch = doIt(P->getMatchExpr())) { + if (auto *match = P->getCachedMatchExpr()) { + if (Expr *newMatch = doIt(match)) { P->setMatchExpr(newMatch); return P; } diff --git a/lib/AST/Pattern.cpp b/lib/AST/Pattern.cpp index c5e4fd87462cb..4ec922e05c4ee 100644 --- a/lib/AST/Pattern.cpp +++ b/lib/AST/Pattern.cpp @@ -19,6 +19,7 @@ #include "swift/AST/ASTWalker.h" #include "swift/AST/Expr.h" #include "swift/AST/GenericEnvironment.h" +#include "swift/AST/TypeCheckRequests.h" #include "swift/AST/TypeLoc.h" #include "swift/AST/TypeRepr.h" #include "swift/Basic/Statistic.h" @@ -333,9 +334,8 @@ EnumElementDecl *OptionalSomePattern::getElementDecl() const { /// Return true if this is a non-resolved ExprPattern which is syntactically /// irrefutable. static bool isIrrefutableExprPattern(const ExprPattern *EP) { - // If the pattern has a registered match expression, it's - // a type-checked ExprPattern. - if (EP->getMatchExpr()) return false; + // If the pattern is resolved, it must be irrefutable. + if (EP->isResolved()) return false; auto expr = EP->getSubExpr(); while (true) { @@ -501,20 +501,35 @@ void IsPattern::setCastType(Type type) { TypeRepr *IsPattern::getCastTypeRepr() const { return CastType->getTypeRepr(); } -ExprPattern *ExprPattern::createParsed(ASTContext &ctx, Expr *E) { - return new (ctx) ExprPattern(E, /*isResolved*/ false); +ExprPattern *ExprPattern::createParsed(ASTContext &ctx, Expr *E, + DeclContext *DC) { + return new (ctx) ExprPattern(E, DC, /*isResolved*/ false); } -ExprPattern *ExprPattern::createResolved(ASTContext &ctx, Expr *E) { - return new (ctx) ExprPattern(E, /*isResolved*/ true); +ExprPattern *ExprPattern::createResolved(ASTContext &ctx, Expr *E, + DeclContext *DC) { + return new (ctx) ExprPattern(E, DC, /*isResolved*/ true); } -ExprPattern *ExprPattern::createImplicit(ASTContext &ctx, Expr *E) { - auto *EP = ExprPattern::createResolved(ctx, E); +ExprPattern *ExprPattern::createImplicit(ASTContext &ctx, Expr *E, + DeclContext *DC) { + auto *EP = ExprPattern::createResolved(ctx, E, DC); EP->setImplicit(); return EP; } +Expr *ExprPattern::getMatchExpr() const { + auto &eval = DC->getASTContext().evaluator; + return evaluateOrDefault(eval, ExprPatternMatchRequest{this}, None) + .getMatchExpr(); +} + +VarDecl *ExprPattern::getMatchVar() const { + auto &eval = DC->getASTContext().evaluator; + return evaluateOrDefault(eval, ExprPatternMatchRequest{this}, None) + .getMatchVar(); +} + SourceLoc EnumElementPattern::getStartLoc() const { return (ParentType && !ParentType->isImplicit()) ? ParentType->getSourceRange().Start @@ -616,5 +631,13 @@ bool ContextualPattern::allowsInference() const { void swift::simple_display(llvm::raw_ostream &out, const ContextualPattern &pattern) { - out << "(pattern @ " << pattern.getPattern() << ")"; + simple_display(out, pattern.getPattern()); +} + +void swift::simple_display(llvm::raw_ostream &out, const Pattern *pattern) { + out << "(pattern @ " << pattern << ")"; +} + +SourceLoc swift::extractNearestSourceLoc(const Pattern *pattern) { + return pattern->getLoc(); } diff --git a/lib/AST/TypeCheckRequests.cpp b/lib/AST/TypeCheckRequests.cpp index 27163193e90ee..15982b594923f 100644 --- a/lib/AST/TypeCheckRequests.cpp +++ b/lib/AST/TypeCheckRequests.cpp @@ -1011,6 +1011,25 @@ void NamingPatternRequest::cacheResult(NamedPattern *value) const { VD->NamingPattern = value; } +//----------------------------------------------------------------------------// +// ExprPatternMatchRequest caching. +//----------------------------------------------------------------------------// + +Optional +ExprPatternMatchRequest::getCachedResult() const { + auto *EP = std::get<0>(getStorage()); + if (!EP->MatchVar) + return None; + + return ExprPatternMatchResult(EP->MatchVar, EP->MatchExpr); +} + +void ExprPatternMatchRequest::cacheResult(ExprPatternMatchResult result) const { + auto *EP = std::get<0>(getStorage()); + EP->MatchVar = result.getMatchVar(); + EP->MatchExpr = result.getMatchExpr(); +} + //----------------------------------------------------------------------------// // InterfaceTypeRequest computation. //----------------------------------------------------------------------------// diff --git a/lib/IDE/ArgumentCompletion.cpp b/lib/IDE/ArgumentCompletion.cpp index c615683d1fc01..d022685e3b453 100644 --- a/lib/IDE/ArgumentCompletion.cpp +++ b/lib/IDE/ArgumentCompletion.cpp @@ -95,19 +95,18 @@ static bool isExpressionResultTypeUnconstrained(const Solution &S, Expr *E) { return true; } } - auto targetIt = S.targets.find(E); - if (targetIt == S.targets.end()) { + auto target = S.getTargetFor(E); + if (!target) return false; - } - auto target = targetIt->second; - assert(target.kind == SyntacticElementTarget::Kind::expression); - switch (target.getExprContextualTypePurpose()) { + + assert(target->kind == SyntacticElementTarget::Kind::expression); + switch (target->getExprContextualTypePurpose()) { case CTP_Unused: // If we aren't using the contextual type, its unconstrained by definition. return true; case CTP_Initialization: { // let x = is unconstrained - auto contextualType = target.getExprContextualType(); + auto contextualType = target->getExprContextualType(); return !contextualType || contextualType->is(); } default: diff --git a/lib/IDE/ExprContextAnalysis.cpp b/lib/IDE/ExprContextAnalysis.cpp index 585b97a5ed808..2aac61940a841 100644 --- a/lib/IDE/ExprContextAnalysis.cpp +++ b/lib/IDE/ExprContextAnalysis.cpp @@ -1185,11 +1185,15 @@ class ExprContextAnalyzer { switch (P->getKind()) { case PatternKind::Expr: { auto ExprPat = cast(P); - if (auto D = ExprPat->getMatchVar()) { - if (D->hasInterfaceType()) - recordPossibleType( - D->getDeclContext()->mapTypeIntoContext(D->getInterfaceType())); - } + if (!ExprPat->isResolved()) + break; + + auto D = ExprPat->getMatchVar(); + if (!D || !D->hasInterfaceType()) + break; + + auto *DC = D->getDeclContext(); + recordPossibleType(DC->mapTypeIntoContext(D->getInterfaceType())); break; } default: diff --git a/lib/IDE/PostfixCompletion.cpp b/lib/IDE/PostfixCompletion.cpp index e6bda2940cb59..7fae130f0a0e1 100644 --- a/lib/IDE/PostfixCompletion.cpp +++ b/lib/IDE/PostfixCompletion.cpp @@ -52,11 +52,9 @@ getClosureActorIsolation(const Solution &S, AbstractClosureExpr *ACE) { // the contextual type might have a global actor attribute but because no // methods from that global actor are called in the closure, the closure has // a non-actor type. - auto target = S.targets.find(dyn_cast(E)); - if (target != S.targets.end()) { - if (auto Ty = target->second.getClosureContextualType()) { + if (auto target = S.getTargetFor(dyn_cast(E))) { + if (auto Ty = target->getClosureContextualType()) return Ty; - } } if (!S.hasType(E)) { return Type(); diff --git a/lib/IDE/TypeCheckCompletionCallback.cpp b/lib/IDE/TypeCheckCompletionCallback.cpp index 618f0bc05d30c..17198bb4e5950 100644 --- a/lib/IDE/TypeCheckCompletionCallback.cpp +++ b/lib/IDE/TypeCheckCompletionCallback.cpp @@ -165,9 +165,8 @@ bool swift::ide::isContextAsync(const constraints::Solution &S, // closure that doesn't contain any async calles. Thus the closure is // type-checked as non-async, but it might get converted to an async // closure based on its contextual type - auto target = S.targets.find(dyn_cast(DC)); - if (target != S.targets.end()) { - if (auto ContextTy = target->second.getClosureContextualType()) { + if (auto target = S.getTargetFor(dyn_cast(DC))) { + if (auto ContextTy = target->getClosureContextualType()) { if (auto ContextFuncTy = S.simplifyType(ContextTy)->getAs()) { return ContextFuncTy->isAsync(); diff --git a/lib/Parse/ParsePattern.cpp b/lib/Parse/ParsePattern.cpp index 1f636d4f998df..4b81c4c8abec6 100644 --- a/lib/Parse/ParsePattern.cpp +++ b/lib/Parse/ParsePattern.cpp @@ -1339,7 +1339,7 @@ ParserResult Parser::parseMatchingPattern(bool isExprBasic) { if (auto *UPE = dyn_cast(subExpr.get())) return makeParserResult(status, UPE->getSubPattern()); - auto *EP = ExprPattern::createParsed(Context, subExpr.get()); + auto *EP = ExprPattern::createParsed(Context, subExpr.get(), CurDeclContext); return makeParserResult(status, EP); } diff --git a/lib/Parse/ParseStmt.cpp b/lib/Parse/ParseStmt.cpp index a7e2e2f4ec511..e4ae0ebbb03bb 100644 --- a/lib/Parse/ParseStmt.cpp +++ b/lib/Parse/ParseStmt.cpp @@ -1095,7 +1095,8 @@ static void parseGuardedPattern(Parser &P, GuardedPattern &result, // Do some special-case code completion for the start of the pattern. if (P.Tok.is(tok::code_complete)) { auto CCE = new (P.Context) CodeCompletionExpr(P.Tok.getLoc()); - result.ThePattern = ExprPattern::createParsed(P.Context, CCE); + result.ThePattern = + ExprPattern::createParsed(P.Context, CCE, P.CurDeclContext); if (P.IDECallbacks) { switch (parsingContext) { case GuardedPatternContext::Case: diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index 52235041ef2b9..20443f0e513ea 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -9928,29 +9928,15 @@ static bool inferEnumMemberThroughTildeEqualsOperator( if (!pattern->hasUnresolvedOriginalExpr()) return true; - auto &DC = cs.DC; + auto *DC = pattern->getDeclContext(); auto &ctx = cs.getASTContext(); - // Slots for expression and variable are going to be filled via - // synthesizing ~= operator application. + // Retrieve a corresponding ExprPattern which we can solve with ~=. auto *EP = - ExprPattern::createResolved(ctx, pattern->getUnresolvedOriginalExpr()); - - auto tildeEqualsApplication = - TypeChecker::synthesizeTildeEqualsOperatorApplication(EP, DC, enumTy); - - if (!tildeEqualsApplication) - return true; - - VarDecl *matchVar; - Expr *matchCall; - - std::tie(matchVar, matchCall) = *tildeEqualsApplication; - - cs.setType(matchVar, enumTy); - cs.setType(EP, enumTy); + llvm::cantFail(ctx.evaluator(EnumElementExprPatternRequest{pattern})); // result of ~= operator is always a `Bool`. + auto *matchCall = EP->getMatchExpr(); auto target = SyntacticElementTarget::forExprPattern( matchCall, DC, EP, ctx.getBoolDecl()->getDeclaredInterfaceType()); @@ -9967,6 +9953,8 @@ static bool inferEnumMemberThroughTildeEqualsOperator( return true; } } + cs.setType(EP->getMatchVar(), enumTy); + cs.setType(EP, enumTy); if (cs.generateConstraints(target)) return true; @@ -9976,12 +9964,7 @@ static bool inferEnumMemberThroughTildeEqualsOperator( cs.addConstraint(ConstraintKind::Conversion, cs.getType(EP->getSubExpr()), elementTy, cs.getConstraintLocator(EP)); - // Store the $match variable and binary expression for solution application. - EP->setMatchVar(matchVar); - EP->setMatchExpr(matchCall); - cs.setTargetFor(pattern, target); - return false; } diff --git a/lib/Sema/DerivedConformanceCodable.cpp b/lib/Sema/DerivedConformanceCodable.cpp index 0e6925ae076b5..1a393655be8cd 100644 --- a/lib/Sema/DerivedConformanceCodable.cpp +++ b/lib/Sema/DerivedConformanceCodable.cpp @@ -976,7 +976,7 @@ createEnumSwitch(ASTContext &C, DeclContext *DC, Expr *expr, EnumDecl *enumDecl, DC->mapTypeIntoContext( targetElt->getParentEnum()->getDeclaredInterfaceType()), C), - SourceLoc(), DeclNameLoc(), DeclNameRef(), targetElt, subpattern); + SourceLoc(), DeclNameLoc(), DeclNameRef(), targetElt, subpattern, DC); pat->setImplicit(); auto labelItem = CaseLabelItem(pat); diff --git a/lib/Sema/DerivedConformanceCodingKey.cpp b/lib/Sema/DerivedConformanceCodingKey.cpp index fc6b473bc229e..b3509b15e944d 100644 --- a/lib/Sema/DerivedConformanceCodingKey.cpp +++ b/lib/Sema/DerivedConformanceCodingKey.cpp @@ -214,7 +214,8 @@ deriveBodyCodingKey_enum_stringValue(AbstractFunctionDecl *strValDecl, void *) { for (auto *elt : elements) { auto *baseTE = TypeExpr::createImplicit(enumType, C); auto *pat = new (C) EnumElementPattern(baseTE, SourceLoc(), DeclNameLoc(), - DeclNameRef(), elt, nullptr); + DeclNameRef(), elt, nullptr, + /*DC*/ strValDecl); pat->setImplicit(); auto labelItem = CaseLabelItem(pat); @@ -277,7 +278,7 @@ deriveBodyCodingKey_init_stringValue(AbstractFunctionDecl *initDecl, void *) { for (auto *elt : elements) { auto *litExpr = new (C) StringLiteralExpr(elt->getNameStr(), SourceRange(), /*Implicit=*/true); - auto *litPat = ExprPattern::createImplicit(C, litExpr); + auto *litPat = ExprPattern::createImplicit(C, litExpr, /*DC*/ initDecl); auto labelItem = CaseLabelItem(litPat); diff --git a/lib/Sema/DerivedConformanceComparable.cpp b/lib/Sema/DerivedConformanceComparable.cpp index 46aca385093a9..84ac1ec947284 100644 --- a/lib/Sema/DerivedConformanceComparable.cpp +++ b/lib/Sema/DerivedConformanceComparable.cpp @@ -116,9 +116,9 @@ deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, v auto lhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'l', ltDecl, lhsPayloadVars); auto *lhsBaseTE = TypeExpr::createImplicit(enumType, C); - auto lhsElemPat = - new (C) EnumElementPattern(lhsBaseTE, SourceLoc(), DeclNameLoc(), - DeclNameRef(), elt, lhsSubpattern); + auto lhsElemPat = new (C) + EnumElementPattern(lhsBaseTE, SourceLoc(), DeclNameLoc(), DeclNameRef(), + elt, lhsSubpattern, /*DC*/ ltDecl); lhsElemPat->setImplicit(); // .(let r0, let r1, ...) @@ -126,9 +126,9 @@ deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, v auto rhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'r', ltDecl, rhsPayloadVars); auto *rhsBaseTE = TypeExpr::createImplicit(enumType, C); - auto rhsElemPat = - new (C) EnumElementPattern(rhsBaseTE, SourceLoc(), DeclNameLoc(), - DeclNameRef(), elt, rhsSubpattern); + auto rhsElemPat = new (C) + EnumElementPattern(rhsBaseTE, SourceLoc(), DeclNameLoc(), DeclNameRef(), + elt, rhsSubpattern, /*DC*/ ltDecl); rhsElemPat->setImplicit(); auto hasBoundDecls = !lhsPayloadVars.empty(); diff --git a/lib/Sema/DerivedConformanceEquatableHashable.cpp b/lib/Sema/DerivedConformanceEquatableHashable.cpp index 6e8bca934c92a..4c88cffb24fa2 100644 --- a/lib/Sema/DerivedConformanceEquatableHashable.cpp +++ b/lib/Sema/DerivedConformanceEquatableHashable.cpp @@ -181,9 +181,9 @@ deriveBodyEquatable_enum_hasAssociatedValues_eq(AbstractFunctionDecl *eqDecl, auto lhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'l', eqDecl, lhsPayloadVars); auto *lhsBaseTE = TypeExpr::createImplicit(enumType, C); - auto lhsElemPat = - new (C) EnumElementPattern(lhsBaseTE, SourceLoc(), DeclNameLoc(), - DeclNameRef(), elt, lhsSubpattern); + auto lhsElemPat = new (C) + EnumElementPattern(lhsBaseTE, SourceLoc(), DeclNameLoc(), DeclNameRef(), + elt, lhsSubpattern, /*DC*/ eqDecl); lhsElemPat->setImplicit(); // .(let r0, let r1, ...) @@ -191,9 +191,9 @@ deriveBodyEquatable_enum_hasAssociatedValues_eq(AbstractFunctionDecl *eqDecl, auto rhsSubpattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'r', eqDecl, rhsPayloadVars); auto *rhsBaseTE = TypeExpr::createImplicit(enumType, C); - auto rhsElemPat = - new (C) EnumElementPattern(rhsBaseTE, SourceLoc(), DeclNameLoc(), - DeclNameRef(), elt, rhsSubpattern); + auto rhsElemPat = new (C) + EnumElementPattern(rhsBaseTE, SourceLoc(), DeclNameLoc(), DeclNameRef(), + elt, rhsSubpattern, /*DC*/ eqDecl); rhsElemPat->setImplicit(); auto hasBoundDecls = !lhsPayloadVars.empty(); @@ -702,9 +702,10 @@ deriveBodyHashable_enum_hasAssociatedValues_hashInto( auto payloadPattern = DerivedConformance::enumElementPayloadSubpattern(elt, 'a', hashIntoDecl, payloadVars); - auto pat = new (C) EnumElementPattern( - TypeExpr::createImplicit(enumType, C), SourceLoc(), DeclNameLoc(), - DeclNameRef(elt->getBaseIdentifier()), elt, payloadPattern); + auto pat = new (C) + EnumElementPattern(TypeExpr::createImplicit(enumType, C), SourceLoc(), + DeclNameLoc(), DeclNameRef(elt->getBaseIdentifier()), + elt, payloadPattern, /*DC*/ hashIntoDecl); pat->setImplicit(); auto labelItem = CaseLabelItem(pat); diff --git a/lib/Sema/DerivedConformanceRawRepresentable.cpp b/lib/Sema/DerivedConformanceRawRepresentable.cpp index 3f669280b0b65..35ce04e82c806 100644 --- a/lib/Sema/DerivedConformanceRawRepresentable.cpp +++ b/lib/Sema/DerivedConformanceRawRepresentable.cpp @@ -114,7 +114,8 @@ deriveBodyRawRepresentable_raw(AbstractFunctionDecl *toRawDecl, void *) { for (auto elt : enumDecl->getAllElements()) { auto pat = new (C) EnumElementPattern(TypeExpr::createImplicit(enumType, C), SourceLoc(), - DeclNameLoc(), DeclNameRef(), elt, nullptr); + DeclNameLoc(), DeclNameRef(), elt, nullptr, + /*DC*/ toRawDecl); pat->setImplicit(); auto labelItem = CaseLabelItem(pat); @@ -323,7 +324,7 @@ deriveBodyRawRepresentable_init(AbstractFunctionDecl *initDecl, void *) { stringExprs.push_back(litExpr); litExpr = IntegerLiteralExpr::createFromUnsigned(C, Idx, SourceLoc()); } - auto *litPat = ExprPattern::createImplicit(C, litExpr); + auto *litPat = ExprPattern::createImplicit(C, litExpr, /*DC*/ initDecl); /// Statements in the body of this case. SmallVector stmts; diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index a51476c9ded10..fc5389c82522f 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -737,7 +737,8 @@ DeclRefExpr *DerivedConformance::convertEnumToIndex(SmallVectorImpl &st // generate: case .: auto pat = new (C) EnumElementPattern(TypeExpr::createImplicit(enumType, C), SourceLoc(), - DeclNameLoc(), DeclNameRef(), elt, nullptr); + DeclNameLoc(), DeclNameRef(), elt, nullptr, + /*DC*/ funcDecl); pat->setImplicit(); pat->setType(enumType); diff --git a/lib/Sema/TypeCheckConstraints.cpp b/lib/Sema/TypeCheckConstraints.cpp index 46a4619c906d0..552d76b59269b 100644 --- a/lib/Sema/TypeCheckConstraints.cpp +++ b/lib/Sema/TypeCheckConstraints.cpp @@ -928,20 +928,10 @@ bool TypeChecker::typeCheckExprPattern(ExprPattern *EP, DeclContext *DC, "typecheck-expr-pattern", EP); PrettyStackTracePattern stackTrace(Context, "type-checking", EP); - auto tildeEqualsApplication = - synthesizeTildeEqualsOperatorApplication(EP, DC, rhsType); - - if (!tildeEqualsApplication) - return true; - - VarDecl *matchVar; - Expr *matchCall; - - std::tie(matchVar, matchCall) = *tildeEqualsApplication; - - matchVar->setInterfaceType(rhsType->mapTypeOutOfContext()); + EP->getMatchVar()->setInterfaceType(rhsType->mapTypeOutOfContext()); // Result of `~=` should always be a boolean. + auto *matchCall = EP->getMatchExpr(); auto contextualTy = Context.getBoolDecl()->getDeclaredInterfaceType(); auto target = SyntacticElementTarget::forExprPattern(matchCall, DC, EP, contextualTy); @@ -951,8 +941,6 @@ bool TypeChecker::typeCheckExprPattern(ExprPattern *EP, DeclContext *DC, if (!result) return true; - // Save the synthesized $match variable in the pattern. - EP->setMatchVar(matchVar); // Save the type-checked expression in the pattern. EP->setMatchExpr(result->getAsExpr()); // Set the type on the pattern. @@ -960,56 +948,6 @@ bool TypeChecker::typeCheckExprPattern(ExprPattern *EP, DeclContext *DC, return false; } -Optional> -TypeChecker::synthesizeTildeEqualsOperatorApplication(ExprPattern *EP, - DeclContext *DC, - Type enumType) { - auto &Context = DC->getASTContext(); - // Create a 'let' binding to stand in for the RHS value. - auto *matchVar = - new (Context) VarDecl(/*IsStatic*/ false, VarDecl::Introducer::Let, - EP->getLoc(), Context.Id_PatternMatchVar, DC); - - matchVar->setImplicit(); - - // Find '~=' operators for the match. - auto matchLookup = - lookupUnqualified(DC->getModuleScopeContext(), - DeclNameRef(Context.Id_MatchOperator), - SourceLoc(), defaultUnqualifiedLookupOptions); - auto &diags = DC->getASTContext().Diags; - if (!matchLookup) { - diags.diagnose(EP->getLoc(), diag::no_match_operator); - return None; - } - - SmallVector choices; - for (auto &result : matchLookup) { - choices.push_back(result.getValueDecl()); - } - - if (choices.empty()) { - diags.diagnose(EP->getLoc(), diag::no_match_operator); - return None; - } - - // Build the 'expr ~= var' expression. - // FIXME: Compound name locations. - auto *matchOp = - TypeChecker::buildRefExpr(choices, DC, DeclNameLoc(EP->getLoc()), - /*Implicit=*/true, FunctionRefKind::Compound); - - // Note we use getEndLoc here to have the BinaryExpr source range be the same - // as the expr pattern source range. - auto *matchVarRef = new (Context) DeclRefExpr(matchVar, - DeclNameLoc(EP->getEndLoc()), - /*Implicit=*/true); - auto *matchCall = BinaryExpr::create(Context, EP->getSubExpr(), matchOp, - matchVarRef, /*implicit*/ true); - - return std::make_pair(matchVar, matchCall); -} - static Type replaceArchetypesWithTypeVariables(ConstraintSystem &cs, Type t) { llvm::DenseMap types; diff --git a/lib/Sema/TypeCheckPattern.cpp b/lib/Sema/TypeCheckPattern.cpp index 66857b54149f2..9b15982e35b66 100644 --- a/lib/Sema/TypeCheckPattern.cpp +++ b/lib/Sema/TypeCheckPattern.cpp @@ -244,7 +244,7 @@ class ResolvePattern : public ASTVisitorgetName().getBaseName().isSpecial()) return nullptr; - return new (Context) - EnumElementPattern(ume->getDotLoc(), ume->getNameLoc(), ume->getName(), - nullptr, ume); + return new (Context) EnumElementPattern(ume->getDotLoc(), ume->getNameLoc(), + ume->getName(), nullptr, ume, DC); } // Member syntax 'T.Element' forms a pattern if 'T' is an enum and the @@ -483,7 +482,7 @@ class ResolvePattern : public ASTVisitorsetType(MetatypeType::get(ty)); return new (Context) EnumElementPattern(base, ude->getDotLoc(), ude->getNameLoc(), - ude->getName(), referencedElement, nullptr); + ude->getName(), referencedElement, nullptr, DC); } // A DeclRef 'E' that refers to an enum element forms an EnumElementPattern. @@ -496,8 +495,9 @@ class ResolvePattern : public ASTVisitorgetParentEnum()->getDeclaredTypeInContext(); auto *base = TypeExpr::createImplicit(enumTy, Context); - return new (Context) EnumElementPattern(base, SourceLoc(), de->getNameLoc(), - elt->createNameRef(), elt, nullptr); + return new (Context) + EnumElementPattern(base, SourceLoc(), de->getNameLoc(), + elt->createNameRef(), elt, nullptr, DC); } Pattern *visitUnresolvedDeclRefExpr(UnresolvedDeclRefExpr *ude) { // FIXME: This shouldn't be needed. It is only necessary because of the @@ -514,7 +514,7 @@ class ResolvePattern : public ASTVisitorgetNameLoc(), - ude->getName(), referencedElement, nullptr); + ude->getName(), referencedElement, nullptr, DC); } @@ -614,7 +614,7 @@ class ResolvePattern : public ASTVisitorgetArgs()); return new (Context) EnumElementPattern( baseTE, SourceLoc(), tailComponent->getNameLoc(), - tailComponent->getNameRef(), referencedElement, subPattern); + tailComponent->getNameRef(), referencedElement, subPattern, DC); } }; @@ -761,6 +761,45 @@ static TypeResolutionOptions applyContextualPatternOptions( return options; } +ExprPatternMatchResult +ExprPatternMatchRequest::evaluate(Evaluator &evaluator, + const ExprPattern *EP) const { + assert(EP->isResolved() && "Must only be queried once resolved"); + + auto *DC = EP->getDeclContext(); + auto &ctx = DC->getASTContext(); + + // Create a 'let' binding to stand in for the RHS value. + auto *matchVar = + new (ctx) VarDecl(/*IsStatic*/ false, VarDecl::Introducer::Let, + EP->getLoc(), ctx.Id_PatternMatchVar, DC); + matchVar->setImplicit(); + + // Build the 'expr ~= var' expression. + auto *matchOp = new (ctx) UnresolvedDeclRefExpr( + DeclNameRef(ctx.Id_MatchOperator), DeclRefKind::BinaryOperator, + DeclNameLoc(EP->getLoc())); + matchOp->setImplicit(); + + // Note we use getEndLoc here to have the BinaryExpr source range be the same + // as the expr pattern source range. + auto *matchVarRef = + new (ctx) DeclRefExpr(matchVar, DeclNameLoc(EP->getEndLoc()), + /*Implicit=*/true); + auto *matchCall = BinaryExpr::create(ctx, EP->getSubExpr(), matchOp, + matchVarRef, /*implicit*/ true); + return {matchVar, matchCall}; +} + +ExprPattern * +EnumElementExprPatternRequest::evaluate(Evaluator &evaluator, + const EnumElementPattern *EEP) const { + assert(EEP->hasUnresolvedOriginalExpr()); + auto *DC = EEP->getDeclContext(); + return ExprPattern::createResolved(DC->getASTContext(), + EEP->getUnresolvedOriginalExpr(), DC); +} + Type PatternTypeRequest::evaluate(Evaluator &evaluator, ContextualPattern pattern) const { Pattern *P = pattern.getPattern(); @@ -1270,7 +1309,7 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern, auto *BaseTE = TypeExpr::createImplicit(type, Context); P = new (Context) EnumElementPattern( BaseTE, NLE->getLoc(), DeclNameLoc(NLE->getLoc()), - NoneEnumElement->createNameRef(), NoneEnumElement, nullptr); + NoneEnumElement->createNameRef(), NoneEnumElement, nullptr, dc); return TypeChecker::coercePatternToType( pattern.forSubPattern(P, /*retainTopLevel=*/true), type, options); } else { @@ -1325,7 +1364,7 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern, auto *base = TypeExpr::createImplicit(extraOptTy, Context); sub = new (Context) EnumElementPattern( base, IP->getStartLoc(), DeclNameLoc(IP->getEndLoc()), - some->createNameRef(), nullptr, sub); + some->createNameRef(), nullptr, sub, dc); sub->setImplicit(); } @@ -1429,8 +1468,8 @@ Pattern *TypeChecker::coercePatternToType(ContextualPattern pattern, // If we have the original expression parse tree, try reinterpreting // it as an expr-pattern if enum element lookup failed, since `.foo` // could also refer to a static member of the context type. - P = ExprPattern::createResolved(Context, - EEP->getUnresolvedOriginalExpr()); + P = ExprPattern::createResolved( + Context, EEP->getUnresolvedOriginalExpr(), dc); return coercePatternToType( pattern.forSubPattern(P, /*retainTopLevel=*/true), type, options); diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index a3907f385a28f..91418e8f31868 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -734,12 +734,6 @@ Pattern *coercePatternToType(ContextualPattern pattern, Type type, TypeResolutionOptions options); bool typeCheckExprPattern(ExprPattern *EP, DeclContext *DC, Type type); -/// Synthesize ~= operator application used to infer enum members -/// in `case` patterns. -Optional> -synthesizeTildeEqualsOperatorApplication(ExprPattern *EP, DeclContext *DC, - Type enumType); - /// Coerce the specified parameter list of a ClosureExpr to the specified /// contextual type. void coerceParameterListToType(ParameterList *P, AnyFunctionType *FN); diff --git a/test/Constraints/patterns.swift b/test/Constraints/patterns.swift index 36a4a15ba5d1a..0104bbe045c16 100644 --- a/test/Constraints/patterns.swift +++ b/test/Constraints/patterns.swift @@ -557,3 +557,10 @@ extension [EWithIdent] { } } } + +struct TestIUOMatchOp { + static func ~= (lhs: TestIUOMatchOp, rhs: TestIUOMatchOp) -> Bool! { nil } + func foo() { + if case self = self {} + } +}