@@ -377,11 +377,16 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
377377static bool checkDistributedTargetResultType (
378378 ModuleDecl *module , ValueDecl *valueDecl,
379379 Type serializationRequirement,
380+ llvm::SmallPtrSet<ProtocolDecl *, 2 > serializationRequirements,
380381 bool diagnose) {
381382 auto &C = valueDecl->getASTContext ();
382383
383- if (!serializationRequirement || serializationRequirement->hasError ())
384+ if (serializationRequirement && serializationRequirement->hasError ()) {
385+ return false ;
386+ }
387+ if ((!serializationRequirement || serializationRequirement->hasError ()) && serializationRequirements.empty ()) {
384388 return false ; // error of the type would be diagnosed elsewhere
389+ }
385390
386391 Type resultType;
387392 if (auto func = dyn_cast<FuncDecl>(valueDecl)) {
@@ -395,37 +400,43 @@ static bool checkDistributedTargetResultType(
395400 if (resultType->isVoid ())
396401 return false ;
397402
403+
404+ // Collect extra "SerializationRequirement: SomeProtocol" requirements
405+ if (serializationRequirement && !serializationRequirement->hasError ()) {
406+ auto srl = serializationRequirement->getExistentialLayout ();
407+ for (auto s: srl.getProtocols ()) {
408+ serializationRequirements.insert (s);
409+ }
410+ }
411+
398412 auto isCodableRequirement =
399413 checkDistributedSerializationRequirementIsExactlyCodable (
400414 C, serializationRequirement);
401415
402- if (serializationRequirement && !serializationRequirement->hasError ()) {
403- auto srl = serializationRequirement->getExistentialLayout ();
404- for (auto serializationReq: srl.getProtocols ()) {
405- auto conformance =
406- TypeChecker::conformsToProtocol (resultType, serializationReq, module );
407- if (conformance.isInvalid ()) {
408- if (diagnose) {
409- llvm::StringRef conformanceToSuggest = isCodableRequirement ?
410- " Codable" : // Codable is a typealias, easier to diagnose like that
411- serializationReq->getNameStr ();
412-
413- auto diag = valueDecl->diagnose (
414- diag::distributed_actor_target_result_not_codable,
415- resultType,
416- valueDecl,
417- conformanceToSuggest
418- );
419-
420- if (isCodableRequirement) {
421- if (auto resultNominalType = resultType->getAnyNominal ()) {
422- addCodableFixIt (resultNominalType, diag);
423- }
416+ for (auto serializationReq: serializationRequirements) {
417+ auto conformance =
418+ TypeChecker::conformsToProtocol (resultType, serializationReq, module );
419+ if (conformance.isInvalid ()) {
420+ if (diagnose) {
421+ llvm::StringRef conformanceToSuggest = isCodableRequirement ?
422+ " Codable" : // Codable is a typealias, easier to diagnose like that
423+ serializationReq->getNameStr ();
424+
425+ auto diag = valueDecl->diagnose (
426+ diag::distributed_actor_target_result_not_codable,
427+ resultType,
428+ valueDecl,
429+ conformanceToSuggest
430+ );
431+
432+ if (isCodableRequirement) {
433+ if (auto resultNominalType = resultType->getAnyNominal ()) {
434+ addCodableFixIt (resultNominalType, diag);
424435 }
425- } // end if: diagnose
436+ }
437+ } // end if: diagnose
426438
427- return true ;
428- }
439+ return true ;
429440 }
430441 }
431442
@@ -493,16 +504,16 @@ bool CheckDistributedFunctionRequest::evaluate(
493504 }
494505
495506 auto &C = func->getASTContext ();
496- auto DC = func->getDeclContext ();
497507 auto module = func->getParentModule ();
498508
499509 // / If no distributed module is available, then no reason to even try checks.
500510 if (!C.getLoadedModule (C.Id_Distributed ))
501511 return true ;
502512
503- Type serializationReqType = getConcreteReplacementForMemberSerializationRequirement (func) ;
504- for ( auto param: *func-> getParameters ()) {
513+ llvm::SmallPtrSet<ProtocolDecl *, 2 > serializationRequirements ;
514+ Type serializationReqType = getSerializationRequirementTypesForMember (func, serializationRequirements);
505515
516+ for (auto param: *func->getParameters ()) {
506517 // --- Check the parameter conforming to serialization requirements
507518 if (serializationReqType && !serializationReqType->hasError ()) {
508519 // If the requirement is exactly `Codable` we diagnose it ia bit nicer.
@@ -565,12 +576,11 @@ bool CheckDistributedFunctionRequest::evaluate(
565576 }
566577 }
567578
568- if (serializationReqType && !serializationReqType->hasError ()) {
569- // --- Result type must be either void or a codable type
570- if (checkDistributedTargetResultType (module , func, serializationReqType,
571- /* diagnose=*/ true )) {
572- return true ;
573- }
579+ // --- Result type must be either void or a serialization requirement conforming type
580+ if (checkDistributedTargetResultType (
581+ module , func, serializationReqType, serializationRequirements,
582+ /* diagnose=*/ true )) {
583+ return true ;
574584 }
575585
576586 return false ;
@@ -618,15 +628,16 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) {
618628 DC->getSelfNominalTypeDecl ()->getDistributedActorSystemProperty ();
619629 auto systemDecl = systemVar->getInterfaceType ()->getAnyNominal ();
620630
621- // auto serializationRequirements =
622- // getDistributedSerializationRequirementProtocols(
623- // systemDecl,
624- // C.getProtocol(KnownProtocolKind::DistributedActorSystem));
631+ auto serializationRequirements =
632+ getDistributedSerializationRequirementProtocols (
633+ systemDecl,
634+ C.getProtocol (KnownProtocolKind::DistributedActorSystem));
635+
625636 auto serializationRequirement =
626- getConcreteReplacementForMemberSerializationRequirement (systemVar);
637+ getSerializationRequirementTypesForMember (systemVar, serializationRequirements );
627638
628639 auto module = var->getModuleContext ();
629- if (checkDistributedTargetResultType (module , var, serializationRequirement, diagnose)) {
640+ if (checkDistributedTargetResultType (module , var, serializationRequirement, serializationRequirements, diagnose)) {
630641 return true ;
631642 }
632643
0 commit comments