diff --git a/lib/Sema/TypeCheckConstraints.cpp b/lib/Sema/TypeCheckConstraints.cpp index fe32aa5f7804a..51a8821838d63 100644 --- a/lib/Sema/TypeCheckConstraints.cpp +++ b/lib/Sema/TypeCheckConstraints.cpp @@ -787,6 +787,10 @@ namespace { /// Simplify a key path expression into a canonical form. void resolveKeyPathExpr(KeyPathExpr *KPE); + /// Simplify constructs like `UInt32(1)` into `1 as UInt32` if + /// the type conforms to the expected literal protocol. + Expr *simplifyInitWithLiteral(Expr *E); + public: PreCheckExpression(TypeChecker &tc, DeclContext *dc) : TC(tc), DC(dc) { } @@ -989,6 +993,9 @@ namespace { return KPE; } + if (auto *simplified = simplifyInitWithLiteral(expr)) + return simplified; + return expr; } @@ -1547,6 +1554,53 @@ void PreCheckExpression::resolveKeyPathExpr(KeyPathExpr *KPE) { KPE->resolveComponents(TC.Context, components); } +Expr *PreCheckExpression::simplifyInitWithLiteral(Expr *E) { + auto *call = dyn_cast(E); + if (!call || call->getNumArguments() != 1) + return nullptr; + + auto *typeExpr = dyn_cast(call->getFn()); + if (!typeExpr) + return nullptr; + + auto *argExpr = call->getArg()->getSemanticsProvidingExpr(); + auto *number = dyn_cast(argExpr); + if (!number) + return nullptr; + + auto *protocol = TC.getLiteralProtocol(number); + if (!protocol) + return nullptr; + + Type type; + if (auto *rep = typeExpr->getTypeRepr()) { + TypeResolutionOptions options; + options |= TypeResolutionFlags::AllowUnboundGenerics; + options |= TypeResolutionFlags::InExpression; + type = TC.resolveType(rep, DC, options); + } else { + type = typeExpr->getTypeLoc().getType(); + } + + if (!type) + return nullptr; + + ConformanceCheckOptions options; + options |= ConformanceCheckFlags::InExpression; + options |= ConformanceCheckFlags::SuppressDependencyTracking; + options |= ConformanceCheckFlags::SkipConditionalRequirements; + + auto result = TC.conformsToProtocol(type, protocol, DC, options); + if (result) { + auto *expr = + new (TC.Context) CoerceExpr(argExpr, {}, typeExpr->getTypeRepr()); + expr->setImplicit(); + return expr; + } + + return nullptr; +} + /// \brief Clean up the given ill-formed expression, removing any references /// to type variables and setting error types on erroneous expression nodes. void CleanupIllFormedExpressionRAII::doIt(Expr *expr, ASTContext &Context) {