diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index c6a7e6eb443d6..a9948a4649be6 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -4267,10 +4267,19 @@ namespace { // Resolve each of the components. bool didOptionalChain = false; - auto keyPathTy = cs.getType(E)->castTo(); - Type baseTy = keyPathTy->getGenericArgs()[0]; - Type leafTy = keyPathTy->getGenericArgs()[1]; - + bool isFunctionType = false; + Type baseTy, leafTy; + Type exprType = cs.getType(E); + if (auto fnTy = exprType->getAs()) { + baseTy = fnTy->getParams()[0].getType(); + leafTy = fnTy->getResult(); + isFunctionType = true; + } else { + auto keyPathTy = exprType->castTo(); + baseTy = keyPathTy->getGenericArgs()[0]; + leafTy = keyPathTy->getGenericArgs()[1]; + } + for (unsigned i : indices(E->getComponents())) { auto &origComponent = E->getMutableComponents()[i]; @@ -4556,7 +4565,46 @@ namespace { // key path. assert(!baseTy || baseTy->hasUnresolvedType() || baseTy->getWithoutSpecifierType()->isEqual(leafTy)); - return E; + + if (!isFunctionType) + return E; + + // Construct an implicit closure which applies this KeyPath. + auto resultTy = cs.getType(E); + auto &ctx = cs.getASTContext(); + auto toFunc = exprType->getAs(); + auto argTy = toFunc->getParams()[0].getType(); + auto discriminator = AutoClosureExpr::InvalidDiscriminator; + auto closure = new (ctx) + AutoClosureExpr(E, toFunc->getResult(), discriminator, cs.DC); + auto param = new (ctx) ParamDecl(VarDecl::Specifier::Default, SourceLoc(), + SourceLoc(), Identifier(), SourceLoc(), + ctx.getIdentifier("$0"), closure); + param->setType(argTy); + param->setInterfaceType(argTy->mapTypeOutOfContext()); + auto *paramRef = new (ctx) + DeclRefExpr(param, DeclNameLoc(E->getLoc()), /*Implicit=*/true); + paramRef->setType(argTy); + cs.cacheType(paramRef); + + if (resultTy->is()) { + auto kpDecl = cs.getASTContext().getKeyPathDecl(); + E->setType(BoundGenericType::get(kpDecl, nullptr, + {argTy, toFunc->getResult()})); + cs.cacheType(E); + } + auto *application = new (ctx) + KeyPathApplicationExpr(paramRef, E->getStartLoc(), E, E->getEndLoc(), + toFunc->getResult(), /*implicit=*/true); + cs.cacheType(application); + closure->setParameterList(ParameterList::create(ctx, {param})); + closure->setBody(application); + + if (!resultTy->is()) + resultTy = toFunc->withExtInfo(toFunc->getExtInfo().withThrows(false)); + closure->setType(resultTy); + cs.cacheType(closure); + return coerceToType(closure, exprType, cs.getConstraintLocator(E)); } Expr *visitKeyPathDotExpr(KeyPathDotExpr *E) { @@ -6260,9 +6308,9 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType, case ConversionRestrictionKind::TupleToTuple: case ConversionRestrictionKind::LValueToRValue: - // Restrictions that don't need to be recorded. - // Should match recordRestriction() in CSSimplify - break; + // Restrictions that don't need to be recorded. + // Should match recordRestriction() in CSSimplify + break; case ConversionRestrictionKind::DeepEquality: { if (toType->hasUnresolvedType()) diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 8ca4734d61f79..594e95d3b3e9c 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -3082,19 +3082,13 @@ namespace { auto rvalueBase = CS.createTypeVariable(locator); CS.addConstraint(ConstraintKind::Equal, base, rvalueBase, locator); - + // The result is a KeyPath from the root to the end component. - Type kpTy; - if (didOptionalChain) { - // Optional-chaining key paths are always read-only. - kpTy = BoundGenericType::get(kpDecl, Type(), {root, rvalueBase}); - } else { - // The type of key path depends on the overloads chosen for the key - // path components. - kpTy = CS.createTypeVariable(CS.getConstraintLocator(E)); - CS.addKeyPathConstraint(kpTy, root, rvalueBase, - CS.getConstraintLocator(E)); - } + // The type of key path depends on the overloads chosen for the key + // path components, or may also end up as function type. + Type kpTy = CS.createTypeVariable(CS.getConstraintLocator(E)); + CS.addKeyPathConstraint(kpTy, root, rvalueBase, + CS.getConstraintLocator(E)); return kpTy; } diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index 5915a4b241e1a..a4459ef49fb06 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -2143,9 +2143,9 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind, conversionsOrFixes.push_back( ConversionRestrictionKind::LValueToRValue); - // An expression can be converted to an auto-closure function type, creating - // an implicit closure. if (auto function2 = type2->getAs()) { + // An expression can be converted to an auto-closure function type, + // creating an implicit closure. if (function2->isAutoClosure()) return matchTypes( type1, function2->getResult(), kind, subflags, @@ -4124,7 +4124,26 @@ ConstraintSystem::simplifyKeyPathConstraint(Type keyPathTy, return SolutionKind::Error; } } - + + // If we're bound to a (Root) -> Value function type, use that for context. + if (auto fnTy = keyPathTy->getAs()) { + if (fnTy->getParams().size() == 1) { + Type boundRoot = fnTy->getParams()[0].getType(); + Type boundValue = fnTy->getResult(); + + if (matchTypes(boundRoot, rootTy, ConstraintKind::Bind, subflags, locator) + .isFailure()) + return SolutionKind::Error; + + if (matchTypes(boundValue, valueTy, ConstraintKind::Bind, subflags, + locator) + .isFailure()) + return SolutionKind::Error; + + return SolutionKind::Solved; + } + } + // See if we resolved overloads for all the components involved. enum { ReadOnly, @@ -4238,11 +4257,18 @@ ConstraintSystem::simplifyKeyPathConstraint(Type keyPathTy, && capability >= Writable) kpDecl = getASTContext().getWritableKeyPathDecl(); } - - auto resolvedKPTy = BoundGenericType::get(kpDecl, nullptr, - {rootTy, valueTy}); - return matchTypes(resolvedKPTy, keyPathTy, ConstraintKind::Bind, - subflags, locator); + + Type resolvedKPTy = BoundGenericType::get(kpDecl, nullptr, {rootTy, valueTy}); + Type fnType = FunctionType::get({AnyFunctionType::Param(rootTy)}, valueTy, + AnyFunctionType::ExtInfo().withThrows(false)); + llvm::SmallVector constraints; + auto loc = locator.getBaseLocator(); + constraints.push_back(Constraint::create(*this, ConstraintKind::Bind, + keyPathTy, resolvedKPTy, loc)); + constraints.push_back( + Constraint::create(*this, ConstraintKind::Bind, keyPathTy, fnType, loc)); + addDisjunctionConstraint(constraints, locator); + return SolutionKind::Solved; } ConstraintSystem::SolutionKind diff --git a/lib/Sema/Constraint.h b/lib/Sema/Constraint.h index f2f9cd08845d4..10d7f299aef19 100644 --- a/lib/Sema/Constraint.h +++ b/lib/Sema/Constraint.h @@ -196,7 +196,8 @@ enum class ConversionRestrictionKind { MetatypeToExistentialMetatype, /// Existential metatype to metatype conversion. ExistentialMetatypeToMetatype, - /// T -> U? value to optional conversion (or to implicitly unwrapped optional). + /// T -> U? value to optional conversion (or to implicitly unwrapped + /// optional). ValueToOptional, /// T? -> U? optional to optional conversion (or unchecked to unchecked). OptionalToOptional, @@ -214,7 +215,7 @@ enum class ConversionRestrictionKind { CFTollFreeBridgeToObjC, /// Implicit conversion from an Objective-C class type to its /// toll-free-bridged CF type. - ObjCTollFreeBridgeToCF + ObjCTollFreeBridgeToCF, }; /// Return a string representation of a conversion restriction.