From 357108c39a8212e7790e3190d4534ed4af0d3150 Mon Sep 17 00:00:00 2001 From: Konrad `ktoso` Malawski Date: Mon, 30 Oct 2023 19:44:24 +0900 Subject: [PATCH 01/11] [Distributed] Don't crash in thunk generation when missing SR conformance --- include/swift/AST/DistributedDecl.h | 6 ++- lib/AST/DistributedDecl.cpp | 9 ++-- lib/Sema/CodeSynthesisDistributedActor.cpp | 8 +++- lib/Sema/TypeCheckDeclOverride.cpp | 2 +- lib/Sema/TypeCheckDistributed.cpp | 25 ++++++++--- lib/Sema/TypeCheckStmt.cpp | 8 ++++ lib/Sema/TypeChecker.h | 4 +- ...r_func_param_not_conforming_req_full.swift | 41 +++++++++++++++++++ 8 files changed, 88 insertions(+), 15 deletions(-) create mode 100644 test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift diff --git a/include/swift/AST/DistributedDecl.h b/include/swift/AST/DistributedDecl.h index 2cef56e9e9ed3..7e7a09b0b3c77 100644 --- a/include/swift/AST/DistributedDecl.h +++ b/include/swift/AST/DistributedDecl.h @@ -123,9 +123,11 @@ getDistributedSerializationRequirements( /// Given any set of generic requirements, locate those which are about the /// `SerializationRequirement`. Those need to be applied in the parameter and /// return type checking of distributed targets. -llvm::SmallPtrSet +void extractDistributedSerializationRequirements( - ASTContext &C, ArrayRef allRequirements); + ASTContext &C, + ArrayRef allRequirements, + llvm::SmallPtrSet &into); } diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 2a896409226b4..716cc9072b1aa 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -1268,10 +1268,11 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const return true; } -llvm::SmallPtrSet +void swift::extractDistributedSerializationRequirements( - ASTContext &C, ArrayRef allRequirements) { - llvm::SmallPtrSet serializationReqs; + ASTContext &C, + ArrayRef allRequirements, + llvm::SmallPtrSet &into) { auto DA = C.getDistributedActorDecl(); auto daSerializationReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement); @@ -1308,8 +1309,6 @@ swift::extractDistributedSerializationRequirements( } } } - - return serializationReqs; } /******************************************************************************/ diff --git a/lib/Sema/CodeSynthesisDistributedActor.cpp b/lib/Sema/CodeSynthesisDistributedActor.cpp index 66b1ead965f2c..16ff8a8b0af23 100644 --- a/lib/Sema/CodeSynthesisDistributedActor.cpp +++ b/lib/Sema/CodeSynthesisDistributedActor.cpp @@ -842,9 +842,15 @@ FuncDecl *GetDistributedThunkRequest::evaluate(Evaluator &evaluator, if (!distributedTarget->isDistributed()) return nullptr; } - assert(distributedTarget); + // This evaluation type-check by now was already computed and cached; + // We need to check in order to avoid emitting a THUNK for a distributed func + // which had errors; as the thunk then may also cause un-addressable issues and confusion. + if (swift::checkDistributedFunction(distributedTarget)) { + return nullptr; + } + auto &C = distributedTarget->getASTContext(); if (!getConcreteReplacementForProtocolActorSystemType(distributedTarget)) { diff --git a/lib/Sema/TypeCheckDeclOverride.cpp b/lib/Sema/TypeCheckDeclOverride.cpp index 9aa62497aecb4..a6790c9253c3a 100644 --- a/lib/Sema/TypeCheckDeclOverride.cpp +++ b/lib/Sema/TypeCheckDeclOverride.cpp @@ -2067,7 +2067,7 @@ static bool checkSingleOverride(ValueDecl *override, ValueDecl *base) { return (prop && prop->isFinal() && isa(prop->getDeclContext()) && - cast(prop->getDeclContext())->isActor() && + cast(prop->getDeclContext())->isAnyActor() && !prop->isStatic() && prop->getName() == ctx.Id_unownedExecutor && prop->getInterfaceType()->getAnyNominal() == ctx.getUnownedSerialExecutorDecl()); diff --git a/lib/Sema/TypeCheckDistributed.cpp b/lib/Sema/TypeCheckDistributed.cpp index 9623da08fd84b..a9909b8c43d1b 100644 --- a/lib/Sema/TypeCheckDistributed.cpp +++ b/lib/Sema/TypeCheckDistributed.cpp @@ -507,8 +507,22 @@ bool CheckDistributedFunctionRequest::evaluate( // SerializationRequirement llvm::SmallPtrSet serializationRequirements; if (auto extension = dyn_cast(DC)) { - serializationRequirements = extractDistributedSerializationRequirements( - C, extension->getGenericRequirements()); + auto actorOrProtocol = extension->getExtendedNominal(); + if (auto actor = dyn_cast(actorOrProtocol)) { + assert(actor->isAnyActor()); + serializationRequirements = getDistributedSerializationRequirementProtocols( + getDistributedActorSystemType(actor)->getAnyNominal(), + C.getProtocol(KnownProtocolKind::DistributedActorSystem)); + } else if (auto protocol = dyn_cast(actorOrProtocol)) { + extractDistributedSerializationRequirements( + C, protocol->getGenericRequirements(), + /*into=*/serializationRequirements); + extractDistributedSerializationRequirements( + C, extension->getGenericRequirements(), + /*into=*/serializationRequirements); + } else { + // ignore + } } else if (auto actor = dyn_cast(DC)) { serializationRequirements = getDistributedSerializationRequirementProtocols( getDistributedActorSystemType(actor)->getAnyNominal(), @@ -555,6 +569,7 @@ bool CheckDistributedFunctionRequest::evaluate( if (auto paramNominalTy = paramTy->getAnyNominal()) { addCodableFixIt(paramNominalTy, diag); } // else, no nominal type to suggest the fixit for, e.g. a closure + return true; } } @@ -749,11 +764,11 @@ void TypeChecker::checkDistributedActor(SourceFile *SF, NominalTypeDecl *nominal (void)nominal->getDistributedActorIDProperty(); } -void TypeChecker::checkDistributedFunc(FuncDecl *func) { +bool TypeChecker::checkDistributedFunc(FuncDecl *func) { if (!func->isDistributed()) - return; + return false; - swift::checkDistributedFunction(func); + return swift::checkDistributedFunction(func); } llvm::SmallPtrSet diff --git a/lib/Sema/TypeCheckStmt.cpp b/lib/Sema/TypeCheckStmt.cpp index abde198ae90d8..54d46182b6605 100644 --- a/lib/Sema/TypeCheckStmt.cpp +++ b/lib/Sema/TypeCheckStmt.cpp @@ -2778,6 +2778,14 @@ TypeCheckFunctionBodyRequest::evaluate(Evaluator &eval, // So, build out the body now. ASTScope::expandFunctionBody(AFD); + if (AFD->isDistributedThunk()) { + if (auto func = dyn_cast(AFD)) { + if (TypeChecker::checkDistributedFunc(func)) { + return errorBody(); + } + } + } + // Type check the function body if needed. bool hadError = false; if (!alreadyTypeChecked) { diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index fa77dae7e5967..5bff84aa1ad75 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -1132,7 +1132,9 @@ diagnosePotentialUnavailability(SourceRange ReferenceRange, void checkDistributedActor(SourceFile *SF, NominalTypeDecl *decl); /// Type check a single 'distributed func' declaration. -void checkDistributedFunc(FuncDecl *func); +/// +/// Returns `true` if there was an error. +bool checkDistributedFunc(FuncDecl *func); bool checkAvailability(SourceRange ReferenceRange, AvailabilityContext Availability, diff --git a/test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift b/test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift new file mode 100644 index 0000000000000..785accbb2e58d --- /dev/null +++ b/test/Distributed/distributed_actor_func_param_not_conforming_req_full.swift @@ -0,0 +1,41 @@ +// RUN: %empty-directory(%t) +// RUN: %target-swift-frontend-emit-module -emit-module-path %t/FakeDistributedActorSystems.swiftmodule -module-name FakeDistributedActorSystems -disable-availability-checking %S/Inputs/FakeDistributedActorSystems.swift +// RUN: %target-build-swift -module-name main -Xfrontend -disable-availability-checking -j2 -parse-as-library -I %t %s %S/Inputs/FakeDistributedActorSystems.swift 2> %t/output.txt || echo 'failed expectedly' +// RUN: %FileCheck %s < %t/output.txt + +// REQUIRES: concurrency +// REQUIRES: distributed + +// rdar://76038845 +// UNSUPPORTED: use_os_stdlib +// UNSUPPORTED: back_deployment_runtime + +import Distributed + +// Notes: +// This test specifically is not just a -typecheck -verify test but attempts to generate the whole module. +// This is because we may be emitting errors but otherwise still attempt to emit a thunk for an "error-ed" +// distributed function, which would then crash in later phases of compilation when we try to get types +// of the `func` the THUNK is based on. + +typealias DefaultDistributedActorSystem = LocalTestingDistributedActorSystem + +distributed actor Service { +} + +extension Service { + distributed func boombox(_ id: Box) async throws {} + // CHECK: parameter '' of type 'Box' in distributed instance method does not conform to serialization requirement 'Codable' + + distributed func boxIt() async throws -> Box { fatalError() } + // CHECK: result type 'Box' of distributed instance method 'boxIt' does not conform to serialization requirement 'Codable' +} + +public enum Box: Hashable { case boom } + +@main struct Main { + static func main() async { + try? await Service(actorSystem: .init()).boombox(Box.boom) + } +} + From 64c3d97fa44ad563133817bca0faeebe7a97d03f Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Mon, 30 Oct 2023 22:56:04 -0400 Subject: [PATCH 02/11] Distributed: Remove unnecessary unwrapping of TypeAliasType --- lib/AST/DistributedDecl.cpp | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 716cc9072b1aa..ebcf4813ce6c8 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -178,13 +178,7 @@ Type swift::getDistributedActorSystemResultHandlerType( auto module = system->getParentModule(); Type selfType = system->getSelfInterfaceType(); auto conformance = module->lookupConformance(selfType, DAS); - auto witness = - conformance.getTypeWitnessByName(selfType, ctx.Id_ResultHandler); - if (auto alias = dyn_cast(witness.getPointer())) { - return alias->getDecl()->getUnderlyingType(); - } else { - return witness; - } + return conformance.getTypeWitnessByName(selfType, ctx.Id_ResultHandler); } Type swift::getDistributedActorSystemInvocationEncoderType(NominalTypeDecl *system) { From 7991412b89ea209c2214e205d8a960a21c236ac1 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Mon, 30 Oct 2023 14:53:44 -0400 Subject: [PATCH 03/11] Distributed: Some cleanups --- lib/AST/DistributedDecl.cpp | 65 ++++++++++--------------------- lib/Sema/TypeCheckDistributed.cpp | 3 +- 2 files changed, 21 insertions(+), 47 deletions(-) diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index ebcf4813ce6c8..531aae5db0c86 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -340,21 +340,16 @@ swift::getDistributedSerializationRequirements( if (existentialRequirementTy->isAny()) return true; // we're done here, any means there are no requirements - if (!existentialRequirementTy->isExistentialType()) { - // SerializationRequirement must be an existential type - return false; - } - ExistentialType *serialReqType = existentialRequirementTy - ->castTo(); + ->getAs(); if (!serialReqType || serialReqType->hasError()) { return false; } - auto desugaredTy = serialReqType->getConstraintType()->getDesugaredType(); + auto desugaredTy = serialReqType->getConstraintType(); auto flattenedRequirements = flattenDistributedSerializationTypeToRequiredProtocols( - desugaredTy); + desugaredTy.getPointer()); for (auto p : flattenedRequirements) { requirementProtos.insert(p); } @@ -565,25 +560,19 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn) // --- Check requirement: conforms_to: Act DistributedActor auto actorReq = requirements[0]; - auto distActorTy = C.getProtocol(KnownProtocolKind::DistributedActor) - ->getInterfaceType() - ->getMetatypeInstanceType(); if (actorReq.getKind() != RequirementKind::Conformance) { return false; } - if (!actorReq.getSecondType()->isEqual(distActorTy)) { + if (!actorReq.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::DistributedActor)) { return false; } // --- Check requirement: conforms_to: Err Error auto errorReq = requirements[1]; - auto errorTy = C.getProtocol(KnownProtocolKind::Error) - ->getInterfaceType() - ->getMetatypeInstanceType(); if (errorReq.getKind() != RequirementKind::Conformance) { return false; } - if (!errorReq.getSecondType()->isEqual(errorTy)) { + if (!errorReq.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::Error)) { return false; } @@ -598,10 +587,9 @@ bool AbstractFunctionDecl::isDistributedActorSystemRemoteCall(bool isVoidReturn) assert(ResParam && "Non void function, yet no Res generic parameter found"); if (auto func = dyn_cast(this)) { auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType()) - ->getMetatypeInstanceType() - ->getDesugaredType(); + ->getMetatypeInstanceType(); auto resultParamType = func->mapTypeIntoContext( - ResParam->getInterfaceType()->getMetatypeInstanceType()); + ResParam->getDeclaredInterfaceType()); // The result of the function must be the `Res` generic argument. if (!resultType->isEqual(resultParamType)) { return false; @@ -797,12 +785,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordArgument() const // the of the RemoteCallArgument auto remoteCallArgValueGenericTy = - mapTypeIntoContext(argGenericParams[0]->getInterfaceType()) - ->getDesugaredType() - ->getMetatypeInstanceType(); + mapTypeIntoContext(argGenericParams[0]->getDeclaredInterfaceType()); // expected (the from the recordArgument) auto expectedGenericParamTy = mapTypeIntoContext( - ArgumentParam->getInterfaceType()->getMetatypeInstanceType()); + ArgumentParam->getDeclaredInterfaceType()); if (!remoteCallArgValueGenericTy->isEqual(expectedGenericParamTy)) { return false; @@ -932,11 +918,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordReturnType() con // ... auto resultType = func->mapTypeIntoContext(argumentParam->getInterfaceType()) - ->getMetatypeInstanceType() - ->getDesugaredType(); + ->getMetatypeInstanceType(); auto resultParamType = func->mapTypeIntoContext( - ArgumentParam->getInterfaceType()->getMetatypeInstanceType()); + ArgumentParam->getDeclaredInterfaceType()); // The result of the function must be the `Res` generic argument. if (!resultType->isEqual(resultParamType)) { @@ -1046,13 +1031,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationEncoderRecordErrorType() cons // --- Check requirement: conforms_to: Err Error auto errorReq = requirements[0]; - auto errorTy = C.getProtocol(KnownProtocolKind::Error) - ->getInterfaceType() - ->getMetatypeInstanceType(); if (errorReq.getKind() != RequirementKind::Conformance) { return false; } - if (!errorReq.getSecondType()->isEqual(errorTy)) { + if (!errorReq.getProtocolDecl()->isSpecificProtocol(KnownProtocolKind::Error)) { return false; } @@ -1139,10 +1121,9 @@ AbstractFunctionDecl::isDistributedTargetInvocationDecoderDecodeNextArgument() c // --- Check: Argument: SerializationRequirement GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0]; auto resultType = func->mapTypeIntoContext(func->getResultInterfaceType()) - ->getMetatypeInstanceType() - ->getDesugaredType(); + ->getMetatypeInstanceType(); auto resultParamType = func->mapTypeIntoContext( - ArgumentParam->getInterfaceType()->getMetatypeInstanceType()); + ArgumentParam->getDeclaredInterfaceType()); // The result of the function must be the `Res` generic argument. if (!resultType->isEqual(resultParamType)) { return false; @@ -1237,11 +1218,10 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const // === Check generic parameters in detail // --- Check: Argument: SerializationRequirement GenericTypeParamDecl *ArgumentParam = genericParams->getParams()[0]; - auto argumentType = func->mapTypeIntoContext(valueParam->getInterfaceType()) - ->getMetatypeInstanceType() - ->getDesugaredType(); + auto argumentType = func->mapTypeIntoContext( + valueParam->getInterfaceType()->getMetatypeInstanceType()); auto resultParamType = func->mapTypeIntoContext( - ArgumentParam->getInterfaceType()->getMetatypeInstanceType()); + ArgumentParam->getDeclaredInterfaceType()); // The result of the function must be the `Res` generic argument. if (!argumentType->isEqual(resultParamType)) { return false; @@ -1270,7 +1250,6 @@ swift::extractDistributedSerializationRequirements( auto DA = C.getDistributedActorDecl(); auto daSerializationReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement); - auto daSystemSerializationReqTy = daSerializationReqAssocType->getInterfaceType(); for (auto req : allRequirements) { if (req.getSecondType()->isAny()) { @@ -1281,21 +1260,17 @@ swift::extractDistributedSerializationRequirements( if (auto dependentMemberType = req.getFirstType()->castTo()) { - auto dependentTy = - dependentMemberType->getAssocType()->getInterfaceType(); - - if (dependentTy->isEqual(daSystemSerializationReqTy)) { + if (dependentMemberType->getAssocType() == daSerializationReqAssocType) { auto requirementProto = req.getSecondType(); if (auto proto = dyn_cast_or_null( requirementProto->getAnyNominal())) { serializationReqs.insert(proto); } else { auto serialReqType = requirementProto->castTo() - ->getConstraintType() - ->getDesugaredType(); + ->getConstraintType(); auto flattenedRequirements = flattenDistributedSerializationTypeToRequiredProtocols( - serialReqType); + serialReqType.getPointer()); for (auto p : flattenedRequirements) { serializationReqs.insert(p); } diff --git a/lib/Sema/TypeCheckDistributed.cpp b/lib/Sema/TypeCheckDistributed.cpp index a9909b8c43d1b..4b68db31fed91 100644 --- a/lib/Sema/TypeCheckDistributed.cpp +++ b/lib/Sema/TypeCheckDistributed.cpp @@ -911,8 +911,7 @@ GetDistributedActorArgumentDecodingMethodRequest::evaluate(Evaluator &evaluator, continue; auto paramTy = genericParamList->getParams()[0] - ->getInterfaceType() - ->getMetatypeInstanceType(); + ->getDeclaredInterfaceType(); // `decodeNextArgument` should return its generic parameter value if (!FD->getResultInterfaceType()->isEqual(paramTy)) From c71564d97507e29c458b3bdadedd9e9f762a3714 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Mon, 30 Oct 2023 14:59:52 -0400 Subject: [PATCH 04/11] Distributed: Remove flattenDistributedSerializationTypeToRequiredProtocols() --- include/swift/AST/DistributedDecl.h | 6 ---- lib/AST/DistributedDecl.cpp | 50 ++++------------------------- lib/Sema/TypeCheckDistributed.cpp | 14 ++++---- 3 files changed, 14 insertions(+), 56 deletions(-) diff --git a/include/swift/AST/DistributedDecl.h b/include/swift/AST/DistributedDecl.h index 7e7a09b0b3c77..ba916be08c748 100644 --- a/include/swift/AST/DistributedDecl.h +++ b/include/swift/AST/DistributedDecl.h @@ -91,12 +91,6 @@ llvm::SmallPtrSet getDistributedSerializationRequirementProtocols( NominalTypeDecl *decl, ProtocolDecl* protocol); -/// Desugar and flatten the `SerializationRequirement` type into a set of -/// specific protocol declarations. -llvm::SmallPtrSet -flattenDistributedSerializationTypeToRequiredProtocols( - TypeBase *serializationRequirement); - /// Check if the `allRequirements` represent *exactly* the /// `Encodable & Decodable` (also known as `Codable`) requirement. /// diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 531aae5db0c86..0f6c24082f081 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -340,47 +340,19 @@ swift::getDistributedSerializationRequirements( if (existentialRequirementTy->isAny()) return true; // we're done here, any means there are no requirements - ExistentialType *serialReqType = existentialRequirementTy - ->getAs(); + auto *serialReqType = existentialRequirementTy->getAs(); if (!serialReqType || serialReqType->hasError()) { return false; } - auto desugaredTy = serialReqType->getConstraintType(); - auto flattenedRequirements = - flattenDistributedSerializationTypeToRequiredProtocols( - desugaredTy.getPointer()); - for (auto p : flattenedRequirements) { + auto layout = serialReqType->getExistentialLayout(); + for (auto p : layout.getProtocols()) { requirementProtos.insert(p); } return true; } -llvm::SmallPtrSet -swift::flattenDistributedSerializationTypeToRequiredProtocols( - TypeBase *serializationRequirement) { - llvm::SmallPtrSet serializationReqs; - if (auto composition = - serializationRequirement->getAs()) { - for (auto member : composition->getMembers()) { - if (auto comp = member->getAs()) { - for (auto protocol : - flattenDistributedSerializationTypeToRequiredProtocols(comp)) { - serializationReqs.insert(protocol); - } - } else if (auto *protocol = member->getAs()) { - serializationReqs.insert(protocol->getDecl()); - } - } - } else { - auto protocol = serializationRequirement->castTo()->getDecl(); - serializationReqs.insert(protocol); - } - - return serializationReqs; -} - bool swift::checkDistributedSerializationRequirementIsExactlyCodable( ASTContext &C, const llvm::SmallPtrSetImpl &allRequirements) { @@ -1261,19 +1233,9 @@ swift::extractDistributedSerializationRequirements( if (auto dependentMemberType = req.getFirstType()->castTo()) { if (dependentMemberType->getAssocType() == daSerializationReqAssocType) { - auto requirementProto = req.getSecondType(); - if (auto proto = dyn_cast_or_null( - requirementProto->getAnyNominal())) { - serializationReqs.insert(proto); - } else { - auto serialReqType = requirementProto->castTo() - ->getConstraintType(); - auto flattenedRequirements = - flattenDistributedSerializationTypeToRequiredProtocols( - serialReqType.getPointer()); - for (auto p : flattenedRequirements) { - serializationReqs.insert(p); - } + auto layout = req.getSecondType()->getExistentialLayout(); + for (auto p : layout.getProtocols()) { + serializationReqs.insert(p); } } } diff --git a/lib/Sema/TypeCheckDistributed.cpp b/lib/Sema/TypeCheckDistributed.cpp index 4b68db31fed91..978ffa63018e6 100644 --- a/lib/Sema/TypeCheckDistributed.cpp +++ b/lib/Sema/TypeCheckDistributed.cpp @@ -530,8 +530,8 @@ bool CheckDistributedFunctionRequest::evaluate( } else if (isa(DC)) { if (auto seqReqTy = getConcreteReplacementForMemberSerializationRequirement(func)) { - auto seqReqTyDes = seqReqTy->castTo()->getConstraintType()->getDesugaredType(); - for (auto req : flattenDistributedSerializationTypeToRequiredProtocols(seqReqTyDes)) { + auto layout = seqReqTy->getExistentialLayout(); + for (auto req : layout.getProtocols()) { serializationRequirements.insert(req); } } @@ -783,11 +783,13 @@ swift::getDistributedSerializationRequirementProtocols( return {}; } - auto serialReqType = - ty->castTo()->getConstraintType()->getDesugaredType(); - // TODO(distributed): check what happens with Any - return flattenDistributedSerializationTypeToRequiredProtocols(serialReqType); + auto layout = ty->getExistentialLayout(); + llvm::SmallPtrSet result; + for (auto p : layout.getProtocols()) { + result.insert(p); + } + return result; } ConstructorDecl* From fde0cde69e84c39cb3ea1e037ad57302ff61cc47 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Mon, 30 Oct 2023 15:06:15 -0400 Subject: [PATCH 05/11] Distributed: Remove walk over requirements --- lib/Sema/TypeCheckDistributed.cpp | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/lib/Sema/TypeCheckDistributed.cpp b/lib/Sema/TypeCheckDistributed.cpp index 978ffa63018e6..270f06e0a38a2 100644 --- a/lib/Sema/TypeCheckDistributed.cpp +++ b/lib/Sema/TypeCheckDistributed.cpp @@ -921,20 +921,16 @@ GetDistributedActorArgumentDecodingMethodRequest::evaluate(Evaluator &evaluator, // Let's find out how many serialization requirements does this method cover // e.g. `Codable` is two requirements - `Encodable` and `Decodable`. - unsigned numSerializationReqsCovered = llvm::count_if( - FD->getGenericRequirements(), [&](const Requirement &requirement) { - if (!(requirement.getFirstType()->isEqual(paramTy) && - requirement.getKind() == RequirementKind::Conformance)) - return 0; - - return serializationReqs.count(requirement.getProtocolDecl()) ? 1 : 0; - }); + bool okay = llvm::all_of(serializationReqs, + [&](ProtocolDecl *p) -> bool { + return FD->getGenericSignature()->requiresProtocol(paramTy, p); + }); // If the current method covers all of the serialization requirements, // it's a match. Note that it might also have other requirements, but // we let that go as long as there are no two candidates that differ // only in generic requirements. - if (numSerializationReqsCovered == serializationReqs.size()) + if (okay) candidates.push_back(FD); } From 14b5d5bd7e2f26b64becee9e4c3dbba5bb37e55e Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Mon, 30 Oct 2023 15:08:09 -0400 Subject: [PATCH 06/11] Distributed: Simplify extractDistributedSerializationRequirements() --- lib/AST/DistributedDecl.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 0f6c24082f081..7d9235364714a 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -1224,14 +1224,13 @@ swift::extractDistributedSerializationRequirements( DA->getAssociatedType(C.Id_SerializationRequirement); for (auto req : allRequirements) { - if (req.getSecondType()->isAny()) { - continue; - } - if (!req.getFirstType()->hasDependentMember()) + // FIXME: Seems unprincipled + if (req.getKind() != RequirementKind::SameType && + req.getKind() != RequirementKind::Conformance) continue; if (auto dependentMemberType = - req.getFirstType()->castTo()) { + req.getFirstType()->getAs()) { if (dependentMemberType->getAssocType() == daSerializationReqAssocType) { auto layout = req.getSecondType()->getExistentialLayout(); for (auto p : layout.getProtocols()) { From 9d569656be72f749976aa1d85ea1ec2149753d11 Mon Sep 17 00:00:00 2001 From: Konrad `ktoso` Malawski Date: Tue, 31 Oct 2023 14:48:09 +0900 Subject: [PATCH 07/11] Use more of getConcreteReplacementForMemberSerializationRequirement --- include/swift/AST/DistributedDecl.h | 2 +- lib/AST/DistributedDecl.cpp | 26 ++-- lib/Sema/TypeCheckDistributed.cpp | 180 ++++++++++++---------------- 3 files changed, 94 insertions(+), 114 deletions(-) diff --git a/include/swift/AST/DistributedDecl.h b/include/swift/AST/DistributedDecl.h index ba916be08c748..76758d3de461e 100644 --- a/include/swift/AST/DistributedDecl.h +++ b/include/swift/AST/DistributedDecl.h @@ -97,7 +97,7 @@ getDistributedSerializationRequirementProtocols( /// If so, we can emit slightly nicer diagnostics. bool checkDistributedSerializationRequirementIsExactlyCodable( ASTContext &C, - const llvm::SmallPtrSetImpl &allRequirements); + Type type); /// Get the `SerializationRequirement`, explode it into the specific /// protocol requirements and insert them into `requirements`. diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 7d9235364714a..1aae7b4bb33c0 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -106,8 +106,10 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement( return getDistributedSerializationRequirementType(classDecl, C.getDistributedActorDecl()); } - /// === Maybe the value is declared in a protocol? - if (auto protocol = DC->getSelfProtocolDecl()) { + auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement) + ->getDeclaredInterfaceType(); + + if (DC->getSelfProtocolDecl() || isa(DC)) { GenericSignature signature; if (auto *genericContext = member->getAsGenericContext()) { signature = genericContext->getGenericSignature(); @@ -115,9 +117,6 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement( signature = DC->getGenericSignatureOfContext(); } - auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement) - ->getDeclaredInterfaceType(); - // Note that this may be null, e.g. if we're a distributed func inside // a protocol that did not declare a specific actor system requirement. return signature->getConcreteType(SerReqAssocType); @@ -355,15 +354,24 @@ swift::getDistributedSerializationRequirements( bool swift::checkDistributedSerializationRequirementIsExactlyCodable( ASTContext &C, - const llvm::SmallPtrSetImpl &allRequirements) { + Type type) { + if (!type) + return false; + + if (type->hasError()) + return false; + auto encodable = C.getProtocol(KnownProtocolKind::Encodable); auto decodable = C.getProtocol(KnownProtocolKind::Decodable); - if (allRequirements.size() != 2) + auto layout = type->getExistentialLayout(); + auto protocols = layout.getProtocols(); + + if (protocols.size() != 2) return false; - return allRequirements.count(encodable) && - allRequirements.count(decodable); + return std::count(protocols.begin(), protocols.end(), encodable) == 1 && + std::count(protocols.begin(), protocols.end(), decodable) == 1; } /******************************************************************************/ diff --git a/lib/Sema/TypeCheckDistributed.cpp b/lib/Sema/TypeCheckDistributed.cpp index 270f06e0a38a2..df17a42fa2275 100644 --- a/lib/Sema/TypeCheckDistributed.cpp +++ b/lib/Sema/TypeCheckDistributed.cpp @@ -385,10 +385,13 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements( static bool checkDistributedTargetResultType( ModuleDecl *module, ValueDecl *valueDecl, - const llvm::SmallPtrSetImpl &serializationRequirements, + Type serializationRequirement, bool diagnose) { auto &C = valueDecl->getASTContext(); + if (!serializationRequirement || serializationRequirement->hasError()) + return false; // error of the type would be diagnosed elsewhere + Type resultType; if (auto func = dyn_cast(valueDecl)) { resultType = func->mapTypeIntoContext(func->getResultInterfaceType()); @@ -403,36 +406,39 @@ static bool checkDistributedTargetResultType( auto isCodableRequirement = checkDistributedSerializationRequirementIsExactlyCodable( - C, serializationRequirements); - - for(auto serializationReq : serializationRequirements) { - auto conformance = - TypeChecker::conformsToProtocol(resultType, serializationReq, module); - if (conformance.isInvalid()) { - if (diagnose) { - llvm::StringRef conformanceToSuggest = isCodableRequirement ? - "Codable" : // Codable is a typealias, easier to diagnose like that - serializationReq->getNameStr(); - - auto diag = valueDecl->diagnose( - diag::distributed_actor_target_result_not_codable, - resultType, - valueDecl, - conformanceToSuggest - ); - - if (isCodableRequirement) { - if (auto resultNominalType = resultType->getAnyNominal()) { - addCodableFixIt(resultNominalType, diag); + C, serializationRequirement); + + if (serializationRequirement && !serializationRequirement->hasError()) { + auto srl = serializationRequirement->getExistentialLayout(); + for (auto serializationReq: srl.getProtocols()) { + auto conformance = + TypeChecker::conformsToProtocol(resultType, serializationReq, module); + if (conformance.isInvalid()) { + if (diagnose) { + llvm::StringRef conformanceToSuggest = isCodableRequirement ? + "Codable" : // Codable is a typealias, easier to diagnose like that + serializationReq->getNameStr(); + + auto diag = valueDecl->diagnose( + diag::distributed_actor_target_result_not_codable, + resultType, + valueDecl, + conformanceToSuggest + ); + + if (isCodableRequirement) { + if (auto resultNominalType = resultType->getAnyNominal()) { + addCodableFixIt(resultNominalType, diag); + } } - } - } // end if: diagnose - - return true; + } // end if: diagnose + + return true; + } } } - return false; + return false; } bool swift::checkDistributedActorSystem(const NominalTypeDecl *system) { @@ -503,74 +509,35 @@ bool CheckDistributedFunctionRequest::evaluate( if (!C.getLoadedModule(C.Id_Distributed)) return true; - // === All parameters and the result type must conform - // SerializationRequirement - llvm::SmallPtrSet serializationRequirements; - if (auto extension = dyn_cast(DC)) { - auto actorOrProtocol = extension->getExtendedNominal(); - if (auto actor = dyn_cast(actorOrProtocol)) { - assert(actor->isAnyActor()); - serializationRequirements = getDistributedSerializationRequirementProtocols( - getDistributedActorSystemType(actor)->getAnyNominal(), - C.getProtocol(KnownProtocolKind::DistributedActorSystem)); - } else if (auto protocol = dyn_cast(actorOrProtocol)) { - extractDistributedSerializationRequirements( - C, protocol->getGenericRequirements(), - /*into=*/serializationRequirements); - extractDistributedSerializationRequirements( - C, extension->getGenericRequirements(), - /*into=*/serializationRequirements); - } else { - // ignore - } - } else if (auto actor = dyn_cast(DC)) { - serializationRequirements = getDistributedSerializationRequirementProtocols( - getDistributedActorSystemType(actor)->getAnyNominal(), - C.getProtocol(KnownProtocolKind::DistributedActorSystem)); - } else if (isa(DC)) { - if (auto seqReqTy = - getConcreteReplacementForMemberSerializationRequirement(func)) { - auto layout = seqReqTy->getExistentialLayout(); - for (auto req : layout.getProtocols()) { - serializationRequirements.insert(req); - } - } - - // The distributed actor constrained protocol has no serialization requirements - // or actor system defined, so these will only be enforced, by implementations - // of DAs conforming to it, skip checks here. - if (serializationRequirements.empty()) { - return false; - } - } else { - llvm_unreachable("Distributed function detected in type other than extension, " - "distributed actor, or protocol! This should not be possible " - ", please file a bug."); - } - - // If the requirement is exactly `Codable` we diagnose it ia bit nicer. - auto serializationRequirementIsCodable = - checkDistributedSerializationRequirementIsExactlyCodable( - C, serializationRequirements); - - for (auto param : *func->getParameters()) { - // --- Check parameters for 'Codable' conformance - auto paramTy = func->mapTypeIntoContext(param->getInterfaceType()); - - for (auto req : serializationRequirements) { - if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) { - auto diag = func->diagnose( - diag::distributed_actor_func_param_not_codable, - param->getArgumentName().str(), param->getInterfaceType(), - func->getDescriptiveKind(), - serializationRequirementIsCodable ? "Codable" - : req->getNameStr()); - - if (auto paramNominalTy = paramTy->getAnyNominal()) { - addCodableFixIt(paramNominalTy, diag); - } // else, no nominal type to suggest the fixit for, e.g. a closure - - return true; + Type serializationReqType = getConcreteReplacementForMemberSerializationRequirement(func); + for (auto param: *func->getParameters()) { + + // --- Check the parameter conforming to serialization requirements + if (serializationReqType && !serializationReqType->hasError()) { + // If the requirement is exactly `Codable` we diagnose it ia bit nicer. + auto serializationRequirementIsCodable = + checkDistributedSerializationRequirementIsExactlyCodable( + C, serializationReqType); + + // --- Check parameters for 'SerializationRequirement' conformance + auto paramTy = func->mapTypeIntoContext(param->getInterfaceType()); + + auto srl = serializationReqType->getExistentialLayout(); + for (auto req: srl.getProtocols()) { + if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) { + auto diag = func->diagnose( + diag::distributed_actor_func_param_not_codable, + param->getArgumentName().str(), param->getInterfaceType(), + func->getDescriptiveKind(), + serializationRequirementIsCodable ? "Codable" + : req->getNameStr()); + + if (auto paramNominalTy = paramTy->getAnyNominal()) { + addCodableFixIt(paramNominalTy, diag); + } // else, no nominal type to suggest the fixit for, e.g. a closure + + return true; + } } } @@ -607,10 +574,12 @@ bool CheckDistributedFunctionRequest::evaluate( } } - // --- Result type must be either void or a codable type - if (checkDistributedTargetResultType(module, func, serializationRequirements, - /*diagnose=*/true)) { - return true; + if (serializationReqType && !serializationReqType->hasError()) { + // --- Result type must be either void or a codable type + if (checkDistributedTargetResultType(module, func, serializationReqType, + /*diagnose=*/true)) { + return true; + } } return false; @@ -658,13 +627,15 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) { DC->getSelfNominalTypeDecl()->getDistributedActorSystemProperty(); auto systemDecl = systemVar->getInterfaceType()->getAnyNominal(); - auto serializationRequirements = - getDistributedSerializationRequirementProtocols( - systemDecl, - C.getProtocol(KnownProtocolKind::DistributedActorSystem)); +// auto serializationRequirements = +// getDistributedSerializationRequirementProtocols( +// systemDecl, +// C.getProtocol(KnownProtocolKind::DistributedActorSystem)); + auto serializationRequirement = + getConcreteReplacementForMemberSerializationRequirement(systemVar); auto module = var->getModuleContext(); - if (checkDistributedTargetResultType(module, var, serializationRequirements, diagnose)) { + if (checkDistributedTargetResultType(module, var, serializationRequirement, diagnose)) { return true; } @@ -771,6 +742,7 @@ bool TypeChecker::checkDistributedFunc(FuncDecl *func) { return swift::checkDistributedFunction(func); } +// TODO(distributed): Remove this entirely and rely on generic signature and getConcrete to implement checks llvm::SmallPtrSet swift::getDistributedSerializationRequirementProtocols( NominalTypeDecl *nominal, ProtocolDecl *protocol) { From 3d47f7042f26e967c81a6fcb3a94a3c29226f7a1 Mon Sep 17 00:00:00 2001 From: Konrad `ktoso` Malawski Date: Tue, 14 Nov 2023 19:10:36 +0900 Subject: [PATCH 08/11] handle conformance requirement on extension in distributed funcs --- include/swift/AST/DistributedDecl.h | 14 +-- lib/AST/DistributedDecl.cpp | 44 +++------ lib/Sema/TypeCheckDistributed.cpp | 93 +++++++++++-------- ...uted_func_serialization_requirements.swift | 2 +- 4 files changed, 70 insertions(+), 83 deletions(-) diff --git a/include/swift/AST/DistributedDecl.h b/include/swift/AST/DistributedDecl.h index 76758d3de461e..018c5f22587cc 100644 --- a/include/swift/AST/DistributedDecl.h +++ b/include/swift/AST/DistributedDecl.h @@ -50,7 +50,8 @@ Type getDistributedActorIDType(NominalTypeDecl *actor); /// Similar to `getDistributedSerializationRequirementType`, however, from the /// perspective of a concrete function. This way we're able to get the /// serialization requirement for specific members, also in protocols. -Type getConcreteReplacementForMemberSerializationRequirement(ValueDecl *member); +Type getSerializationRequirementTypesForMember( + ValueDecl *member, llvm::SmallPtrSet &serializationRequirements); /// Get specific 'SerializationRequirement' as defined in 'nominal' /// type, which must conform to the passed 'protocol' which is expected @@ -114,17 +115,6 @@ getDistributedSerializationRequirements( ProtocolDecl *protocol, llvm::SmallPtrSetImpl &requirementProtos); -/// Given any set of generic requirements, locate those which are about the -/// `SerializationRequirement`. Those need to be applied in the parameter and -/// return type checking of distributed targets. -void -extractDistributedSerializationRequirements( - ASTContext &C, - ArrayRef allRequirements, - llvm::SmallPtrSet &into); - -} - // ==== ------------------------------------------------------------------------ #endif /* SWIFT_DECL_DISTRIBUTEDDECL_H */ diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 1aae7b4bb33c0..274d41701d878 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -95,8 +95,9 @@ Type swift::getConcreteReplacementForProtocolActorSystemType(ValueDecl *member) llvm_unreachable("Unable to fetch ActorSystem type!"); } -Type swift::getConcreteReplacementForMemberSerializationRequirement( - ValueDecl *member) { +Type swift::getSerializationRequirementTypesForMember( + ValueDecl *member, + llvm::SmallPtrSet &serializationRequirements) { auto &C = member->getASTContext(); auto *DC = member->getDeclContext(); auto DA = C.getDistributedActorDecl(); @@ -117,6 +118,18 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement( signature = DC->getGenericSignatureOfContext(); } + // Also store all `SerializationRequirement : SomeProtocol` requirements + for (auto requirement: signature.getRequirements()) { + if (requirement.getFirstType()->isEqual(SerReqAssocType) && + requirement.getKind() == RequirementKind::Conformance) { + if (auto nominal = requirement.getSecondType()->getAnyNominal()) { + if (auto protocol = dyn_cast(nominal)) { + serializationRequirements.insert(protocol); + } + } + } + } + // Note that this may be null, e.g. if we're a distributed func inside // a protocol that did not declare a specific actor system requirement. return signature->getConcreteType(SerReqAssocType); @@ -1222,33 +1235,6 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const return true; } -void -swift::extractDistributedSerializationRequirements( - ASTContext &C, - ArrayRef allRequirements, - llvm::SmallPtrSet &into) { - auto DA = C.getDistributedActorDecl(); - auto daSerializationReqAssocType = - DA->getAssociatedType(C.Id_SerializationRequirement); - - for (auto req : allRequirements) { - // FIXME: Seems unprincipled - if (req.getKind() != RequirementKind::SameType && - req.getKind() != RequirementKind::Conformance) - continue; - - if (auto dependentMemberType = - req.getFirstType()->getAs()) { - if (dependentMemberType->getAssocType() == daSerializationReqAssocType) { - auto layout = req.getSecondType()->getExistentialLayout(); - for (auto p : layout.getProtocols()) { - serializationReqs.insert(p); - } - } - } - } -} - /******************************************************************************/ /********************** Distributed Functions *********************************/ /******************************************************************************/ diff --git a/lib/Sema/TypeCheckDistributed.cpp b/lib/Sema/TypeCheckDistributed.cpp index df17a42fa2275..e4b6624856510 100644 --- a/lib/Sema/TypeCheckDistributed.cpp +++ b/lib/Sema/TypeCheckDistributed.cpp @@ -386,11 +386,16 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements( static bool checkDistributedTargetResultType( ModuleDecl *module, ValueDecl *valueDecl, Type serializationRequirement, + llvm::SmallPtrSet serializationRequirements, bool diagnose) { auto &C = valueDecl->getASTContext(); - if (!serializationRequirement || serializationRequirement->hasError()) + if (serializationRequirement && serializationRequirement->hasError()) { + return false; + } + if ((!serializationRequirement || serializationRequirement->hasError()) && serializationRequirements.empty()) { return false; // error of the type would be diagnosed elsewhere + } Type resultType; if (auto func = dyn_cast(valueDecl)) { @@ -404,37 +409,43 @@ static bool checkDistributedTargetResultType( if (resultType->isVoid()) return false; + + // Collect extra "SerializationRequirement: SomeProtocol" requirements + if (serializationRequirement && !serializationRequirement->hasError()) { + auto srl = serializationRequirement->getExistentialLayout(); + for (auto s: srl.getProtocols()) { + serializationRequirements.insert(s); + } + } + auto isCodableRequirement = checkDistributedSerializationRequirementIsExactlyCodable( C, serializationRequirement); - if (serializationRequirement && !serializationRequirement->hasError()) { - auto srl = serializationRequirement->getExistentialLayout(); - for (auto serializationReq: srl.getProtocols()) { - auto conformance = - TypeChecker::conformsToProtocol(resultType, serializationReq, module); - if (conformance.isInvalid()) { - if (diagnose) { - llvm::StringRef conformanceToSuggest = isCodableRequirement ? - "Codable" : // Codable is a typealias, easier to diagnose like that - serializationReq->getNameStr(); - - auto diag = valueDecl->diagnose( - diag::distributed_actor_target_result_not_codable, - resultType, - valueDecl, - conformanceToSuggest - ); - - if (isCodableRequirement) { - if (auto resultNominalType = resultType->getAnyNominal()) { - addCodableFixIt(resultNominalType, diag); - } + for (auto serializationReq: serializationRequirements) { + auto conformance = + TypeChecker::conformsToProtocol(resultType, serializationReq, module); + if (conformance.isInvalid()) { + if (diagnose) { + llvm::StringRef conformanceToSuggest = isCodableRequirement ? + "Codable" : // Codable is a typealias, easier to diagnose like that + serializationReq->getNameStr(); + + auto diag = valueDecl->diagnose( + diag::distributed_actor_target_result_not_codable, + resultType, + valueDecl, + conformanceToSuggest + ); + + if (isCodableRequirement) { + if (auto resultNominalType = resultType->getAnyNominal()) { + addCodableFixIt(resultNominalType, diag); } - } // end if: diagnose + } + } // end if: diagnose - return true; - } + return true; } } @@ -502,16 +513,16 @@ bool CheckDistributedFunctionRequest::evaluate( } auto &C = func->getASTContext(); - auto DC = func->getDeclContext(); auto module = func->getParentModule(); /// If no distributed module is available, then no reason to even try checks. if (!C.getLoadedModule(C.Id_Distributed)) return true; - Type serializationReqType = getConcreteReplacementForMemberSerializationRequirement(func); - for (auto param: *func->getParameters()) { + llvm::SmallPtrSet serializationRequirements; + Type serializationReqType = getSerializationRequirementTypesForMember(func, serializationRequirements); + for (auto param: *func->getParameters()) { // --- Check the parameter conforming to serialization requirements if (serializationReqType && !serializationReqType->hasError()) { // If the requirement is exactly `Codable` we diagnose it ia bit nicer. @@ -574,12 +585,11 @@ bool CheckDistributedFunctionRequest::evaluate( } } - if (serializationReqType && !serializationReqType->hasError()) { - // --- Result type must be either void or a codable type - if (checkDistributedTargetResultType(module, func, serializationReqType, - /*diagnose=*/true)) { - return true; - } + // --- Result type must be either void or a serialization requirement conforming type + if (checkDistributedTargetResultType( + module, func, serializationReqType, serializationRequirements, + /*diagnose=*/true)) { + return true; } return false; @@ -627,15 +637,16 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) { DC->getSelfNominalTypeDecl()->getDistributedActorSystemProperty(); auto systemDecl = systemVar->getInterfaceType()->getAnyNominal(); -// auto serializationRequirements = -// getDistributedSerializationRequirementProtocols( -// systemDecl, -// C.getProtocol(KnownProtocolKind::DistributedActorSystem)); + auto serializationRequirements = + getDistributedSerializationRequirementProtocols( + systemDecl, + C.getProtocol(KnownProtocolKind::DistributedActorSystem)); + auto serializationRequirement = - getConcreteReplacementForMemberSerializationRequirement(systemVar); + getSerializationRequirementTypesForMember(systemVar, serializationRequirements); auto module = var->getModuleContext(); - if (checkDistributedTargetResultType(module, var, serializationRequirement, diagnose)) { + if (checkDistributedTargetResultType(module, var, serializationRequirement, serializationRequirements, diagnose)) { return true; } diff --git a/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift b/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift index af4fa1020ce58..eace8fc19ea16 100644 --- a/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift +++ b/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift @@ -82,7 +82,7 @@ extension NoSerializationRequirementYet extension NoSerializationRequirementYet where SerializationRequirement: Codable { - // expected-error@+1{{result type 'NotCodable' of distributed instance method 'test4' does not conform to serialization requirement 'Codable'}} + // expected-error@+1{{result type 'NotCodable' of distributed instance method 'test4' does not conform to serialization requirement 'Decodable'}} distributed func test4() -> NotCodable { .init() } From 332e20fb8a7908fa1f6d6c995545f5207a6c1fa1 Mon Sep 17 00:00:00 2001 From: Konrad `ktoso` Malawski Date: Tue, 14 Nov 2023 19:10:36 +0900 Subject: [PATCH 09/11] handle conformance requirement on extension in distributed funcs --- include/swift/AST/ASTSynthesis.h | 13 ++++++++++--- include/swift/AST/DistributedDecl.h | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/include/swift/AST/ASTSynthesis.h b/include/swift/AST/ASTSynthesis.h index e9a8f1c0647b3..350ea58071c02 100644 --- a/include/swift/AST/ASTSynthesis.h +++ b/include/swift/AST/ASTSynthesis.h @@ -41,7 +41,7 @@ enum SingletonTypeSynthesizer { _any, _bridgeObject, _error, - _executor, + _executor, // the 'BuiltinExecutor' type _job, _nativeObject, _never, @@ -49,7 +49,8 @@ enum SingletonTypeSynthesizer { _rawUnsafeContinuation, _void, _word, - _serialExecutor, + _executorProtocol, // the '_Concurrency.Executor' protocol + _serialExecutor, // the '_Concurrency.SerialExecutor' protocol }; inline Type synthesizeType(SynthesisContext &SC, SingletonTypeSynthesizer kind) { @@ -66,9 +67,15 @@ inline Type synthesizeType(SynthesisContext &SC, case _void: return SC.Context.TheEmptyTupleType; case _word: return BuiltinIntegerType::get(BuiltinIntegerWidth::pointer(), SC.Context); - case _serialExecutor: + case _executorProtocol: + return SC.Context.getProtocol(KnownProtocolKind::Executor) + ->getDeclaredInterfaceType(); + case _serialExecutor: return SC.Context.getProtocol(KnownProtocolKind::SerialExecutor) ->getDeclaredInterfaceType(); + case _taskExecutor: + return SC.Context.getProtocol(KnownProtocolKind::TaskExecutor) + ->getDeclaredInterfaceType(); } } diff --git a/include/swift/AST/DistributedDecl.h b/include/swift/AST/DistributedDecl.h index 018c5f22587cc..abd0d72a2c5a5 100644 --- a/include/swift/AST/DistributedDecl.h +++ b/include/swift/AST/DistributedDecl.h @@ -115,6 +115,6 @@ getDistributedSerializationRequirements( ProtocolDecl *protocol, llvm::SmallPtrSetImpl &requirementProtos); -// ==== ------------------------------------------------------------------------ +} #endif /* SWIFT_DECL_DISTRIBUTEDDECL_H */ From 7cdeda77bc8bdb2e372144606a5db668c2e3f0a2 Mon Sep 17 00:00:00 2001 From: Konrad `ktoso` Malawski Date: Wed, 15 Nov 2023 08:16:54 +0900 Subject: [PATCH 10/11] [Distributed] Another fix for getting required protocols for SR --- include/swift/AST/ASTSynthesis.h | 13 +++---------- lib/AST/DistributedDecl.cpp | 11 ++--------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/include/swift/AST/ASTSynthesis.h b/include/swift/AST/ASTSynthesis.h index 350ea58071c02..e9a8f1c0647b3 100644 --- a/include/swift/AST/ASTSynthesis.h +++ b/include/swift/AST/ASTSynthesis.h @@ -41,7 +41,7 @@ enum SingletonTypeSynthesizer { _any, _bridgeObject, _error, - _executor, // the 'BuiltinExecutor' type + _executor, _job, _nativeObject, _never, @@ -49,8 +49,7 @@ enum SingletonTypeSynthesizer { _rawUnsafeContinuation, _void, _word, - _executorProtocol, // the '_Concurrency.Executor' protocol - _serialExecutor, // the '_Concurrency.SerialExecutor' protocol + _serialExecutor, }; inline Type synthesizeType(SynthesisContext &SC, SingletonTypeSynthesizer kind) { @@ -67,15 +66,9 @@ inline Type synthesizeType(SynthesisContext &SC, case _void: return SC.Context.TheEmptyTupleType; case _word: return BuiltinIntegerType::get(BuiltinIntegerWidth::pointer(), SC.Context); - case _executorProtocol: - return SC.Context.getProtocol(KnownProtocolKind::Executor) - ->getDeclaredInterfaceType(); - case _serialExecutor: + case _serialExecutor: return SC.Context.getProtocol(KnownProtocolKind::SerialExecutor) ->getDeclaredInterfaceType(); - case _taskExecutor: - return SC.Context.getProtocol(KnownProtocolKind::TaskExecutor) - ->getDeclaredInterfaceType(); } } diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 274d41701d878..435eae745866c 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -119,15 +119,8 @@ Type swift::getSerializationRequirementTypesForMember( } // Also store all `SerializationRequirement : SomeProtocol` requirements - for (auto requirement: signature.getRequirements()) { - if (requirement.getFirstType()->isEqual(SerReqAssocType) && - requirement.getKind() == RequirementKind::Conformance) { - if (auto nominal = requirement.getSecondType()->getAnyNominal()) { - if (auto protocol = dyn_cast(nominal)) { - serializationRequirements.insert(protocol); - } - } - } + for (auto proto: signature->getRequiredProtocols(SerReqAssocType)) { + serializationRequirements.insert(proto); } // Note that this may be null, e.g. if we're a distributed func inside From af211dd9ab07bd372ecf66bf26b03809d97d7725 Mon Sep 17 00:00:00 2001 From: Konrad `ktoso` Malawski Date: Thu, 16 Nov 2023 11:48:17 +0900 Subject: [PATCH 11/11] [Distributed] Remove redundant isa check in getting SR --- lib/AST/DistributedDecl.cpp | 2 +- ...ocols_distributed_func_serialization_requirements.swift | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/lib/AST/DistributedDecl.cpp b/lib/AST/DistributedDecl.cpp index 435eae745866c..7d3c457e148a2 100644 --- a/lib/AST/DistributedDecl.cpp +++ b/lib/AST/DistributedDecl.cpp @@ -110,7 +110,7 @@ Type swift::getSerializationRequirementTypesForMember( auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement) ->getDeclaredInterfaceType(); - if (DC->getSelfProtocolDecl() || isa(DC)) { + if (DC->getSelfProtocolDecl()) { GenericSignature signature; if (auto *genericContext = member->getAsGenericContext()) { signature = genericContext->getGenericSignature(); diff --git a/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift b/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift index eace8fc19ea16..ee3ce1d570797 100644 --- a/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift +++ b/test/Distributed/distributed_protocols_distributed_func_serialization_requirements.swift @@ -88,6 +88,13 @@ extension NoSerializationRequirementYet } } +extension ProtocolWithChecksSeqReqDA { + // expected-error@+1{{result type 'NotCodable' of distributed instance method 'test4' does not conform to serialization requirement 'Codable'}} + distributed func test4() -> NotCodable { + .init() + } +} + // FIXME(distributed): remove the -verify-ignore-unknown // :0: error: unexpected error produced: instance method 'recordReturnType' requires that 'NotCodable' conform to 'Decodable' // :0: error: unexpected error produced: instance method 'recordReturnType' requires that 'NotCodable' conform to 'Encodable'