diff --git a/include/swift/AST/DistributedDecl.h b/include/swift/AST/DistributedDecl.h index 8acd37d62ef49..abd0d72a2c5a5 100644 --- a/include/swift/AST/DistributedDecl.h +++ b/include/swift/AST/DistributedDecl.h @@ -50,7 +50,8 @@ Type getDistributedActorIDType(NominalTypeDecl *actor); /// Similar to `getDistributedSerializationRequirementType`, however, from the /// perspective of a concrete function. This way we're able to get the /// serialization requirement for specific members, also in protocols. -Type getConcreteReplacementForMemberSerializationRequirement(ValueDecl *member); +Type getSerializationRequirementTypesForMember( + ValueDecl *member, llvm::SmallPtrSet &serializationRequirements); /// Get specific 'SerializationRequirement' as defined in 'nominal' /// type, which must conform to the passed 'protocol' which is expected @@ -97,7 +98,7 @@ getDistributedSerializationRequirementProtocols( /// If so, we can emit slightly nicer diagnostics. bool checkDistributedSerializationRequirementIsExactlyCodable( ASTContext &C, - const llvm::SmallPtrSetImpl &allRequirements); + Type type); /// Get the `SerializationRequirement`, explode it into the specific /// protocol requirements and insert them into `requirements`. @@ -114,15 +115,6 @@ getDistributedSerializationRequirements( ProtocolDecl *protocol, llvm::SmallPtrSetImpl &requirementProtos); -/// Given any set of generic requirements, locate those which are about the -/// `SerializationRequirement`. Those need to be applied in the parameter and -/// return type checking of distributed targets. -llvm::SmallPtrSet -extractDistributedSerializationRequirements( - ASTContext &C, ArrayRef allRequirements); - } -// ==== ------------------------------------------------------------------------ - #endif /* SWIFT_DECL_DISTRIBUTEDDECL_H */ diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 111c6bc43625b..274d41701d878 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -95,8 +95,9 @@ Type swift::getConcreteReplacementForProtocolActorSystemType(ValueDecl *member) llvm_unreachable("Unable to fetch ActorSystem type!"); } -Type swift::getConcreteReplacementForMemberSerializationRequirement( - ValueDecl *member) { +Type swift::getSerializationRequirementTypesForMember( + ValueDecl *member, + llvm::SmallPtrSet &serializationRequirements) { auto &C = member->getASTContext(); auto *DC = member->getDeclContext(); auto DA = C.getDistributedActorDecl(); @@ -106,8 +107,10 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement( return getDistributedSerializationRequirementType(classDecl, C.getDistributedActorDecl()); } - /// === Maybe the value is declared in a protocol? - if (auto protocol = DC->getSelfProtocolDecl()) { + auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement) + ->getDeclaredInterfaceType(); + + if (DC->getSelfProtocolDecl() || isa(DC)) { GenericSignature signature; if (auto *genericContext = member->getAsGenericContext()) { signature = genericContext->getGenericSignature(); @@ -115,8 +118,17 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement( signature = DC->getGenericSignatureOfContext(); } - auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement) - ->getDeclaredInterfaceType(); + // Also store all `SerializationRequirement : SomeProtocol` requirements + for (auto requirement: signature.getRequirements()) { + if (requirement.getFirstType()->isEqual(SerReqAssocType) && + requirement.getKind() == RequirementKind::Conformance) { + if (auto nominal = requirement.getSecondType()->getAnyNominal()) { + if (auto protocol = dyn_cast(nominal)) { + serializationRequirements.insert(protocol); + } + } + } + } // Note that this may be null, e.g. if we're a distributed func inside // a protocol that did not declare a specific actor system requirement. @@ -355,15 +367,24 @@ swift::getDistributedSerializationRequirements( bool swift::checkDistributedSerializationRequirementIsExactlyCodable( ASTContext &C, - const llvm::SmallPtrSetImpl &allRequirements) { + Type type) { + if (!type) + return false; + + if (type->hasError()) + return false; + auto encodable = C.getProtocol(KnownProtocolKind::Encodable); auto decodable = C.getProtocol(KnownProtocolKind::Decodable); - if (allRequirements.size() != 2) + auto layout = type->getExistentialLayout(); + auto protocols = layout.getProtocols(); + + if (protocols.size() != 2) return false; - return allRequirements.count(encodable) && - allRequirements.count(decodable); + return std::count(protocols.begin(), protocols.end(), encodable) == 1 && + std::count(protocols.begin(), protocols.end(), decodable) == 1; } /******************************************************************************/ @@ -1214,34 +1235,6 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const return true; } -llvm::SmallPtrSet -swift::extractDistributedSerializationRequirements( - ASTContext &C, ArrayRef allRequirements) { - llvm::SmallPtrSet serializationReqs; - auto DA = C.getDistributedActorDecl(); - auto daSerializationReqAssocType = - DA->getAssociatedType(C.Id_SerializationRequirement); - - for (auto req : allRequirements) { - // FIXME: Seems unprincipled - if (req.getKind() != RequirementKind::SameType && - req.getKind() != RequirementKind::Conformance) - continue; - - if (auto dependentMemberType = - req.getFirstType()->getAs()) { - if (dependentMemberType->getAssocType() == daSerializationReqAssocType) { - auto layout = req.getSecondType()->getExistentialLayout(); - for (auto p : layout.getProtocols()) { - serializationReqs.insert(p); - } - } - } - } - - return serializationReqs; -} - /******************************************************************************/ /********************** Distributed Functions *********************************/ /******************************************************************************/ diff --git a/lib/Sema/CodeSynthesisDistributedActor.cpp b/lib/Sema/CodeSynthesisDistributedActor.cpp index f5cceeacd846c..c466d30c3ac67 100644 --- a/lib/Sema/CodeSynthesisDistributedActor.cpp +++ b/lib/Sema/CodeSynthesisDistributedActor.cpp @@ -842,9 +842,15 @@ FuncDecl *GetDistributedThunkRequest::evaluate(Evaluator &evaluator, if (!distributedTarget->isDistributed()) return nullptr; } - assert(distributedTarget); + // This evaluation type-check by now was already computed and cached; + // We need to check in order to avoid emitting a THUNK for a distributed func + // which had errors; as the thunk then may also cause un-addressable issues and confusion. + if (swift::checkDistributedFunction(distributedTarget)) { + return nullptr; + } + auto &C = distributedTarget->getASTContext(); if (!getConcreteReplacementForProtocolActorSystemType(distributedTarget)) { diff --git a/lib/Sema/TypeCheckDeclOverride.cpp b/lib/Sema/TypeCheckDeclOverride.cpp index d5d916ab3494b..4db76bd3a9804 100644 --- a/lib/Sema/TypeCheckDeclOverride.cpp +++ b/lib/Sema/TypeCheckDeclOverride.cpp @@ -2071,7 +2071,7 @@ static bool checkSingleOverride(ValueDecl *override, ValueDecl *base) { return (prop && prop->isFinal() && isa(prop->getDeclContext()) && - cast(prop->getDeclContext())->isActor() && + cast(prop->getDeclContext())->isAnyActor() && !prop->isStatic() && prop->getName() == ctx.Id_unownedExecutor && prop->getInterfaceType()->getAnyNominal() == ctx.getUnownedSerialExecutorDecl()); diff --git a/lib/Sema/TypeCheckDistributed.cpp b/lib/Sema/TypeCheckDistributed.cpp index 7a57d23b631de..8a2310660ae66 100644 --- a/lib/Sema/TypeCheckDistributed.cpp +++ b/lib/Sema/TypeCheckDistributed.cpp @@ -376,10 +376,18 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements( static bool checkDistributedTargetResultType( ModuleDecl *module, ValueDecl *valueDecl, - const llvm::SmallPtrSetImpl &serializationRequirements, + Type serializationRequirement, + llvm::SmallPtrSet serializationRequirements, bool diagnose) { auto &C = valueDecl->getASTContext(); + if (serializationRequirement && serializationRequirement->hasError()) { + return false; + } + if ((!serializationRequirement || serializationRequirement->hasError()) && serializationRequirements.empty()) { + return false; // error of the type would be diagnosed elsewhere + } + Type resultType; if (auto func = dyn_cast(valueDecl)) { resultType = func->mapTypeIntoContext(func->getResultInterfaceType()); @@ -392,18 +400,27 @@ static bool checkDistributedTargetResultType( if (resultType->isVoid()) return false; + + // Collect extra "SerializationRequirement: SomeProtocol" requirements + if (serializationRequirement && !serializationRequirement->hasError()) { + auto srl = serializationRequirement->getExistentialLayout(); + for (auto s: srl.getProtocols()) { + serializationRequirements.insert(s); + } + } + auto isCodableRequirement = checkDistributedSerializationRequirementIsExactlyCodable( - C, serializationRequirements); + C, serializationRequirement); - for(auto serializationReq : serializationRequirements) { + for (auto serializationReq: serializationRequirements) { auto conformance = TypeChecker::conformsToProtocol(resultType, serializationReq, module); if (conformance.isInvalid()) { if (diagnose) { llvm::StringRef conformanceToSuggest = isCodableRequirement ? - "Codable" : // Codable is a typealias, easier to diagnose like that - serializationReq->getNameStr(); + "Codable" : // Codable is a typealias, easier to diagnose like that + serializationReq->getNameStr(); auto diag = valueDecl->diagnose( diag::distributed_actor_target_result_not_codable, @@ -418,12 +435,12 @@ static bool checkDistributedTargetResultType( } } } // end if: diagnose - + return true; } } - return false; + return false; } bool swift::checkDistributedActorSystem(const NominalTypeDecl *system) { @@ -487,66 +504,42 @@ bool CheckDistributedFunctionRequest::evaluate( } auto &C = func->getASTContext(); - auto DC = func->getDeclContext(); auto module = func->getParentModule(); /// If no distributed module is available, then no reason to even try checks. if (!C.getLoadedModule(C.Id_Distributed)) return true; - // === All parameters and the result type must conform - // SerializationRequirement llvm::SmallPtrSet serializationRequirements; - if (auto extension = dyn_cast(DC)) { - serializationRequirements = extractDistributedSerializationRequirements( - C, extension->getGenericRequirements()); - } else if (auto actor = dyn_cast(DC)) { - serializationRequirements = getDistributedSerializationRequirementProtocols( - getDistributedActorSystemType(actor)->getAnyNominal(), - C.getProtocol(KnownProtocolKind::DistributedActorSystem)); - } else if (isa(DC)) { - if (auto seqReqTy = - getConcreteReplacementForMemberSerializationRequirement(func)) { - auto layout = seqReqTy->getExistentialLayout(); - for (auto req : layout.getProtocols()) { - serializationRequirements.insert(req); - } - } - - // The distributed actor constrained protocol has no serialization requirements - // or actor system defined, so these will only be enforced, by implementations - // of DAs conforming to it, skip checks here. - if (serializationRequirements.empty()) { - return false; - } - } else { - llvm_unreachable("Distributed function detected in type other than extension, " - "distributed actor, or protocol! This should not be possible " - ", please file a bug."); - } - - // If the requirement is exactly `Codable` we diagnose it ia bit nicer. - auto serializationRequirementIsCodable = - checkDistributedSerializationRequirementIsExactlyCodable( - C, serializationRequirements); - - for (auto param : *func->getParameters()) { - // --- Check parameters for 'Codable' conformance - auto paramTy = func->mapTypeIntoContext(param->getInterfaceType()); - - for (auto req : serializationRequirements) { - if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) { - auto diag = func->diagnose( - diag::distributed_actor_func_param_not_codable, - param->getArgumentName().str(), param->getInterfaceType(), - func->getDescriptiveKind(), - serializationRequirementIsCodable ? "Codable" - : req->getNameStr()); - - if (auto paramNominalTy = paramTy->getAnyNominal()) { - addCodableFixIt(paramNominalTy, diag); - } // else, no nominal type to suggest the fixit for, e.g. a closure - return true; + Type serializationReqType = getSerializationRequirementTypesForMember(func, serializationRequirements); + + for (auto param: *func->getParameters()) { + // --- Check the parameter conforming to serialization requirements + if (serializationReqType && !serializationReqType->hasError()) { + // If the requirement is exactly `Codable` we diagnose it ia bit nicer. + auto serializationRequirementIsCodable = + checkDistributedSerializationRequirementIsExactlyCodable( + C, serializationReqType); + + // --- Check parameters for 'SerializationRequirement' conformance + auto paramTy = func->mapTypeIntoContext(param->getInterfaceType()); + + auto srl = serializationReqType->getExistentialLayout(); + for (auto req: srl.getProtocols()) { + if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) { + auto diag = func->diagnose( + diag::distributed_actor_func_param_not_codable, + param->getArgumentName().str(), param->getInterfaceType(), + func->getDescriptiveKind(), + serializationRequirementIsCodable ? "Codable" + : req->getNameStr()); + + if (auto paramNominalTy = paramTy->getAnyNominal()) { + addCodableFixIt(paramNominalTy, diag); + } // else, no nominal type to suggest the fixit for, e.g. a closure + + return true; + } } } @@ -583,9 +576,10 @@ bool CheckDistributedFunctionRequest::evaluate( } } - // --- Result type must be either void or a codable type - if (checkDistributedTargetResultType(module, func, serializationRequirements, - /*diagnose=*/true)) { + // --- Result type must be either void or a serialization requirement conforming type + if (checkDistributedTargetResultType( + module, func, serializationReqType, serializationRequirements, + /*diagnose=*/true)) { return true; } @@ -639,8 +633,11 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) { systemDecl, C.getProtocol(KnownProtocolKind::DistributedActorSystem)); + auto serializationRequirement = + getSerializationRequirementTypesForMember(systemVar, serializationRequirements); + auto module = var->getModuleContext(); - if (checkDistributedTargetResultType(module, var, serializationRequirements, diagnose)) { + if (checkDistributedTargetResultType(module, var, serializationRequirement, serializationRequirements, diagnose)) { return true; } @@ -740,13 +737,14 @@ void TypeChecker::checkDistributedActor(SourceFile *SF, NominalTypeDecl *nominal (void)nominal->getDistributedActorIDProperty(); } -void TypeChecker::checkDistributedFunc(FuncDecl *func) { +bool TypeChecker::checkDistributedFunc(FuncDecl *func) { if (!func->isDistributed()) - return; + return false; - swift::checkDistributedFunction(func); + return swift::checkDistributedFunction(func); } +// TODO(distributed): Remove this entirely and rely on generic signature and getConcrete to implement checks llvm::SmallPtrSet swift::getDistributedSerializationRequirementProtocols( NominalTypeDecl *nominal, ProtocolDecl *protocol) { diff --git a/lib/Sema/TypeCheckStmt.cpp b/lib/Sema/TypeCheckStmt.cpp index 25bb1b0382cf9..5a4241cfbe60c 100644 --- a/lib/Sema/TypeCheckStmt.cpp +++ b/lib/Sema/TypeCheckStmt.cpp @@ -2775,6 +2775,14 @@ TypeCheckFunctionBodyRequest::evaluate(Evaluator &eval, // So, build out the body now. ASTScope::expandFunctionBody(AFD); + if (AFD->isDistributedThunk()) { + if (auto func = dyn_cast(AFD)) { + if (TypeChecker::checkDistributedFunc(func)) { + return errorBody(); + } + } + } + // Type check the function body if needed. bool hadError = false; if (!alreadyTypeChecked) { diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 804aee03cacd2..5c2883533632d 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -1127,7 +1127,9 @@ diagnosePotentialUnavailability(SourceRange ReferenceRange, void checkDistributedActor(SourceFile *SF, NominalTypeDecl *decl); /// Type check a single 'distributed func' declaration. -void checkDistributedFunc(FuncDecl *func); +/// +/// Returns `true` if there was an error. +bool checkDistributedFunc(FuncDecl *func); bool checkAvailability(SourceRange ReferenceRange, AvailabilityContext Availability, diff --git a/test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift b/test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift new file mode 100644 index 0000000000000..785accbb2e58d --- /dev/null +++ b/test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift @@ -0,0 +1,41 @@ +// RUN: %empty-directory(%t) +// RUN: %target-swift-frontend-emit-module -emit-module-path %t/FakeDistributedActorSystems.swiftmodule -module-name FakeDistributedActorSystems -disable-availability-checking %S/Inputs/FakeDistributedActorSystems.swift +// RUN: %target-build-swift -module-name main -Xfrontend -disable-availability-checking -j2 -parse-as-library -I %t %s %S/Inputs/FakeDistributedActorSystems.swift 2> %t/output.txt || echo 'failed expectedly' +// RUN: %FileCheck %s < %t/output.txt + +// REQUIRES: concurrency +// REQUIRES: distributed + +// rdar://76038845 +// UNSUPPORTED: use_os_stdlib +// UNSUPPORTED: back_deployment_runtime + +import Distributed + +// Notes: +// This test specifically is not just a -typecheck -verify test but attempts to generate the whole module. +// This is because we may be emitting errors but otherwise still attempt to emit a thunk for an "error-ed" +// distributed function, which would then crash in later phases of compilation when we try to get types +// of the `func` the THUNK is based on. + +typealias DefaultDistributedActorSystem = LocalTestingDistributedActorSystem + +distributed actor Service { +} + +extension Service { + distributed func boombox(_ id: Box) async throws {} + // CHECK: parameter '' of type 'Box' in distributed instance method does not conform to serialization requirement 'Codable' + + distributed func boxIt() async throws -> Box { fatalError() } + // CHECK: result type 'Box' of distributed instance method 'boxIt' does not conform to serialization requirement 'Codable' +} + +public enum Box: Hashable { case boom } + +@main struct Main { + static func main() async { + try? await Service(actorSystem: .init()).boombox(Box.boom) + } +} + diff --git a/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift b/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift index af4fa1020ce58..eace8fc19ea16 100644 --- a/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift +++ b/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift @@ -82,7 +82,7 @@ extension NoSerializationRequirementYet extension NoSerializationRequirementYet where SerializationRequirement: Codable { - // expected-error@+1{{result type 'NotCodable' of distributed instance method 'test4' does not conform to serialization requirement 'Codable'}} + // expected-error@+1{{result type 'NotCodable' of distributed instance method 'test4' does not conform to serialization requirement 'Decodable'}} distributed func test4() -> NotCodable { .init() }