From 30653a80912b3abee2443957945716d6cbf1ab83 Mon Sep 17 00:00:00 2001 From: Konrad `ktoso` Malawski Date: Mon, 30 Oct 2023 19:44:24 +0900 Subject: [PATCH 1/6] [Distributed] Don't crash in thunk generation when missing SR conformance --- include/swift/AST/DistributedDecl.h | 6 ++- lib/AST/DistributedDecl.cpp | 9 ++-- lib/Sema/CodeSynthesisDistributedActor.cpp | 8 +++- lib/Sema/TypeCheckDeclOverride.cpp | 2 +- lib/Sema/TypeCheckDistributed.cpp | 25 ++++++++--- lib/Sema/TypeCheckStmt.cpp | 8 ++++ lib/Sema/TypeChecker.h | 4 +- ...r_func_param_not_conforming_req_full.swift | 41 +++++++++++++++++++ 8 files changed, 88 insertions(+), 15 deletions(-) create mode 100644 test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift diff --git a/include/swift/AST/DistributedDecl.h b/include/swift/AST/DistributedDecl.h index 8acd37d62ef49..ba916be08c748 100644 --- a/include/swift/AST/DistributedDecl.h +++ b/include/swift/AST/DistributedDecl.h @@ -117,9 +117,11 @@ getDistributedSerializationRequirements( /// Given any set of generic requirements, locate those which are about the /// `SerializationRequirement`. Those need to be applied in the parameter and /// return type checking of distributed targets. -llvm::SmallPtrSet +void extractDistributedSerializationRequirements( - ASTContext &C, ArrayRef allRequirements); + ASTContext &C, + ArrayRef allRequirements, + llvm::SmallPtrSet &into); } diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 111c6bc43625b..7d9235364714a 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -1214,10 +1214,11 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const return true; } -llvm::SmallPtrSet +void swift::extractDistributedSerializationRequirements( - ASTContext &C, ArrayRef allRequirements) { - llvm::SmallPtrSet serializationReqs; + ASTContext &C, + ArrayRef allRequirements, + llvm::SmallPtrSet &into) { auto DA = C.getDistributedActorDecl(); auto daSerializationReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement); @@ -1238,8 +1239,6 @@ swift::extractDistributedSerializationRequirements( } } } - - return serializationReqs; } /******************************************************************************/ diff --git a/lib/Sema/CodeSynthesisDistributedActor.cpp b/lib/Sema/CodeSynthesisDistributedActor.cpp index f5cceeacd846c..c466d30c3ac67 100644 --- a/lib/Sema/CodeSynthesisDistributedActor.cpp +++ b/lib/Sema/CodeSynthesisDistributedActor.cpp @@ -842,9 +842,15 @@ FuncDecl *GetDistributedThunkRequest::evaluate(Evaluator &evaluator, if (!distributedTarget->isDistributed()) return nullptr; } - assert(distributedTarget); + // This evaluation type-check by now was already computed and cached; + // We need to check in order to avoid emitting a THUNK for a distributed func + // which had errors; as the thunk then may also cause un-addressable issues and confusion. + if (swift::checkDistributedFunction(distributedTarget)) { + return nullptr; + } + auto &C = distributedTarget->getASTContext(); if (!getConcreteReplacementForProtocolActorSystemType(distributedTarget)) { diff --git a/lib/Sema/TypeCheckDeclOverride.cpp b/lib/Sema/TypeCheckDeclOverride.cpp index d5d916ab3494b..4db76bd3a9804 100644 --- a/lib/Sema/TypeCheckDeclOverride.cpp +++ b/lib/Sema/TypeCheckDeclOverride.cpp @@ -2071,7 +2071,7 @@ static bool checkSingleOverride(ValueDecl *override, ValueDecl *base) { return (prop && prop->isFinal() && isa(prop->getDeclContext()) && - cast(prop->getDeclContext())->isActor() && + cast(prop->getDeclContext())->isAnyActor() && !prop->isStatic() && prop->getName() == ctx.Id_unownedExecutor && prop->getInterfaceType()->getAnyNominal() == ctx.getUnownedSerialExecutorDecl()); diff --git a/lib/Sema/TypeCheckDistributed.cpp b/lib/Sema/TypeCheckDistributed.cpp index 7a57d23b631de..2aee52525bff1 100644 --- a/lib/Sema/TypeCheckDistributed.cpp +++ b/lib/Sema/TypeCheckDistributed.cpp @@ -498,8 +498,22 @@ bool CheckDistributedFunctionRequest::evaluate( // SerializationRequirement llvm::SmallPtrSet serializationRequirements; if (auto extension = dyn_cast(DC)) { - serializationRequirements = extractDistributedSerializationRequirements( - C, extension->getGenericRequirements()); + auto actorOrProtocol = extension->getExtendedNominal(); + if (auto actor = dyn_cast(actorOrProtocol)) { + assert(actor->isAnyActor()); + serializationRequirements = getDistributedSerializationRequirementProtocols( + getDistributedActorSystemType(actor)->getAnyNominal(), + C.getProtocol(KnownProtocolKind::DistributedActorSystem)); + } else if (auto protocol = dyn_cast(actorOrProtocol)) { + extractDistributedSerializationRequirements( + C, protocol->getGenericRequirements(), + /*into=*/serializationRequirements); + extractDistributedSerializationRequirements( + C, extension->getGenericRequirements(), + /*into=*/serializationRequirements); + } else { + // ignore + } } else if (auto actor = dyn_cast(DC)) { serializationRequirements = getDistributedSerializationRequirementProtocols( getDistributedActorSystemType(actor)->getAnyNominal(), @@ -546,6 +560,7 @@ bool CheckDistributedFunctionRequest::evaluate( if (auto paramNominalTy = paramTy->getAnyNominal()) { addCodableFixIt(paramNominalTy, diag); } // else, no nominal type to suggest the fixit for, e.g. a closure + return true; } } @@ -740,11 +755,11 @@ void TypeChecker::checkDistributedActor(SourceFile *SF, NominalTypeDecl *nominal (void)nominal->getDistributedActorIDProperty(); } -void TypeChecker::checkDistributedFunc(FuncDecl *func) { +bool TypeChecker::checkDistributedFunc(FuncDecl *func) { if (!func->isDistributed()) - return; + return false; - swift::checkDistributedFunction(func); + return swift::checkDistributedFunction(func); } llvm::SmallPtrSet diff --git a/lib/Sema/TypeCheckStmt.cpp b/lib/Sema/TypeCheckStmt.cpp index 25bb1b0382cf9..5a4241cfbe60c 100644 --- a/lib/Sema/TypeCheckStmt.cpp +++ b/lib/Sema/TypeCheckStmt.cpp @@ -2775,6 +2775,14 @@ TypeCheckFunctionBodyRequest::evaluate(Evaluator &eval, // So, build out the body now. ASTScope::expandFunctionBody(AFD); + if (AFD->isDistributedThunk()) { + if (auto func = dyn_cast(AFD)) { + if (TypeChecker::checkDistributedFunc(func)) { + return errorBody(); + } + } + } + // Type check the function body if needed. bool hadError = false; if (!alreadyTypeChecked) { diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 804aee03cacd2..5c2883533632d 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -1127,7 +1127,9 @@ diagnosePotentialUnavailability(SourceRange ReferenceRange, void checkDistributedActor(SourceFile *SF, NominalTypeDecl *decl); /// Type check a single 'distributed func' declaration. -void checkDistributedFunc(FuncDecl *func); +/// +/// Returns `true` if there was an error. +bool checkDistributedFunc(FuncDecl *func); bool checkAvailability(SourceRange ReferenceRange, AvailabilityContext Availability, diff --git a/test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift b/test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift new file mode 100644 index 0000000000000..785accbb2e58d --- /dev/null +++ b/test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift @@ -0,0 +1,41 @@ +// RUN: %empty-directory(%t) +// RUN: %target-swift-frontend-emit-module -emit-module-path %t/FakeDistributedActorSystems.swiftmodule -module-name FakeDistributedActorSystems -disable-availability-checking %S/Inputs/FakeDistributedActorSystems.swift +// RUN: %target-build-swift -module-name main -Xfrontend -disable-availability-checking -j2 -parse-as-library -I %t %s %S/Inputs/FakeDistributedActorSystems.swift 2> %t/output.txt || echo 'failed expectedly' +// RUN: %FileCheck %s < %t/output.txt + +// REQUIRES: concurrency +// REQUIRES: distributed + +// rdar://76038845 +// UNSUPPORTED: use_os_stdlib +// UNSUPPORTED: back_deployment_runtime + +import Distributed + +// Notes: +// This test specifically is not just a -typecheck -verify test but attempts to generate the whole module. +// This is because we may be emitting errors but otherwise still attempt to emit a thunk for an "error-ed" +// distributed function, which would then crash in later phases of compilation when we try to get types +// of the `func` the THUNK is based on. + +typealias DefaultDistributedActorSystem = LocalTestingDistributedActorSystem + +distributed actor Service { +} + +extension Service { + distributed func boombox(_ id: Box) async throws {} + // CHECK: parameter '' of type 'Box' in distributed instance method does not conform to serialization requirement 'Codable' + + distributed func boxIt() async throws -> Box { fatalError() } + // CHECK: result type 'Box' of distributed instance method 'boxIt' does not conform to serialization requirement 'Codable' +} + +public enum Box: Hashable { case boom } + +@main struct Main { + static func main() async { + try? await Service(actorSystem: .init()).boombox(Box.boom) + } +} + From cbde18680fb74fa0f3773bbc936dc44dbb7ea770 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Mon, 30 Oct 2023 14:53:44 -0400 Subject: [PATCH 2/6] Distributed: Some cleanups --- lib/AST/DistributedDecl.cpp | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 7d9235364714a..08ca83090240c 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -341,11 +341,12 @@ swift::getDistributedSerializationRequirements( return true; // we're done here, any means there are no requirements auto *serialReqType = existentialRequirementTy->getAs(); + ->getAs(); if (!serialReqType || serialReqType->hasError()) { return false; } - auto layout = serialReqType->getExistentialLayout(); + auto desugaredTy = serialReqType->getConstraintType(); for (auto p : layout.getProtocols()) { requirementProtos.insert(p); } @@ -1224,15 +1225,25 @@ swift::extractDistributedSerializationRequirements( DA->getAssociatedType(C.Id_SerializationRequirement); for (auto req : allRequirements) { - // FIXME: Seems unprincipled - if (req.getKind() != RequirementKind::SameType && - req.getKind() != RequirementKind::Conformance) + if (req.getSecondType()->isAny()) { + continue; + } + if (!req.getFirstType()->hasDependentMember()) continue; if (auto dependentMemberType = - req.getFirstType()->getAs()) { + req.getFirstType()->castTo()) { if (dependentMemberType->getAssocType() == daSerializationReqAssocType) { - auto layout = req.getSecondType()->getExistentialLayout(); + auto requirementProto = req.getSecondType(); + if (auto proto = dyn_cast_or_null( + requirementProto->getAnyNominal())) { + into.insert(proto); + } else { + auto serialReqType = requirementProto->castTo() + ->getConstraintType(); + auto flattenedRequirements = + flattenDistributedSerializationTypeToRequiredProtocols( + serialReqType.getPointer()); for (auto p : layout.getProtocols()) { serializationReqs.insert(p); } From 4c22395d94c1e691f3c6765ac7905f164ab76671 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Mon, 30 Oct 2023 14:59:52 -0400 Subject: [PATCH 3/6] Distributed: Remove flattenDistributedSerializationTypeToRequiredProtocols() --- lib/AST/DistributedDecl.cpp | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 08ca83090240c..4d8377a5fbf0f 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -341,12 +341,11 @@ swift::getDistributedSerializationRequirements( return true; // we're done here, any means there are no requirements auto *serialReqType = existentialRequirementTy->getAs(); - ->getAs(); if (!serialReqType || serialReqType->hasError()) { return false; } - auto desugaredTy = serialReqType->getConstraintType(); + auto layout = serialReqType->getExistentialLayout(); for (auto p : layout.getProtocols()) { requirementProtos.insert(p); } @@ -1234,16 +1233,20 @@ swift::extractDistributedSerializationRequirements( if (auto dependentMemberType = req.getFirstType()->castTo()) { if (dependentMemberType->getAssocType() == daSerializationReqAssocType) { - auto requirementProto = req.getSecondType(); - if (auto proto = dyn_cast_or_null( - requirementProto->getAnyNominal())) { - into.insert(proto); - } else { - auto serialReqType = requirementProto->castTo() - ->getConstraintType(); - auto flattenedRequirements = - flattenDistributedSerializationTypeToRequiredProtocols( - serialReqType.getPointer()); + // auto requirementProto = req.getSecondType(); + // if (auto proto = dyn_cast_or_null( + // requirementProto->getAnyNominal())) { + // into.insert(proto); + // } else { + // auto serialReqType = requirementProto->castTo() + // ->getConstraintType(); + // auto flattenedRequirements = + // flattenDistributedSerializationTypeToRequiredProtocols( + // serialReqType.getPointer()); + // for (auto p : flattenedRequirements) { + // into.insert(p); + // } + auto layout = req.getSecondType()->getExistentialLayout(); for (auto p : layout.getProtocols()) { serializationReqs.insert(p); } From f91b12bd0ec5a17e60eaeee91588cbc2336bbfa2 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Mon, 30 Oct 2023 15:08:09 -0400 Subject: [PATCH 4/6] Distributed: Simplify extractDistributedSerializationRequirements() --- lib/AST/DistributedDecl.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 4d8377a5fbf0f..06ea5d4b0d562 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -1224,14 +1224,13 @@ swift::extractDistributedSerializationRequirements( DA->getAssociatedType(C.Id_SerializationRequirement); for (auto req : allRequirements) { - if (req.getSecondType()->isAny()) { - continue; - } - if (!req.getFirstType()->hasDependentMember()) + // FIXME: Seems unprincipled + if (req.getKind() != RequirementKind::SameType && + req.getKind() != RequirementKind::Conformance) continue; if (auto dependentMemberType = - req.getFirstType()->castTo()) { + req.getFirstType()->getAs()) { if (dependentMemberType->getAssocType() == daSerializationReqAssocType) { // auto requirementProto = req.getSecondType(); // if (auto proto = dyn_cast_or_null( @@ -1248,7 +1247,7 @@ swift::extractDistributedSerializationRequirements( // } auto layout = req.getSecondType()->getExistentialLayout(); for (auto p : layout.getProtocols()) { - serializationReqs.insert(p); + into.insert(p); } } } From 436ecb240be37b1471f8e27eac7efff4dcfc6d5a Mon Sep 17 00:00:00 2001 From: Konrad `ktoso` Malawski Date: Tue, 31 Oct 2023 14:48:09 +0900 Subject: [PATCH 5/6] Use more of getConcreteReplacementForMemberSerializationRequirement --- include/swift/AST/DistributedDecl.h | 2 +- lib/AST/DistributedDecl.cpp | 26 ++-- lib/Sema/TypeCheckDistributed.cpp | 180 ++++++++++++---------------- 3 files changed, 94 insertions(+), 114 deletions(-) diff --git a/include/swift/AST/DistributedDecl.h b/include/swift/AST/DistributedDecl.h index ba916be08c748..76758d3de461e 100644 --- a/include/swift/AST/DistributedDecl.h +++ b/include/swift/AST/DistributedDecl.h @@ -97,7 +97,7 @@ getDistributedSerializationRequirementProtocols( /// If so, we can emit slightly nicer diagnostics. bool checkDistributedSerializationRequirementIsExactlyCodable( ASTContext &C, - const llvm::SmallPtrSetImpl &allRequirements); + Type type); /// Get the `SerializationRequirement`, explode it into the specific /// protocol requirements and insert them into `requirements`. diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 06ea5d4b0d562..45cc7eafe1cb6 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -106,8 +106,10 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement( return getDistributedSerializationRequirementType(classDecl, C.getDistributedActorDecl()); } - /// === Maybe the value is declared in a protocol? - if (auto protocol = DC->getSelfProtocolDecl()) { + auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement) + ->getDeclaredInterfaceType(); + + if (DC->getSelfProtocolDecl() || isa(DC)) { GenericSignature signature; if (auto *genericContext = member->getAsGenericContext()) { signature = genericContext->getGenericSignature(); @@ -115,9 +117,6 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement( signature = DC->getGenericSignatureOfContext(); } - auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement) - ->getDeclaredInterfaceType(); - // Note that this may be null, e.g. if we're a distributed func inside // a protocol that did not declare a specific actor system requirement. return signature->getConcreteType(SerReqAssocType); @@ -355,15 +354,24 @@ swift::getDistributedSerializationRequirements( bool swift::checkDistributedSerializationRequirementIsExactlyCodable( ASTContext &C, - const llvm::SmallPtrSetImpl &allRequirements) { + Type type) { + if (!type) + return false; + + if (type->hasError()) + return false; + auto encodable = C.getProtocol(KnownProtocolKind::Encodable); auto decodable = C.getProtocol(KnownProtocolKind::Decodable); - if (allRequirements.size() != 2) + auto layout = type->getExistentialLayout(); + auto protocols = layout.getProtocols(); + + if (protocols.size() != 2) return false; - return allRequirements.count(encodable) && - allRequirements.count(decodable); + return std::count(protocols.begin(), protocols.end(), encodable) == 1 && + std::count(protocols.begin(), protocols.end(), decodable) == 1; } /******************************************************************************/ diff --git a/lib/Sema/TypeCheckDistributed.cpp b/lib/Sema/TypeCheckDistributed.cpp index 2aee52525bff1..5af17897d155b 100644 --- a/lib/Sema/TypeCheckDistributed.cpp +++ b/lib/Sema/TypeCheckDistributed.cpp @@ -376,10 +376,13 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements( static bool checkDistributedTargetResultType( ModuleDecl *module, ValueDecl *valueDecl, - const llvm::SmallPtrSetImpl &serializationRequirements, + Type serializationRequirement, bool diagnose) { auto &C = valueDecl->getASTContext(); + if (!serializationRequirement || serializationRequirement->hasError()) + return false; // error of the type would be diagnosed elsewhere + Type resultType; if (auto func = dyn_cast(valueDecl)) { resultType = func->mapTypeIntoContext(func->getResultInterfaceType()); @@ -394,36 +397,39 @@ static bool checkDistributedTargetResultType( auto isCodableRequirement = checkDistributedSerializationRequirementIsExactlyCodable( - C, serializationRequirements); - - for(auto serializationReq : serializationRequirements) { - auto conformance = - TypeChecker::conformsToProtocol(resultType, serializationReq, module); - if (conformance.isInvalid()) { - if (diagnose) { - llvm::StringRef conformanceToSuggest = isCodableRequirement ? - "Codable" : // Codable is a typealias, easier to diagnose like that - serializationReq->getNameStr(); - - auto diag = valueDecl->diagnose( - diag::distributed_actor_target_result_not_codable, - resultType, - valueDecl, - conformanceToSuggest - ); - - if (isCodableRequirement) { - if (auto resultNominalType = resultType->getAnyNominal()) { - addCodableFixIt(resultNominalType, diag); + C, serializationRequirement); + + if (serializationRequirement && !serializationRequirement->hasError()) { + auto srl = serializationRequirement->getExistentialLayout(); + for (auto serializationReq: srl.getProtocols()) { + auto conformance = + TypeChecker::conformsToProtocol(resultType, serializationReq, module); + if (conformance.isInvalid()) { + if (diagnose) { + llvm::StringRef conformanceToSuggest = isCodableRequirement ? + "Codable" : // Codable is a typealias, easier to diagnose like that + serializationReq->getNameStr(); + + auto diag = valueDecl->diagnose( + diag::distributed_actor_target_result_not_codable, + resultType, + valueDecl, + conformanceToSuggest + ); + + if (isCodableRequirement) { + if (auto resultNominalType = resultType->getAnyNominal()) { + addCodableFixIt(resultNominalType, diag); + } } - } - } // end if: diagnose - - return true; + } // end if: diagnose + + return true; + } } } - return false; + return false; } bool swift::checkDistributedActorSystem(const NominalTypeDecl *system) { @@ -494,74 +500,35 @@ bool CheckDistributedFunctionRequest::evaluate( if (!C.getLoadedModule(C.Id_Distributed)) return true; - // === All parameters and the result type must conform - // SerializationRequirement - llvm::SmallPtrSet serializationRequirements; - if (auto extension = dyn_cast(DC)) { - auto actorOrProtocol = extension->getExtendedNominal(); - if (auto actor = dyn_cast(actorOrProtocol)) { - assert(actor->isAnyActor()); - serializationRequirements = getDistributedSerializationRequirementProtocols( - getDistributedActorSystemType(actor)->getAnyNominal(), - C.getProtocol(KnownProtocolKind::DistributedActorSystem)); - } else if (auto protocol = dyn_cast(actorOrProtocol)) { - extractDistributedSerializationRequirements( - C, protocol->getGenericRequirements(), - /*into=*/serializationRequirements); - extractDistributedSerializationRequirements( - C, extension->getGenericRequirements(), - /*into=*/serializationRequirements); - } else { - // ignore - } - } else if (auto actor = dyn_cast(DC)) { - serializationRequirements = getDistributedSerializationRequirementProtocols( - getDistributedActorSystemType(actor)->getAnyNominal(), - C.getProtocol(KnownProtocolKind::DistributedActorSystem)); - } else if (isa(DC)) { - if (auto seqReqTy = - getConcreteReplacementForMemberSerializationRequirement(func)) { - auto layout = seqReqTy->getExistentialLayout(); - for (auto req : layout.getProtocols()) { - serializationRequirements.insert(req); - } - } - - // The distributed actor constrained protocol has no serialization requirements - // or actor system defined, so these will only be enforced, by implementations - // of DAs conforming to it, skip checks here. - if (serializationRequirements.empty()) { - return false; - } - } else { - llvm_unreachable("Distributed function detected in type other than extension, " - "distributed actor, or protocol! This should not be possible " - ", please file a bug."); - } - - // If the requirement is exactly `Codable` we diagnose it ia bit nicer. - auto serializationRequirementIsCodable = - checkDistributedSerializationRequirementIsExactlyCodable( - C, serializationRequirements); - - for (auto param : *func->getParameters()) { - // --- Check parameters for 'Codable' conformance - auto paramTy = func->mapTypeIntoContext(param->getInterfaceType()); - - for (auto req : serializationRequirements) { - if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) { - auto diag = func->diagnose( - diag::distributed_actor_func_param_not_codable, - param->getArgumentName().str(), param->getInterfaceType(), - func->getDescriptiveKind(), - serializationRequirementIsCodable ? "Codable" - : req->getNameStr()); - - if (auto paramNominalTy = paramTy->getAnyNominal()) { - addCodableFixIt(paramNominalTy, diag); - } // else, no nominal type to suggest the fixit for, e.g. a closure - - return true; + Type serializationReqType = getConcreteReplacementForMemberSerializationRequirement(func); + for (auto param: *func->getParameters()) { + + // --- Check the parameter conforming to serialization requirements + if (serializationReqType && !serializationReqType->hasError()) { + // If the requirement is exactly `Codable` we diagnose it ia bit nicer. + auto serializationRequirementIsCodable = + checkDistributedSerializationRequirementIsExactlyCodable( + C, serializationReqType); + + // --- Check parameters for 'SerializationRequirement' conformance + auto paramTy = func->mapTypeIntoContext(param->getInterfaceType()); + + auto srl = serializationReqType->getExistentialLayout(); + for (auto req: srl.getProtocols()) { + if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) { + auto diag = func->diagnose( + diag::distributed_actor_func_param_not_codable, + param->getArgumentName().str(), param->getInterfaceType(), + func->getDescriptiveKind(), + serializationRequirementIsCodable ? "Codable" + : req->getNameStr()); + + if (auto paramNominalTy = paramTy->getAnyNominal()) { + addCodableFixIt(paramNominalTy, diag); + } // else, no nominal type to suggest the fixit for, e.g. a closure + + return true; + } } } @@ -598,10 +565,12 @@ bool CheckDistributedFunctionRequest::evaluate( } } - // --- Result type must be either void or a codable type - if (checkDistributedTargetResultType(module, func, serializationRequirements, - /*diagnose=*/true)) { - return true; + if (serializationReqType && !serializationReqType->hasError()) { + // --- Result type must be either void or a codable type + if (checkDistributedTargetResultType(module, func, serializationReqType, + /*diagnose=*/true)) { + return true; + } } return false; @@ -649,13 +618,15 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) { DC->getSelfNominalTypeDecl()->getDistributedActorSystemProperty(); auto systemDecl = systemVar->getInterfaceType()->getAnyNominal(); - auto serializationRequirements = - getDistributedSerializationRequirementProtocols( - systemDecl, - C.getProtocol(KnownProtocolKind::DistributedActorSystem)); +// auto serializationRequirements = +// getDistributedSerializationRequirementProtocols( +// systemDecl, +// C.getProtocol(KnownProtocolKind::DistributedActorSystem)); + auto serializationRequirement = + getConcreteReplacementForMemberSerializationRequirement(systemVar); auto module = var->getModuleContext(); - if (checkDistributedTargetResultType(module, var, serializationRequirements, diagnose)) { + if (checkDistributedTargetResultType(module, var, serializationRequirement, diagnose)) { return true; } @@ -762,6 +733,7 @@ bool TypeChecker::checkDistributedFunc(FuncDecl *func) { return swift::checkDistributedFunction(func); } +// TODO(distributed): Remove this entirely and rely on generic signature and getConcrete to implement checks llvm::SmallPtrSet swift::getDistributedSerializationRequirementProtocols( NominalTypeDecl *nominal, ProtocolDecl *protocol) { From 0f5e564bbff8edc31fdafaa53db49c6d513e21c0 Mon Sep 17 00:00:00 2001 From: Konrad `ktoso` Malawski Date: Tue, 14 Nov 2023 19:10:36 +0900 Subject: [PATCH 6/6] handle conformance requirement on extension in distributed funcs --- include/swift/AST/DistributedDecl.h | 14 +-- lib/AST/DistributedDecl.cpp | 57 +++--------- lib/Sema/TypeCheckDistributed.cpp | 93 +++++++++++-------- ...uted_func_serialization_requirements.swift | 2 +- 4 files changed, 70 insertions(+), 96 deletions(-) diff --git a/include/swift/AST/DistributedDecl.h b/include/swift/AST/DistributedDecl.h index 76758d3de461e..abd0d72a2c5a5 100644 --- a/include/swift/AST/DistributedDecl.h +++ b/include/swift/AST/DistributedDecl.h @@ -50,7 +50,8 @@ Type getDistributedActorIDType(NominalTypeDecl *actor); /// Similar to `getDistributedSerializationRequirementType`, however, from the /// perspective of a concrete function. This way we're able to get the /// serialization requirement for specific members, also in protocols. -Type getConcreteReplacementForMemberSerializationRequirement(ValueDecl *member); +Type getSerializationRequirementTypesForMember( + ValueDecl *member, llvm::SmallPtrSet &serializationRequirements); /// Get specific 'SerializationRequirement' as defined in 'nominal' /// type, which must conform to the passed 'protocol' which is expected @@ -114,17 +115,6 @@ getDistributedSerializationRequirements( ProtocolDecl *protocol, llvm::SmallPtrSetImpl &requirementProtos); -/// Given any set of generic requirements, locate those which are about the -/// `SerializationRequirement`. Those need to be applied in the parameter and -/// return type checking of distributed targets. -void -extractDistributedSerializationRequirements( - ASTContext &C, - ArrayRef allRequirements, - llvm::SmallPtrSet &into); - } -// ==== ------------------------------------------------------------------------ - #endif /* SWIFT_DECL_DISTRIBUTEDDECL_H */ diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 45cc7eafe1cb6..274d41701d878 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -95,8 +95,9 @@ Type swift::getConcreteReplacementForProtocolActorSystemType(ValueDecl *member) llvm_unreachable("Unable to fetch ActorSystem type!"); } -Type swift::getConcreteReplacementForMemberSerializationRequirement( - ValueDecl *member) { +Type swift::getSerializationRequirementTypesForMember( + ValueDecl *member, + llvm::SmallPtrSet &serializationRequirements) { auto &C = member->getASTContext(); auto *DC = member->getDeclContext(); auto DA = C.getDistributedActorDecl(); @@ -117,6 +118,18 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement( signature = DC->getGenericSignatureOfContext(); } + // Also store all `SerializationRequirement : SomeProtocol` requirements + for (auto requirement: signature.getRequirements()) { + if (requirement.getFirstType()->isEqual(SerReqAssocType) && + requirement.getKind() == RequirementKind::Conformance) { + if (auto nominal = requirement.getSecondType()->getAnyNominal()) { + if (auto protocol = dyn_cast(nominal)) { + serializationRequirements.insert(protocol); + } + } + } + } + // Note that this may be null, e.g. if we're a distributed func inside // a protocol that did not declare a specific actor system requirement. return signature->getConcreteType(SerReqAssocType); @@ -1222,46 +1235,6 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const return true; } -void -swift::extractDistributedSerializationRequirements( - ASTContext &C, - ArrayRef allRequirements, - llvm::SmallPtrSet &into) { - auto DA = C.getDistributedActorDecl(); - auto daSerializationReqAssocType = - DA->getAssociatedType(C.Id_SerializationRequirement); - - for (auto req : allRequirements) { - // FIXME: Seems unprincipled - if (req.getKind() != RequirementKind::SameType && - req.getKind() != RequirementKind::Conformance) - continue; - - if (auto dependentMemberType = - req.getFirstType()->getAs()) { - if (dependentMemberType->getAssocType() == daSerializationReqAssocType) { - // auto requirementProto = req.getSecondType(); - // if (auto proto = dyn_cast_or_null( - // requirementProto->getAnyNominal())) { - // into.insert(proto); - // } else { - // auto serialReqType = requirementProto->castTo() - // ->getConstraintType(); - // auto flattenedRequirements = - // flattenDistributedSerializationTypeToRequiredProtocols( - // serialReqType.getPointer()); - // for (auto p : flattenedRequirements) { - // into.insert(p); - // } - auto layout = req.getSecondType()->getExistentialLayout(); - for (auto p : layout.getProtocols()) { - into.insert(p); - } - } - } - } -} - /******************************************************************************/ /********************** Distributed Functions *********************************/ /******************************************************************************/ diff --git a/lib/Sema/TypeCheckDistributed.cpp b/lib/Sema/TypeCheckDistributed.cpp index 5af17897d155b..8a2310660ae66 100644 --- a/lib/Sema/TypeCheckDistributed.cpp +++ b/lib/Sema/TypeCheckDistributed.cpp @@ -377,11 +377,16 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements( static bool checkDistributedTargetResultType( ModuleDecl *module, ValueDecl *valueDecl, Type serializationRequirement, + llvm::SmallPtrSet serializationRequirements, bool diagnose) { auto &C = valueDecl->getASTContext(); - if (!serializationRequirement || serializationRequirement->hasError()) + if (serializationRequirement && serializationRequirement->hasError()) { + return false; + } + if ((!serializationRequirement || serializationRequirement->hasError()) && serializationRequirements.empty()) { return false; // error of the type would be diagnosed elsewhere + } Type resultType; if (auto func = dyn_cast(valueDecl)) { @@ -395,37 +400,43 @@ static bool checkDistributedTargetResultType( if (resultType->isVoid()) return false; + + // Collect extra "SerializationRequirement: SomeProtocol" requirements + if (serializationRequirement && !serializationRequirement->hasError()) { + auto srl = serializationRequirement->getExistentialLayout(); + for (auto s: srl.getProtocols()) { + serializationRequirements.insert(s); + } + } + auto isCodableRequirement = checkDistributedSerializationRequirementIsExactlyCodable( C, serializationRequirement); - if (serializationRequirement && !serializationRequirement->hasError()) { - auto srl = serializationRequirement->getExistentialLayout(); - for (auto serializationReq: srl.getProtocols()) { - auto conformance = - TypeChecker::conformsToProtocol(resultType, serializationReq, module); - if (conformance.isInvalid()) { - if (diagnose) { - llvm::StringRef conformanceToSuggest = isCodableRequirement ? - "Codable" : // Codable is a typealias, easier to diagnose like that - serializationReq->getNameStr(); - - auto diag = valueDecl->diagnose( - diag::distributed_actor_target_result_not_codable, - resultType, - valueDecl, - conformanceToSuggest - ); - - if (isCodableRequirement) { - if (auto resultNominalType = resultType->getAnyNominal()) { - addCodableFixIt(resultNominalType, diag); - } + for (auto serializationReq: serializationRequirements) { + auto conformance = + TypeChecker::conformsToProtocol(resultType, serializationReq, module); + if (conformance.isInvalid()) { + if (diagnose) { + llvm::StringRef conformanceToSuggest = isCodableRequirement ? + "Codable" : // Codable is a typealias, easier to diagnose like that + serializationReq->getNameStr(); + + auto diag = valueDecl->diagnose( + diag::distributed_actor_target_result_not_codable, + resultType, + valueDecl, + conformanceToSuggest + ); + + if (isCodableRequirement) { + if (auto resultNominalType = resultType->getAnyNominal()) { + addCodableFixIt(resultNominalType, diag); } - } // end if: diagnose + } + } // end if: diagnose - return true; - } + return true; } } @@ -493,16 +504,16 @@ bool CheckDistributedFunctionRequest::evaluate( } auto &C = func->getASTContext(); - auto DC = func->getDeclContext(); auto module = func->getParentModule(); /// If no distributed module is available, then no reason to even try checks. if (!C.getLoadedModule(C.Id_Distributed)) return true; - Type serializationReqType = getConcreteReplacementForMemberSerializationRequirement(func); - for (auto param: *func->getParameters()) { + llvm::SmallPtrSet serializationRequirements; + Type serializationReqType = getSerializationRequirementTypesForMember(func, serializationRequirements); + for (auto param: *func->getParameters()) { // --- Check the parameter conforming to serialization requirements if (serializationReqType && !serializationReqType->hasError()) { // If the requirement is exactly `Codable` we diagnose it ia bit nicer. @@ -565,12 +576,11 @@ bool CheckDistributedFunctionRequest::evaluate( } } - if (serializationReqType && !serializationReqType->hasError()) { - // --- Result type must be either void or a codable type - if (checkDistributedTargetResultType(module, func, serializationReqType, - /*diagnose=*/true)) { - return true; - } + // --- Result type must be either void or a serialization requirement conforming type + if (checkDistributedTargetResultType( + module, func, serializationReqType, serializationRequirements, + /*diagnose=*/true)) { + return true; } return false; @@ -618,15 +628,16 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) { DC->getSelfNominalTypeDecl()->getDistributedActorSystemProperty(); auto systemDecl = systemVar->getInterfaceType()->getAnyNominal(); -// auto serializationRequirements = -// getDistributedSerializationRequirementProtocols( -// systemDecl, -// C.getProtocol(KnownProtocolKind::DistributedActorSystem)); + auto serializationRequirements = + getDistributedSerializationRequirementProtocols( + systemDecl, + C.getProtocol(KnownProtocolKind::DistributedActorSystem)); + auto serializationRequirement = - getConcreteReplacementForMemberSerializationRequirement(systemVar); + getSerializationRequirementTypesForMember(systemVar, serializationRequirements); auto module = var->getModuleContext(); - if (checkDistributedTargetResultType(module, var, serializationRequirement, diagnose)) { + if (checkDistributedTargetResultType(module, var, serializationRequirement, serializationRequirements, diagnose)) { return true; } diff --git a/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift b/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift index af4fa1020ce58..eace8fc19ea16 100644 --- a/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift +++ b/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift @@ -82,7 +82,7 @@ extension NoSerializationRequirementYet extension NoSerializationRequirementYet where SerializationRequirement: Codable { - // expected-error@+1{{result type 'NotCodable' of distributed instance method 'test4' does not conform to serialization requirement 'Codable'}} + // expected-error@+1{{result type 'NotCodable' of distributed instance method 'test4' does not conform to serialization requirement 'Decodable'}} distributed func test4() -> NotCodable { .init() }